/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.rng.distribution.impl;

import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.util.Localizable;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistributionEx;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.BaseDistribution;
import org.nd4j.linalg.api.rng.distribution.impl.SaddlePointExpansion;
import org.nd4j.linalg.factory.Nd4j;

public class BinomialDistribution
extends BaseDistribution {
    private final int numberOfTrials;
    private double probabilityOfSuccess;
    private INDArray p;

    public BinomialDistribution(int trials, double p) {
        this(Nd4j.getRandom(), trials, p);
    }

    public BinomialDistribution(Random rng, int trials, double p) {
        super(rng);
        if (trials < 0) {
            throw new NotPositiveException((Localizable)LocalizedFormats.NUMBER_OF_TRIALS, (Number)trials);
        }
        if (p < 0.0 || p > 1.0) {
            throw new OutOfRangeException((Number)p, (Number)0, (Number)1);
        }
        this.probabilityOfSuccess = p;
        this.numberOfTrials = trials;
    }

    public BinomialDistribution(int n, INDArray p) {
        this.random = Nd4j.getRandom();
        this.numberOfTrials = n;
        this.p = p;
    }

    public int getNumberOfTrials() {
        return this.numberOfTrials;
    }

    public double getProbabilityOfSuccess() {
        return this.probabilityOfSuccess;
    }

    public double probability(int x) {
        double ret = x < 0 || x > this.numberOfTrials ? 0.0 : FastMath.exp((double)SaddlePointExpansion.logBinomialProbability(x, this.numberOfTrials, this.probabilityOfSuccess, 1.0 - this.probabilityOfSuccess));
        return ret;
    }

    public double cumulativeProbability(int x) {
        double ret = x < 0 ? 0.0 : (x >= this.numberOfTrials ? 1.0 : 1.0 - Beta.regularizedBeta((double)this.probabilityOfSuccess, (double)((double)x + 1.0), (double)(this.numberOfTrials - x)));
        return ret;
    }

    @Override
    public double density(double x) {
        return 0.0;
    }

    @Override
    public double cumulativeProbability(double x) {
        double ret = x < 0.0 ? 0.0 : (x >= (double)this.numberOfTrials ? 1.0 : 1.0 - Beta.regularizedBeta((double)this.probabilityOfSuccess, (double)(x + 1.0), (double)((double)this.numberOfTrials - x)));
        return ret;
    }

    @Override
    public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
        return 0.0;
    }

    @Override
    public double getNumericalMean() {
        return (double)this.numberOfTrials * this.probabilityOfSuccess;
    }

    @Override
    public double getNumericalVariance() {
        double p = this.probabilityOfSuccess;
        return (double)this.numberOfTrials * p * (1.0 - p);
    }

    @Override
    public double getSupportLowerBound() {
        return this.probabilityOfSuccess < 1.0 ? 0.0 : (double)this.numberOfTrials;
    }

    @Override
    public double getSupportUpperBound() {
        return this.probabilityOfSuccess > 0.0 ? (double)this.numberOfTrials : 0.0;
    }

    @Override
    public boolean isSupportLowerBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportUpperBoundInclusive() {
        return false;
    }

    @Override
    public boolean isSupportConnected() {
        return true;
    }

    private void ensureConsistent(int i) {
        this.probabilityOfSuccess = this.p.linearView().getDouble((long)i);
    }

    @Override
    public INDArray sample(int[] shape) {
        INDArray ret = Nd4j.createUninitialized(shape, Nd4j.order().charValue());
        return this.sample(ret);
    }

    @Override
    public INDArray sample(INDArray ret) {
        if (this.random.getStatePointer() != null) {
            if (this.p != null) {
                return Nd4j.getExecutioner().exec(new BinomialDistributionEx(ret, this.numberOfTrials, this.p), this.random);
            }
            return Nd4j.getExecutioner().exec(new BinomialDistributionEx(ret, this.numberOfTrials, this.probabilityOfSuccess), this.random);
        }
        NdIndexIterator idxIter = new NdIndexIterator(ret.shape());
        long len = ret.length();
        if (this.p != null) {
            int i = 0;
            while ((long)i < len) {
                long[] idx = (long[])idxIter.next();
                org.apache.commons.math3.distribution.BinomialDistribution binomialDistribution = new org.apache.commons.math3.distribution.BinomialDistribution((RandomGenerator)Nd4j.getRandom(), this.numberOfTrials, this.p.getDouble(idx));
                ret.putScalar(idx, binomialDistribution.sample());
                ++i;
            }
        } else {
            org.apache.commons.math3.distribution.BinomialDistribution binomialDistribution = new org.apache.commons.math3.distribution.BinomialDistribution((RandomGenerator)Nd4j.getRandom(), this.numberOfTrials, this.probabilityOfSuccess);
            int i = 0;
            while ((long)i < len) {
                ret.putScalar((long[])idxIter.next(), binomialDistribution.sample());
                ++i;
            }
        }
        return ret;
    }
}

