/*
 * Decompiled with CFR 0.152.
 */
package hivemall.utils.math;

import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;
import hivemall.utils.math.MatrixUtils;
import java.util.AbstractMap;
import java.util.Map;
import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;

public final class StatsUtils {
    private StatsUtils() {
    }

    public static double probit(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p must be in [0,1]");
        }
        return Math.sqrt(2.0) * MathUtils.inverseErf(2.0 * p - 1.0);
    }

    public static double probit(double p, double range) {
        if (range <= 0.0) {
            throw new IllegalArgumentException("range must be > 0: " + range);
        }
        if (p == 0.0) {
            return -range;
        }
        if (p == 1.0) {
            return range;
        }
        double v = StatsUtils.probit(p);
        if (v < 0.0) {
            return Math.max(v, -range);
        }
        return Math.min(v, range);
    }

    public static double pdf(double x, double x_hat, double sigma) {
        if (sigma == 0.0) {
            return 0.0;
        }
        double diff = x - x_hat;
        double numerator = Math.exp(-0.5 * diff * diff / sigma);
        double denominator = Math.sqrt(Math.PI * 2) * Math.sqrt(sigma);
        return numerator / denominator;
    }

    public static double pdf(@Nonnull RealVector x, @Nonnull RealVector x_hat, @Nonnull RealMatrix sigma) {
        RealMatrix invSigma;
        int dim = x.getDimension();
        Preconditions.checkArgument(x_hat.getDimension() == dim, "|x| != |x_hat|, |x|=" + dim + ", |x_hat|=" + x_hat.getDimension());
        Preconditions.checkArgument(sigma.getRowDimension() == dim, "|x| != |sigma|, |x|=" + dim + ", |sigma|=" + sigma.getRowDimension());
        Preconditions.checkArgument(sigma.isSquare(), "Sigma is not square matrix");
        LUDecomposition LU = new LUDecomposition(sigma);
        double detSigma = LU.getDeterminant();
        double denominator = Math.pow(Math.PI * 2, 0.5 * (double)dim) * Math.pow(detSigma, 0.5);
        if (denominator == 0.0) {
            return 0.0;
        }
        DecompositionSolver solver = LU.getSolver();
        if (!solver.isNonSingular()) {
            SingularValueDecomposition svd = new SingularValueDecomposition(sigma);
            invSigma = svd.getSolver().getInverse();
        } else {
            invSigma = solver.getInverse();
        }
        RealVector diff = x.subtract(x_hat);
        RealVector premultiplied = invSigma.preMultiply(diff);
        double sum = premultiplied.dotProduct(diff);
        double numerator = Math.exp(-0.5 * sum);
        return numerator / denominator;
    }

    public static double logLoss(double actual, double predicted, double sigma) {
        double p = StatsUtils.pdf(actual, predicted, sigma);
        if (p == 0.0) {
            return 0.0;
        }
        return -Math.log(p);
    }

    public static double logLoss(@Nonnull RealVector actual, @Nonnull RealVector predicted, @Nonnull RealMatrix sigma) {
        double p = StatsUtils.pdf(actual, predicted, sigma);
        if (p == 0.0) {
            return 0.0;
        }
        return -Math.log(p);
    }

    public static double hellingerDistance(@Nonnull double mu1, @Nonnull double sigma1, @Nonnull double mu2, @Nonnull double sigma2) {
        double sigmaSum = sigma1 + sigma2;
        if (sigmaSum == 0.0) {
            return 0.0;
        }
        double numerator = Math.pow(sigma1, 0.25) * Math.pow(sigma2, 0.25) * Math.exp(-0.25 * Math.pow(mu1 - mu2, 2.0) / sigmaSum);
        double denominator = Math.sqrt(sigmaSum / 2.0);
        if (denominator == 0.0) {
            return 1.0;
        }
        return 1.0 - numerator / denominator;
    }

    public static double hellingerDistance(@Nonnull RealVector mu1, @Nonnull RealMatrix sigma1, @Nonnull RealVector mu2, @Nonnull RealMatrix sigma2) {
        RealVector muSub = mu1.subtract(mu2);
        RealMatrix sigmaMean = sigma1.add(sigma2).scalarMultiply(0.5);
        LUDecomposition LUsigmaMean = new LUDecomposition(sigmaMean);
        double denominator = Math.sqrt(LUsigmaMean.getDeterminant());
        if (denominator == 0.0) {
            return 1.0;
        }
        RealMatrix sigmaMeanInv = LUsigmaMean.getSolver().getInverse();
        double sigma1Det = MatrixUtils.det(sigma1);
        double sigma2Det = MatrixUtils.det(sigma2);
        double numerator = Math.pow(sigma1Det, 0.25) * Math.pow(sigma2Det, 0.25) * Math.exp(-0.125 * sigmaMeanInv.preMultiply(muSub).dotProduct(muSub));
        return 1.0 - numerator / denominator;
    }

    public static double chiSquare(@Nonnull double[] observed, @Nonnull double[] expected) {
        if (observed.length < 2) {
            throw new DimensionMismatchException(observed.length, 2);
        }
        if (expected.length != observed.length) {
            throw new DimensionMismatchException(observed.length, expected.length);
        }
        MathArrays.checkPositive(expected);
        for (double d : observed) {
            if (!(d < 0.0)) continue;
            throw new NotPositiveException(d);
        }
        double sumObserved = 0.0;
        double sumExpected = 0.0;
        for (int i = 0; i < observed.length; ++i) {
            sumObserved += observed[i];
            sumExpected += expected[i];
        }
        double ratio = 1.0;
        boolean rescale = false;
        if (FastMath.abs(sumObserved - sumExpected) > 1.0E-5) {
            ratio = sumObserved / sumExpected;
            rescale = true;
        }
        double sumSq = 0.0;
        for (int i = 0; i < observed.length; ++i) {
            double dev;
            if (rescale) {
                dev = observed[i] - ratio * expected[i];
                sumSq += dev * dev / (ratio * expected[i]);
                continue;
            }
            dev = observed[i] - expected[i];
            sumSq += dev * dev / expected[i];
        }
        return sumSq;
    }

    public static double chiSquareTest(@Nonnull double[] observed, @Nonnull double[] expected) {
        ChiSquaredDistribution distribution = new ChiSquaredDistribution((double)expected.length - 1.0);
        return 1.0 - distribution.cumulativeProbability(StatsUtils.chiSquare(observed, expected));
    }

    public static Map.Entry<double[], double[]> chiSquare(@Nonnull double[][] observeds, @Nonnull double[][] expecteds) {
        Preconditions.checkArgument(observeds.length == expecteds.length);
        int len = expecteds.length;
        int lenOfEach = expecteds[0].length;
        ChiSquaredDistribution distribution = new ChiSquaredDistribution((double)lenOfEach - 1.0);
        double[] chi2s = new double[len];
        double[] ps = new double[len];
        for (int i = 0; i < len; ++i) {
            chi2s[i] = StatsUtils.chiSquare(observeds[i], expecteds[i]);
            ps[i] = 1.0 - distribution.cumulativeProbability(chi2s[i]);
        }
        return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps);
    }
}

