/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.MathEx;
import smile.stat.distribution.DiscreteDistribution;
import smile.stat.distribution.DiscreteExponentialFamily;
import smile.stat.distribution.DiscreteMixture;
import smile.stat.distribution.ExponentialDistribution;
import smile.stat.distribution.GeometricDistribution;

public class ShiftedGeometricDistribution
extends DiscreteDistribution
implements DiscreteExponentialFamily {
    private static final long serialVersionUID = 2L;
    public final double p;
    private final double entropy;
    private ExponentialDistribution expDist;

    public ShiftedGeometricDistribution(double p) {
        if (p <= 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        this.p = p;
        this.entropy = (-p * MathEx.log2(p) - (1.0 - p) * MathEx.log2(1.0 - p)) / p;
    }

    public static ShiftedGeometricDistribution fit(int[] data) {
        double sum = 0.0;
        for (int x : data) {
            if (x < 0) {
                throw new IllegalArgumentException("Invalid value " + x);
            }
            sum += (double)(x + 1);
        }
        double p = (double)data.length / sum;
        return new ShiftedGeometricDistribution(p);
    }

    @Override
    public int length() {
        return 1;
    }

    @Override
    public double mean() {
        return 1.0 / this.p;
    }

    @Override
    public double variance() {
        return (1.0 - this.p) / (this.p * this.p);
    }

    @Override
    public double sd() {
        return Math.sqrt(1.0 - this.p) / this.p;
    }

    @Override
    public double entropy() {
        return this.entropy;
    }

    public String toString() {
        return String.format("Shifted Geometric Distribution(%.4f)", this.p);
    }

    @Override
    public double rand() {
        if (this.expDist == null) {
            double lambda = -Math.log(1.0 - this.p);
            this.expDist = new ExponentialDistribution(lambda);
        }
        return Math.floor(this.expDist.rand());
    }

    @Override
    public double p(int k) {
        if (k <= 0) {
            return 0.0;
        }
        return Math.pow(1.0 - this.p, k - 1) * this.p;
    }

    @Override
    public double logp(int k) {
        if (k <= 0) {
            return Double.NEGATIVE_INFINITY;
        }
        return (double)(k - 1) * Math.log(1.0 - this.p) + Math.log(this.p);
    }

    @Override
    public double cdf(double k) {
        if (k < 0.0) {
            return 0.0;
        }
        return 1.0 - Math.pow(1.0 - this.p, k);
    }

    @Override
    public double quantile(double p) {
        int nu;
        int nl;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        int n = (int)Math.max(Math.sqrt(1.0 / this.p), 5.0);
        int inc = 1;
        if (p < this.cdf(n)) {
            do {
                n = Math.max(n - inc, 0);
                inc *= 2;
            } while (p < this.cdf(n));
            nl = n;
            nu = n + inc / 2;
        } else {
            while (p > this.cdf(n += (inc *= 2))) {
            }
            nu = n;
            nl = n - inc / 2;
        }
        return this.quantile(p, nl, nu);
    }

    @Override
    public DiscreteMixture.Component M(int[] x, double[] posteriori) {
        double alpha = 0.0;
        double mean = 0.0;
        for (int i = 0; i < x.length; ++i) {
            alpha += posteriori[i];
            mean += (double)x[i] * posteriori[i];
        }
        return new DiscreteMixture.Component(alpha, new GeometricDistribution(1.0 / (1.0 + (mean /= alpha))));
    }
}

