/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.Math;
import smile.math.special.Gamma;
import smile.stat.distribution.AbstractDistribution;
import smile.stat.distribution.ExponentialFamily;
import smile.stat.distribution.Mixture;

public class GammaDistribution
extends AbstractDistribution
implements ExponentialFamily {
    private double theta;
    private double k;
    private double logTheta;
    private double thetaGammaK;
    private double logGammaK;
    private double entropy;

    public GammaDistribution(double shape, double scale) {
        if (shape <= 0.0) {
            throw new IllegalArgumentException("Invalid shape: " + shape);
        }
        if (scale <= 0.0) {
            throw new IllegalArgumentException("Invalid scale: " + scale);
        }
        this.theta = scale;
        this.k = shape;
        this.logTheta = Math.log(this.theta);
        this.thetaGammaK = this.theta * Gamma.gamma(this.k);
        this.logGammaK = Gamma.logGamma(this.k);
        this.entropy = this.k + Math.log(this.theta) + Gamma.logGamma(this.k) + (1.0 - this.k) * Gamma.digamma(this.k);
    }

    public GammaDistribution(double[] data) {
        for (int i = 0; i < data.length; ++i) {
            if (!(data[i] <= 0.0)) continue;
            throw new IllegalArgumentException("Samples contain non-positive values.");
        }
        double mu = 0.0;
        double s = 0.0;
        for (double x : data) {
            mu += x;
            s += Math.log(x);
        }
        s = Math.log(mu /= (double)data.length) - s / (double)data.length;
        this.k = (3.0 - s + Math.sqrt(Math.sqr(s - 3.0) + 24.0 * s)) / (12.0 * s);
        this.theta = mu / this.k;
        this.logTheta = Math.log(this.theta);
        this.thetaGammaK = this.theta * Gamma.gamma(this.k);
        this.logGammaK = Gamma.logGamma(this.k);
        this.entropy = this.k + Math.log(this.theta) + Gamma.logGamma(this.k) + (1.0 - this.k) * Gamma.digamma(this.k);
    }

    public double getScale() {
        return this.theta;
    }

    public double getShape() {
        return this.k;
    }

    @Override
    public int npara() {
        return 2;
    }

    @Override
    public double mean() {
        return this.k * this.theta;
    }

    @Override
    public double var() {
        return this.k * this.theta * this.theta;
    }

    @Override
    public double sd() {
        return Math.sqrt(this.k) * this.theta;
    }

    @Override
    public double entropy() {
        return this.entropy;
    }

    public String toString() {
        return String.format("Gamma Distribution(%.4f, %.4f)", this.theta, this.k);
    }

    @Override
    public double rand() {
        if (this.k - Math.floor(this.k) != 0.0) {
            throw new IllegalArgumentException("Gamma random number generator support only integer shape parameter.");
        }
        double r = 0.0;
        int i = 0;
        while ((double)i < this.k) {
            r += Math.log(Math.random());
            ++i;
        }
        return r *= -this.theta;
    }

    @Override
    public double p(double x) {
        if (x < 0.0) {
            return 0.0;
        }
        return Math.pow(x / this.theta, this.k - 1.0) * Math.exp(-x / this.theta) / this.thetaGammaK;
    }

    @Override
    public double logp(double x) {
        if (x < 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        return (this.k - 1.0) * Math.log(x) - x / this.theta - this.k * this.logTheta - this.logGammaK;
    }

    @Override
    public double cdf(double x) {
        if (x < 0.0) {
            return 0.0;
        }
        return Gamma.regularizedIncompleteGamma(this.k, x / this.theta);
    }

    @Override
    public double quantile(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        return Gamma.inverseRegularizedIncompleteGamma(this.k, p) * this.theta;
    }

    @Override
    public Mixture.Component M(double[] x, double[] posteriori) {
        int i;
        double alpha = 0.0;
        double mean = 0.0;
        double var = 0.0;
        for (i = 0; i < x.length; ++i) {
            alpha += posteriori[i];
            mean += x[i] * posteriori[i];
        }
        mean /= alpha;
        for (i = 0; i < x.length; ++i) {
            double d = x[i] - mean;
            var += d * d * posteriori[i];
        }
        Mixture.Component c = new Mixture.Component();
        c.priori = alpha;
        c.distribution = new GammaDistribution(mean * mean / (var /= alpha), var / mean);
        return c;
    }
}

