/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.function.DoublePredicate;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.function.IntPredicate;
import java.util.function.IntToDoubleFunction;
import java.util.function.Supplier;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.EnumeratedDistribution;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.Pair;
import org.broadinstitute.hellbender.utils.DigammaCache;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.Log10Cache;
import org.broadinstitute.hellbender.utils.Log10FactorialCache;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class MathUtils {
    public static final double LOG10_P_OF_ZERO = -1000000.0;
    public static final double LOG10_ONE_HALF = Math.log10(0.5);
    public static final double LOG10_ONE_THIRD = -Math.log10(3.0);
    public static final double LOG_ONE_THIRD = -Math.log(3.0);
    public static final double INV_LOG_2 = 1.0 / Math.log(2.0);
    private static final double LOG_10 = Math.log(10.0);
    private static final double INV_LOG_10 = 1.0 / LOG_10;
    public static final double LOG10_E = Math.log10(Math.E);
    private static final double ROOT_TWO_PI = Math.sqrt(Math.PI * 2);
    private static final Log10Cache LOG_10_CACHE = new Log10Cache();
    private static final Log10FactorialCache LOG_10_FACTORIAL_CACHE = new Log10FactorialCache();
    private static final DigammaCache DIGAMMA_CACHE = new DigammaCache();

    private MathUtils() {
    }

    public static <E> E randomSelect(List<E> choices, Function<E, Double> probabilityFunction, RandomGenerator rng) {
        Utils.nonNull(choices);
        Utils.nonNull(probabilityFunction);
        Utils.nonNull(rng);
        List pmf = choices.stream().map(e -> new Pair(e, probabilityFunction.apply(e))).collect(Collectors.toList());
        return (E)new EnumeratedDistribution(rng, pmf).sample();
    }

    public static int secondSmallestMinusSmallest(int[] values, int defaultValue) {
        Utils.nonNull(values);
        if (values.length <= 1) {
            return defaultValue;
        }
        int smallest = values[0];
        int secondSmallest = Integer.MAX_VALUE;
        for (int i = 1; i < values.length; ++i) {
            if (values[i] < smallest) {
                secondSmallest = smallest;
                smallest = values[i];
                continue;
            }
            if (values[i] >= secondSmallest) continue;
            secondSmallest = values[i];
        }
        return secondSmallest - smallest;
    }

    public static int[] normalizePLs(int[] PLs) {
        int[] newPLs = new int[PLs.length];
        int smallest = MathUtils.arrayMin(PLs);
        for (int i = 0; i < PLs.length; ++i) {
            newPLs[i] = PLs[i] - smallest;
        }
        return newPLs;
    }

    public static int[] ebeAdd(int[] a, int[] b) throws DimensionMismatchException {
        if (a.length != b.length) {
            throw new DimensionMismatchException(a.length, b.length);
        }
        int[] result = (int[])a.clone();
        for (int i = 0; i < a.length; ++i) {
            int n = i;
            result[n] = result[n] + b[i];
        }
        return result;
    }

    public static double[] ebeAdd(double[] a, double[] b) throws DimensionMismatchException {
        if (a.length != b.length) {
            throw new DimensionMismatchException(a.length, b.length);
        }
        double[] result = (double[])a.clone();
        for (int i = 0; i < a.length; ++i) {
            int n = i;
            result[n] = result[n] + b[i];
        }
        return result;
    }

    public static double[] sumArrayFunction(int min, int max, IntToDoubleArrayFunction function) {
        Utils.validateArg(max >= min, "max must be at least as great as min");
        double[] result = function.apply(min);
        for (int n = min + 1; n < max; ++n) {
            double[] newValues = function.apply(n);
            Utils.validateArg(newValues.length == result.length, "array function returns different sizes for different inputs!");
            for (int i = 0; i < result.length; ++i) {
                int n2 = i;
                result[n2] = result[n2] + newValues[i];
            }
        }
        return result;
    }

    public static void addToArrayInPlace(double[] array, double[] summand) {
        Utils.validateArg(array.length == summand.length, "Arrays must have same length");
        for (int n = 0; n < array.length; ++n) {
            int n2 = n;
            array[n2] = array[n2] + summand[n];
        }
    }

    public static void addToArrayInPlace(int[] array, int[] summand) {
        Utils.validateArg(array.length == summand.length, "Arrays must have same length");
        for (int n = 0; n < array.length; ++n) {
            int n2 = n;
            array[n2] = array[n2] + summand[n];
        }
    }

    public static int median(int[] values) {
        Utils.nonNull(values);
        return (int)FastMath.round((double)new Median().evaluate(Arrays.stream(values).mapToDouble(n -> n).toArray()));
    }

    public static double dotProduct(double[] a, double[] b) {
        return MathUtils.sum(MathArrays.ebeMultiply((double[])Utils.nonNull(a), (double[])Utils.nonNull(b)));
    }

    public static double[] doubles(double start, double limit, double step) {
        ParamUtils.isFinite(start, "the start must be finite");
        ParamUtils.isFinite(limit, "the limit must be finite");
        ParamUtils.isFinite(step, "the step must be finite");
        double tolerance = Math.pow(10.0, Math.floor(Math.min(0.0, Math.log10(Math.abs(step)) - 3.0)));
        double diff = limit - start;
        if (Math.abs(diff) < tolerance) {
            return new double[]{start};
        }
        Utils.validateArg(diff * step > 0.0, "the difference between start and end must have the same sign as the step");
        if (diff == 0.0) {
            return new double[]{start};
        }
        if (diff > 0.0 == step > 0.0) {
            long lengthAsLong = Math.round(Math.floor(1.0 + tolerance + diff / step));
            if (lengthAsLong > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("cannot produce such a large sequence with " + lengthAsLong + " elements");
            }
            int length = (int)lengthAsLong;
            double[] result = new double[length];
            for (int i = 0; i < length; ++i) {
                result[i] = start + step * (double)i;
            }
            if (Math.abs(result[result.length - 1] - limit) <= tolerance) {
                result[result.length - 1] = limit;
            }
            return result;
        }
        throw new IllegalArgumentException("the max - min difference and increment must have the same sign");
    }

    public static double[] doubles(int repeats, double val) {
        ParamUtils.isPositiveOrZero(repeats, "repeats must be 0 or greater");
        double[] result = new double[repeats];
        Arrays.fill(result, val);
        return result;
    }

    public static double sumOfSquares(Collection<Integer> collection) {
        Utils.nonNull(collection);
        return collection.stream().mapToInt(i -> i * i).sum();
    }

    public static int[] sampleIndicesWithoutReplacement(int n, int k) {
        return Utils.getRandomDataGenerator().nextPermutation(n, k);
    }

    public static double log10OneMinusPow10(double a) {
        if (a > 0.0) {
            return Double.NaN;
        }
        if (a == 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        double b = a * LOG_10;
        return NaturalLogUtils.log1mexp(b) * INV_LOG_10;
    }

    public static boolean isValidLog10ProbabilityVector(double[] vector, int expectedSize, boolean shouldSumToOne) {
        Utils.nonNull(vector);
        return vector.length == expectedSize && MathUtils.allMatch(vector, MathUtils::isValidLog10Probability) && (!shouldSumToOne || MathUtils.compareDoubles(MathUtils.sumLog10(vector), 1.0, 1.0E-4) == 0);
    }

    public static double sumLog10(double[] log10values) {
        return Math.pow(10.0, MathUtils.log10SumLog10(Utils.nonNull(log10values)));
    }

    public static int[] vectorDiff(int[] x, int[] y) {
        Utils.nonNull(x, "x is null");
        Utils.nonNull(y, "y is null");
        Utils.validateArg(x.length == y.length, "Lengths of x and y must be the same");
        return new IndexRange(0, x.length).mapToInteger(k -> x[k] - y[k]);
    }

    public static double log10MultinomialCoefficient(int n, int[] k) {
        Utils.validateArg(n >= 0, "n: Must have non-negative number of trials");
        Utils.validateArg(MathUtils.allMatch(k, (int x) -> x >= 0), "Elements of k must be non-negative");
        Utils.validateArg(MathUtils.sum(k) == (long)n, "Sum of observations k must sum to total number of trials n");
        return MathUtils.log10Factorial(n) - new IndexRange(0, k.length).sum(j -> MathUtils.log10Factorial(k[j]));
    }

    public static double log10(int i) {
        return LOG_10_CACHE.get(i);
    }

    public static double digamma(int i) {
        return DIGAMMA_CACHE.get(i);
    }

    public static double log10sumLog10(double[] log10values) {
        return MathUtils.log10sumLog10(Utils.nonNull(log10values), 0);
    }

    public static double log10sumLog10(double[] log10p, int start) {
        return MathUtils.log10sumLog10(Utils.nonNull(log10p), start, log10p.length);
    }

    public static double log10sumLog10(double[] log10p, int start, int finish) {
        Utils.nonNull(log10p);
        if (finish - start < 2) {
            return finish == start ? Double.NEGATIVE_INFINITY : log10p[start];
        }
        int maxElementIndex = MathUtils.maxElementIndex(log10p, start, finish);
        double maxValue = log10p[maxElementIndex];
        if (maxValue == Double.NEGATIVE_INFINITY) {
            return maxValue;
        }
        double sum = 1.0 + new IndexRange(start, finish).sum(i -> i == maxElementIndex ? 0.0 : Math.pow(10.0, log10p[i] - maxValue));
        Utils.validateArg(!Double.isNaN(sum) && sum != Double.POSITIVE_INFINITY, "log10p values must be non-infinite and non-NAN");
        return maxValue + Math.log10(sum);
    }

    public static int fastRound(double d) {
        return d > 0.0 ? (int)(d + 0.5) : (int)(d - 0.5);
    }

    public static double logToLog10(double ln) {
        return ln * LOG10_E;
    }

    public static double approximateLog10SumLog10(double[] vals) {
        return MathUtils.approximateLog10SumLog10(Utils.nonNull(vals), vals.length);
    }

    public static double approximateLog10SumLog10(double[] vals, int endIndex) {
        Utils.nonNull(vals);
        int maxElementIndex = MathUtils.maxElementIndex(vals, endIndex);
        double approxSum = vals[maxElementIndex];
        for (int i = 0; i < endIndex; ++i) {
            double diff;
            if (i == maxElementIndex || vals[i] == Double.NEGATIVE_INFINITY) continue;
            approxSum += (diff = approxSum - vals[i]) < 8.0 ? JacobianLogTable.get(diff) : 0.0;
        }
        return approxSum;
    }

    public static double approximateLog10SumLog10(double a, double b, double c) {
        return MathUtils.approximateLog10SumLog10(a, MathUtils.approximateLog10SumLog10(b, c));
    }

    public static double approximateLog10SumLog10(double a, double b) {
        if (a > b) {
            return MathUtils.approximateLog10SumLog10(b, a);
        }
        if (a == Double.NEGATIVE_INFINITY) {
            return b;
        }
        double diff = b - a;
        return b + (diff < 8.0 ? JacobianLogTable.get(diff) : 0.0);
    }

    public static double approximateLog10SumLog10(double[] vals, int fromIndex, int toIndex) {
        Utils.nonNull(vals);
        if (fromIndex == toIndex) {
            return Double.NEGATIVE_INFINITY;
        }
        int maxElementIndex = MathUtils.maxElementIndex(vals, fromIndex, toIndex);
        double approxSum = vals[maxElementIndex];
        for (int i = fromIndex; i < toIndex; ++i) {
            double diff;
            double val;
            if (i == maxElementIndex || (val = vals[i]) == Double.NEGATIVE_INFINITY || !((diff = approxSum - val) < 8.0)) continue;
            approxSum += JacobianLogTable.get(diff);
        }
        return approxSum;
    }

    public static double sum(double[] values) {
        Utils.nonNull(values);
        double s = 0.0;
        for (double v : values) {
            s += v;
        }
        return s;
    }

    public static long sum(int[] x) {
        Utils.nonNull(x);
        long total = 0L;
        for (int v : x) {
            total += (long)v;
        }
        return total;
    }

    public static long sum(long[] x) {
        Utils.nonNull(x);
        int total = 0;
        for (long v : x) {
            total = (int)((long)total + v);
        }
        return total;
    }

    public static double sum(double[] arr, int start, int stop) {
        Utils.nonNull(arr);
        Utils.validateArg(start <= stop, () -> start + " > " + stop);
        Utils.validateArg(start >= 0, () -> start + " < " + 0);
        Utils.validateArg(stop <= arr.length, () -> stop + " >  " + arr.length);
        double result = 0.0;
        for (int n = start; n < stop; ++n) {
            result += arr[n];
        }
        return result;
    }

    public static <E> double sumDoubleFunction(Collection<E> collection, ToDoubleFunction<E> function) {
        double result = 0.0;
        for (E e : collection) {
            result += function.applyAsDouble(e);
        }
        return result;
    }

    public static <E> int sumIntFunction(Collection<E> collection, ToIntFunction<E> function) {
        int result = 0;
        for (E e : collection) {
            result += function.applyAsInt(e);
        }
        return result;
    }

    public static byte compareDoubles(double a, double b) {
        return MathUtils.compareDoubles(a, b, 1.0E-6);
    }

    public static byte compareDoubles(double a, double b, double epsilon) {
        if (Math.abs(a - b) < epsilon) {
            return 0;
        }
        if (a > b) {
            return -1;
        }
        return 1;
    }

    public static double binomialCoefficient(int n, int k) {
        return Math.pow(10.0, MathUtils.log10BinomialCoefficient(n, k));
    }

    public static double log10BinomialCoefficient(int n, int k) {
        Utils.validateArg(n >= 0, "Must have non-negative number of trials");
        Utils.validateArg(k <= n && k >= 0, "k: Must have non-negative number of successes, and no more successes than number of trials");
        return MathUtils.log10Factorial(n) - MathUtils.log10Factorial(k) - MathUtils.log10Factorial(n - k);
    }

    public static double binomialProbability(int n, int k, double p) {
        return Math.pow(10.0, MathUtils.log10BinomialProbability(n, k, Math.log10(p)));
    }

    public static double log10BinomialProbability(int n, int k, double log10p) {
        Utils.validateArg(log10p < 1.0E-18, "log10p: Log10-probability must be 0 or less");
        if (log10p == Double.NEGATIVE_INFINITY) {
            return k == 0 ? 0.0 : Double.NEGATIVE_INFINITY;
        }
        if (log10p == 0.0) {
            return k == n ? 0.0 : Double.NEGATIVE_INFINITY;
        }
        double log10OneMinusP = Math.log10(1.0 - Math.pow(10.0, log10p));
        return MathUtils.log10BinomialCoefficient(n, k) + log10p * (double)k + log10OneMinusP * (double)(n - k);
    }

    public static double log10BinomialProbability(int n, int k) {
        return MathUtils.log10BinomialCoefficient(n, k) + (double)n * LOG10_ONE_HALF;
    }

    public static double log10SumLog10(double[] log10Values, int start) {
        return MathUtils.log10SumLog10(Utils.nonNull(log10Values), start, log10Values.length);
    }

    public static double log10SumLog10(double[] log10Values) {
        return MathUtils.log10SumLog10(Utils.nonNull(log10Values), 0);
    }

    public static double log10SumLog10(double[] log10Values, int start, int finish) {
        Utils.nonNull(log10Values);
        if (start >= finish) {
            return Double.NEGATIVE_INFINITY;
        }
        int maxElementIndex = MathUtils.maxElementIndex(log10Values, start, finish);
        double maxValue = log10Values[maxElementIndex];
        if (maxValue == Double.NEGATIVE_INFINITY) {
            return maxValue;
        }
        double sum = 1.0;
        for (int i = start; i < finish; ++i) {
            double curVal = log10Values[i];
            if (i == maxElementIndex || curVal == Double.NEGATIVE_INFINITY) continue;
            double scaled_val = curVal - maxValue;
            sum += Math.pow(10.0, scaled_val);
        }
        if (Double.isNaN(sum) || sum == Double.POSITIVE_INFINITY) {
            throw new IllegalArgumentException("log10 p: Values must be non-infinite and non-NAN");
        }
        return maxValue + (sum != 1.0 ? Math.log10(sum) : 0.0);
    }

    public static double log10SumLog10(double a, double b) {
        return a > b ? a + Math.log10(1.0 + Math.pow(10.0, b - a)) : b + Math.log10(1.0 + Math.pow(10.0, a - b));
    }

    public static double log10SumLog10(double a, double b, double c) {
        if (a >= b && a >= c) {
            return a + Math.log10(1.0 + Math.pow(10.0, b - a) + Math.pow(10.0, c - a));
        }
        if (b >= c) {
            return b + Math.log10(1.0 + Math.pow(10.0, a - b) + Math.pow(10.0, c - b));
        }
        return c + Math.log10(1.0 + Math.pow(10.0, a - c) + Math.pow(10.0, b - c));
    }

    public static double normalDistributionLog10(double mean, double sd, double x) {
        Utils.validateArg(sd >= 0.0, "sd: Standard deviation of normal must be > 0");
        if (!(MathUtils.wellFormedDouble(mean) && MathUtils.wellFormedDouble(sd) && MathUtils.wellFormedDouble(x))) {
            throw new IllegalArgumentException("mean, sd, or, x : Normal parameters must be well formatted (non-INF, non-NAN)");
        }
        double a = -1.0 * Math.log10(sd * ROOT_TWO_PI);
        double b = -1.0 * (MathUtils.square(x - mean) / (2.0 * MathUtils.square(sd))) / LOG_10;
        return a + b;
    }

    public static double square(double x) {
        return x * x;
    }

    public static double distanceSquared(double[] x, double[] y) {
        Utils.nonNull(x);
        Utils.nonNull(y);
        return new IndexRange(0, x.length).sum(n -> MathUtils.square(x[n] - y[n]));
    }

    public static double[] normalizeFromLog10ToLinearSpace(double[] array) {
        return MathUtils.normalizeLog10(Utils.nonNull(array), false, true);
    }

    public static double[] normalizeLog10(double[] array) {
        return MathUtils.normalizeLog10(Utils.nonNull(array), true, true);
    }

    public static double[] normalizeLog10(double[] array, boolean takeLog10OfOutput, boolean inPlace) {
        double log10Sum = MathUtils.log10SumLog10(Utils.nonNull(array));
        double[] result = inPlace ? MathUtils.applyToArrayInPlace(array, x -> x - log10Sum) : MathUtils.applyToArray(array, (double x) -> x - log10Sum);
        return takeLog10OfOutput ? result : MathUtils.applyToArrayInPlace(result, x -> Math.pow(10.0, x));
    }

    public static double[] normalizeLog10DeleteMePlease(double[] array, boolean takeLog10OfOutput) {
        Utils.nonNull(array);
        double maxValue = MathUtils.arrayMax(array);
        double[] normalized = MathUtils.applyToArray(array, (double x) -> Math.pow(10.0, x - maxValue));
        double sum = MathUtils.sum(normalized);
        if (!takeLog10OfOutput) {
            return MathUtils.applyToArrayInPlace(normalized, x -> x / sum);
        }
        double log10Sum = Math.log10(sum);
        return MathUtils.applyToArrayInPlace(array, x -> x - maxValue - log10Sum);
    }

    public static double[] scaleLogSpaceArrayForNumericalStability(double[] array) {
        Utils.nonNull(array);
        double maxValue = MathUtils.arrayMax(array);
        return MathUtils.applyToArrayInPlace(array, x -> x - maxValue);
    }

    public static double[] normalizeSumToOne(double[] array) {
        Utils.nonNull(array);
        if (array.length == 0) {
            return array;
        }
        double sum = MathUtils.sum(array);
        Utils.validateArg(sum >= 0.0, () -> "Values in probability array sum to a negative number " + sum);
        return MathUtils.applyToArray(array, (double x) -> x / sum);
    }

    public static int maxElementIndex(double[] array) {
        return MathUtils.maxElementIndex(Utils.nonNull(array), array.length);
    }

    public static int maxElementIndex(double[] array, int start, int endIndex) {
        Utils.nonNull(array);
        Utils.validateArg(array.length > 0, "array may not be empty");
        Utils.validateArg(start <= endIndex, "Start cannot be after end.");
        int maxI = start;
        for (int i = start + 1; i < endIndex; ++i) {
            if (!(array[i] > array[maxI])) continue;
            maxI = i;
        }
        return maxI;
    }

    public static int maxElementIndex(double[] array, int endIndex) {
        return MathUtils.maxElementIndex(Utils.nonNull(array), 0, endIndex);
    }

    public static int arrayMax(int[] array) {
        Utils.nonNull(array);
        return array[MathUtils.maxElementIndex(array)];
    }

    public static int arrayMax(int[] array, int from, int to, int defaultValue) {
        if (to > from) {
            int value = array[from];
            for (int i = from + 1; i < to; ++i) {
                int candidate = array[i];
                if (candidate <= value) continue;
                value = candidate;
            }
            return value;
        }
        if (from >= 0) {
            return defaultValue;
        }
        throw new ArrayIndexOutOfBoundsException(from);
    }

    public static double arrayMax(double[] array) {
        Utils.nonNull(array);
        return array[MathUtils.maxElementIndex(array)];
    }

    public static int arrayMin(int[] array) {
        Utils.nonNull(array);
        int min = array[0];
        for (int i = 0; i < array.length; ++i) {
            if (array[i] >= min) continue;
            min = array[i];
        }
        return min;
    }

    public static boolean isValidLog10Probability(double result) {
        return result <= 0.0;
    }

    public static boolean isValidProbability(double result) {
        return result >= 0.0 && result <= 1.0;
    }

    public static double log10ToLog(double log10) {
        return log10 * LOG_10;
    }

    public static double log10Gamma(double x) {
        return MathUtils.logToLog10(Gamma.logGamma((double)x));
    }

    public static double log10Factorial(int n) {
        return LOG_10_FACTORIAL_CACHE.get(n);
    }

    public static double[] toLog10(double[] prRealSpace) {
        Utils.nonNull(prRealSpace);
        return MathUtils.applyToArray(prRealSpace, Math::log10);
    }

    public static double log10OneMinusX(double x) {
        if (x == 1.0) {
            return Double.NEGATIVE_INFINITY;
        }
        if (x == 0.0) {
            return 0.0;
        }
        double d = Math.log10(1.0 / x - 1.0) + Math.log10(x);
        return Double.isInfinite(d) || d > 0.0 ? 0.0 : d;
    }

    public static <T extends Number> double median(Collection<T> values) {
        Utils.nonEmpty(values, "cannot take the median of a collection with no values.");
        return new Median().evaluate(values.stream().mapToDouble(Number::doubleValue).toArray());
    }

    public static double roundToNDecimalPlaces(double in, int n) {
        Utils.validateArg(n > 0, "must round to at least one decimal place");
        double mult = Math.pow(10.0, n);
        return (double)Math.round((in + Math.ulp(in)) * mult) / mult;
    }

    public static boolean wellFormedDouble(double val) {
        return !Double.isInfinite(val) && !Double.isNaN(val);
    }

    public static double normalDistribution(double mean, double sd, double x) {
        Utils.validateArg(sd >= 0.0, "sd: Standard deviation of normal must be >= 0");
        Utils.validateArg(MathUtils.wellFormedDouble(mean) && MathUtils.wellFormedDouble(sd) && MathUtils.wellFormedDouble(x), "mean, sd, or, x : Normal parameters must be well formatted (non-INF, non-NAN)");
        return Math.exp(-(x - mean) * (x - mean) / (2.0 * sd * sd)) / (sd * ROOT_TWO_PI);
    }

    public static double dirichletMultinomial(double[] params, int[] counts) {
        Utils.nonNull(params);
        Utils.nonNull(counts);
        Utils.validateArg(params.length == counts.length, "The number of dirichlet parameters must match the number of categories");
        double dirichletSum = MathUtils.sum(params);
        int countSum = (int)MathUtils.sum(counts);
        double prefactor = MathUtils.log10MultinomialCoefficient(countSum, counts) + MathUtils.log10Gamma(dirichletSum) - MathUtils.log10Gamma(dirichletSum + (double)countSum);
        return prefactor + new IndexRange(0, counts.length).sum(n -> MathUtils.log10Gamma((double)counts[n] + params[n]) - MathUtils.log10Gamma(params[n]));
    }

    public static double[] applyToArray(double[] array, DoubleUnaryOperator func) {
        Utils.nonNull(func);
        Utils.nonNull(array);
        double[] result = new double[array.length];
        for (int m = 0; m < result.length; ++m) {
            result[m] = func.applyAsDouble(array[m]);
        }
        return result;
    }

    public static double[] applyToArray(int[] array, IntToDoubleFunction func) {
        Utils.nonNull(func);
        Utils.nonNull(array);
        double[] result = new double[array.length];
        for (int m = 0; m < result.length; ++m) {
            result[m] = func.applyAsDouble(array[m]);
        }
        return result;
    }

    public static double[] applyToArrayInPlace(double[] array, DoubleUnaryOperator func) {
        Utils.nonNull(array);
        Utils.nonNull(func);
        for (int m = 0; m < array.length; ++m) {
            array[m] = func.applyAsDouble(array[m]);
        }
        return array;
    }

    public static boolean allMatch(double[] array, DoublePredicate pred) {
        Utils.nonNull(array);
        Utils.nonNull(pred);
        for (double x : array) {
            if (pred.test(x)) continue;
            return false;
        }
        return true;
    }

    public static boolean allMatch(int[] array, IntPredicate pred) {
        Utils.nonNull(array, "array may not be null");
        Utils.nonNull(pred, "predicate may not be null");
        for (int x : array) {
            if (pred.test(x)) continue;
            return false;
        }
        return true;
    }

    public static int maxElementIndex(int[] array) {
        int maxIndex = 0;
        int currentMax = Integer.MIN_VALUE;
        for (int i = 0; i < array.length; ++i) {
            if (array[i] <= currentMax) continue;
            maxIndex = i;
            currentMax = array[i];
        }
        return maxIndex;
    }

    public static int toIntExactOrThrow(long toConvert, Supplier<RuntimeException> exceptionSupplier) {
        if (toConvert == (long)((int)toConvert)) {
            return (int)toConvert;
        }
        throw exceptionSupplier.get();
    }

    public static final double fastBernoulliEntropy(double p) {
        double product = p * (1.0 - p);
        return product * (11.0 + 33.0 * product) / (2.0 + 20.0 * product);
    }

    private static final class JacobianLogTable {
        public static final double MAX_TOLERANCE = 8.0;
        private static final double TABLE_STEP = 1.0E-4;
        private static final double INV_STEP = 10000.0;
        private static final double[] cache = new IndexRange(0, 80001).mapToDouble(k -> Math.log10(1.0 + Math.pow(10.0, (double)(-k) * 1.0E-4)));

        private JacobianLogTable() {
        }

        public static double get(double difference) {
            int index = MathUtils.fastRound(difference * 10000.0);
            return cache[index];
        }
    }

    public static class RunningAverage {
        private double mean = 0.0;
        private double s = 0.0;
        private long obs_count = 0L;

        public void add(double obs) {
            ++this.obs_count;
            double oldMean = this.mean;
            this.mean += (obs - this.mean) / (double)this.obs_count;
            this.s += (obs - oldMean) * (obs - this.mean);
        }

        public void addAll(Collection<Number> col) {
            for (Number o : col) {
                this.add(o.doubleValue());
            }
        }

        public double mean() {
            return this.mean;
        }

        public double stddev() {
            return Math.sqrt(this.s / (double)(this.obs_count - 1L));
        }

        public double var() {
            return this.s / (double)(this.obs_count - 1L);
        }

        public long observationCount() {
            return this.obs_count;
        }

        public RunningAverage clone() {
            RunningAverage ra = new RunningAverage();
            ra.mean = this.mean;
            ra.s = this.s;
            ra.obs_count = this.obs_count;
            return ra;
        }

        public void merge(RunningAverage other) {
            if (this.obs_count > 0L || other.obs_count > 0L) {
                this.mean = (this.mean * (double)this.obs_count + other.mean * (double)other.obs_count) / (double)(this.obs_count + other.obs_count);
                this.s += other.s;
            }
            this.obs_count += other.obs_count;
        }
    }

    @FunctionalInterface
    public static interface IntToDoubleArrayFunction {
        public double[] apply(int var1);
    }
}

