/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils.mcmc;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.primes.Primes;
import org.apache.commons.math3.random.RandomGenerator;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.mcmc.AbstractSliceSampler;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class MinibatchSliceSampler<DATA>
extends AbstractSliceSampler {
    private final List<DATA> data;
    private final Function<Double, Double> logPrior;
    private final BiFunction<DATA, Double, Double> logLikelihood;
    private final Integer minibatchSize;
    private final Double approxThreshold;
    private final int numDataPoints;
    private Double xSampleCache = null;
    private Double logPriorCache = null;
    private Map<DATA, Double> logLikelihoodsCache = null;

    public MinibatchSliceSampler(RandomGenerator rng, List<DATA> data, Function<Double, Double> logPrior, BiFunction<DATA, Double, Double> logLikelihood, double xMin, double xMax, double width, int minibatchSize, double approxThreshold) {
        super(rng, xMin, xMax, width);
        Utils.nonNull(data);
        Utils.nonNull(logPrior);
        Utils.nonNull(logLikelihood);
        Utils.validateArg(minibatchSize > 1, "Minibatch size must be greater than 1.");
        ParamUtils.isPositiveOrZero(approxThreshold, "Minibatch approximation threshold must be non-negative.");
        this.data = Collections.unmodifiableList(new ArrayList<DATA>(data));
        this.logPrior = logPrior;
        this.logLikelihood = logLikelihood;
        this.minibatchSize = minibatchSize;
        this.approxThreshold = approxThreshold;
        this.numDataPoints = data.size();
    }

    public MinibatchSliceSampler(RandomGenerator rng, List<DATA> data, Function<Double, Double> logPrior, BiFunction<DATA, Double, Double> logLikelihood, double width, int minibatchSize, double approxThreshold) {
        this(rng, data, logPrior, logLikelihood, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, width, minibatchSize, approxThreshold);
    }

    @Override
    boolean isGreaterThanSliceHeight(double xProposed, double xSample, double z) {
        if (xProposed < this.xMin || this.xMax < xProposed) {
            return false;
        }
        if (this.xSampleCache == null || this.xSampleCache != xSample) {
            this.xSampleCache = xSample;
            this.logPriorCache = this.logPrior.apply(xSample);
            this.logLikelihoodsCache = new HashMap<DATA, Double>(this.numDataPoints);
        }
        if (this.xSampleCache == null || this.logPriorCache == null || this.logLikelihoodsCache == null) {
            throw new GATKException.ShouldNeverReachHereException("Cache for xSample is in an invalid state.");
        }
        if (this.numDataPoints == 0) {
            return this.logPrior.apply(xProposed) > this.logPriorCache - z;
        }
        double mu0 = (this.logPriorCache - this.logPrior.apply(xProposed) - z) / (double)this.numDataPoints;
        int numMinibatches = Math.max(this.numDataPoints / this.minibatchSize, 1);
        Iterator shuffledDataIterator = numMinibatches > 1 ? MinibatchSliceSampler.lazyShuffleIterator(this.rng, this.data) : this.data.iterator();
        int numDataIndicesSeen = 0;
        double logLikelihoodDifferencesMean = 0.0;
        double logLikelihoodDifferencesSquaredMean = 0.0;
        for (int minibatchIndex = 0; minibatchIndex < numMinibatches; ++minibatchIndex) {
            double s;
            double delta;
            int dataIndexStart = minibatchIndex * this.minibatchSize;
            int dataIndexEnd = Math.min((minibatchIndex + 1) * this.minibatchSize, this.numDataPoints);
            int actualMinibatchSize = dataIndexEnd - dataIndexStart;
            List dataMinibatch = IntStream.range(0, actualMinibatchSize).boxed().map(i -> shuffledDataIterator.next()).collect(Collectors.toList());
            double logLikelihoodDifferencesMinibatchSum = 0.0;
            double logLikelihoodDifferencesSquaredMinibatchSum = 0.0;
            for (Object dataPoint : dataMinibatch) {
                double logLikelihoodxSample = this.logLikelihoodsCache.computeIfAbsent(dataPoint, d -> this.logLikelihood.apply(d, xSample));
                double logLikelihoodxProposed = this.logLikelihood.apply(dataPoint, xProposed);
                double logLikelihoodDifference = logLikelihoodxProposed - logLikelihoodxSample;
                logLikelihoodDifferencesMinibatchSum += logLikelihoodDifference;
                logLikelihoodDifferencesSquaredMinibatchSum += logLikelihoodDifference * logLikelihoodDifference;
            }
            logLikelihoodDifferencesMean = ((double)numDataIndicesSeen * logLikelihoodDifferencesMean + logLikelihoodDifferencesMinibatchSum) / (double)(numDataIndicesSeen + actualMinibatchSize);
            logLikelihoodDifferencesSquaredMean = ((double)numDataIndicesSeen * logLikelihoodDifferencesSquaredMean + logLikelihoodDifferencesSquaredMinibatchSum) / (double)(numDataIndicesSeen + actualMinibatchSize);
            if (numMinibatches == 1 || (delta = 1.0 - new TDistribution(null, (double)((numDataIndicesSeen += actualMinibatchSize) - 1)).cumulativeProbability(Math.abs((logLikelihoodDifferencesMean - mu0) / (s = Math.sqrt(1.0 - (double)numDataIndicesSeen / (double)this.numDataPoints) * Math.sqrt((logLikelihoodDifferencesSquaredMean - Math.pow(logLikelihoodDifferencesMean, 2.0)) / (double)(numDataIndicesSeen - 1)))))) < this.approxThreshold) break;
        }
        return logLikelihoodDifferencesMean > mu0;
    }

    private static <T> Iterator<T> lazyShuffleIterator(final RandomGenerator rng, final List<T> data) {
        final int numDataPoints = data.size();
        final int nextPrime = Primes.nextPrime((int)numDataPoints);
        return new Iterator<T>(){
            int numSeen = 0;
            int index;
            final int increment = this.index = rng.nextInt(numDataPoints) + 1;

            @Override
            public boolean hasNext() {
                return this.numSeen < data.size();
            }

            @Override
            public T next() {
                do {
                    this.index = (this.index + this.increment) % nextPrime;
                } while (this.index >= numDataPoints);
                ++this.numSeen;
                return data.get(this.index);
            }
        };
    }
}

