/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.models;

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.models.CopyRatioParameter;
import org.broadinstitute.hellbender.tools.copynumber.models.CopyRatioSegmentedData;
import org.broadinstitute.hellbender.tools.copynumber.models.CopyRatioState;
import org.broadinstitute.hellbender.tools.copynumber.models.FunctionCache;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.mcmc.MinibatchSliceSampler;
import org.broadinstitute.hellbender.utils.mcmc.ParameterSampler;

final class CopyRatioSamplers {
    private static final Logger logger = LogManager.getLogger(CopyRatioSamplers.class);
    private static final FunctionCache<Double> logCache = new FunctionCache<Double>(FastMath::log);
    private static final Function<Double, Double> UNIFORM_LOG_PRIOR = x -> 0.0;
    private static final int GLOBAL_MINIBATCH_SIZE = 1000;
    private static final int SEGMENT_MINIBATCH_SIZE = 100;
    private static final double APPROX_THRESHOLD = 0.1;

    private CopyRatioSamplers() {
    }

    private static double normalTerm(double quantity, double mean, double variance) {
        return (quantity - mean) * (quantity - mean) / (2.0 * variance);
    }

    static final class OutlierIndicatorsSampler
    implements ParameterSampler<CopyRatioState.OutlierIndicators, CopyRatioParameter, CopyRatioState, CopyRatioSegmentedData> {
        private final double outlierUniformLogLikelihood;

        OutlierIndicatorsSampler(double outlierUniformLogLikelihood) {
            this.outlierUniformLogLikelihood = outlierUniformLogLikelihood;
        }

        @Override
        public CopyRatioState.OutlierIndicators sample(RandomGenerator rng, CopyRatioState state, CopyRatioSegmentedData data) {
            logger.debug("Sampling outlier indicators...");
            double outlierUnnormalizedLogProbability = FastMath.log((double)state.outlierProbability()) + this.outlierUniformLogLikelihood;
            double notOutlierUnnormalizedLogProbabilityPrefactor = FastMath.log((double)((1.0 - state.outlierProbability()) / FastMath.sqrt((double)(Math.PI * 2 * state.variance()))));
            ArrayList<Boolean> indicators = new ArrayList<Boolean>(data.getNumPoints());
            for (int segmentIndex = 0; segmentIndex < data.getNumSegments(); ++segmentIndex) {
                List<CopyRatioSegmentedData.IndexedCopyRatio> indexedCopyRatiosInSegment = data.getIndexedCopyRatiosInSegment(segmentIndex);
                for (CopyRatioSegmentedData.IndexedCopyRatio indexedCopyRatio : indexedCopyRatiosInSegment) {
                    double notOutlierUnnormalizedLogProbability = notOutlierUnnormalizedLogProbabilityPrefactor - CopyRatioSamplers.normalTerm(indexedCopyRatio.getLog2CopyRatioValue(), state.segmentMean(segmentIndex), state.variance());
                    double conditionalProbability = FastMath.exp((double)(outlierUnnormalizedLogProbability - NaturalLogUtils.logSumLog(outlierUnnormalizedLogProbability, notOutlierUnnormalizedLogProbability)));
                    indicators.add(rng.nextDouble() < conditionalProbability);
                }
            }
            return new CopyRatioState.OutlierIndicators(indicators);
        }
    }

    static final class SegmentMeansSampler
    implements ParameterSampler<CopyRatioState.SegmentMeans, CopyRatioParameter, CopyRatioState, CopyRatioSegmentedData> {
        private final double meanMin;
        private final double meanMax;
        private final double meanSliceSamplingWidth;

        SegmentMeansSampler(double meanMin, double meanMax, double meanSliceSamplingWidth) {
            this.meanMin = meanMin;
            this.meanMax = meanMax;
            this.meanSliceSamplingWidth = meanSliceSamplingWidth;
        }

        @Override
        public CopyRatioState.SegmentMeans sample(RandomGenerator rng, CopyRatioState state, CopyRatioSegmentedData data) {
            ArrayList<Double> means = new ArrayList<Double>(data.getNumSegments());
            BiFunction<CopyRatioSegmentedData.IndexedCopyRatio, Double, Double> logConditionalPDF = (icr, newMean) -> state.outlierIndicator(icr.getIndex()) ? 0.0 : -CopyRatioSamplers.normalTerm(icr.getLog2CopyRatioValue(), newMean, state.variance());
            for (int segmentIndex = 0; segmentIndex < data.getNumSegments(); ++segmentIndex) {
                List<CopyRatioSegmentedData.IndexedCopyRatio> indexedCopyRatiosInSegment = data.getIndexedCopyRatiosInSegment(segmentIndex);
                if (indexedCopyRatiosInSegment.isEmpty()) {
                    means.add(Double.NaN);
                    continue;
                }
                logger.debug(String.format("Sampling mean for segment %d...", segmentIndex));
                MinibatchSliceSampler<CopyRatioSegmentedData.IndexedCopyRatio> sampler = new MinibatchSliceSampler<CopyRatioSegmentedData.IndexedCopyRatio>(rng, indexedCopyRatiosInSegment, UNIFORM_LOG_PRIOR, logConditionalPDF, this.meanMin, this.meanMax, this.meanSliceSamplingWidth, 100, 0.1);
                means.add(sampler.sample(state.segmentMean(segmentIndex)));
            }
            return new CopyRatioState.SegmentMeans((List<Double>)means);
        }
    }

    static final class OutlierProbabilitySampler
    implements ParameterSampler<Double, CopyRatioParameter, CopyRatioState, CopyRatioSegmentedData> {
        private final double outlierProbabilityPriorAlpha;
        private final double outlierProbabilityPriorBeta;

        OutlierProbabilitySampler(double outlierProbabilityPriorAlpha, double outlierProbabilityPriorBeta) {
            this.outlierProbabilityPriorAlpha = outlierProbabilityPriorAlpha;
            this.outlierProbabilityPriorBeta = outlierProbabilityPriorBeta;
        }

        @Override
        public Double sample(RandomGenerator rng, CopyRatioState state, CopyRatioSegmentedData data) {
            logger.debug("Sampling outlier probability...");
            int numOutliers = (int)IntStream.range(0, data.getNumPoints()).filter(state::outlierIndicator).count();
            return new BetaDistribution(rng, this.outlierProbabilityPriorAlpha + (double)numOutliers, this.outlierProbabilityPriorBeta + (double)data.getNumPoints() - (double)numOutliers).sample();
        }
    }

    static final class VarianceSampler
    implements ParameterSampler<Double, CopyRatioParameter, CopyRatioState, CopyRatioSegmentedData> {
        private final double varianceMin;
        private final double varianceMax;
        private final double varianceSliceSamplingWidth;

        VarianceSampler(double varianceMin, double varianceMax, double varianceSliceSamplingWidth) {
            this.varianceMin = varianceMin;
            this.varianceMax = varianceMax;
            this.varianceSliceSamplingWidth = varianceSliceSamplingWidth;
        }

        @Override
        public Double sample(RandomGenerator rng, CopyRatioState state, CopyRatioSegmentedData data) {
            logger.debug("Sampling variance...");
            List nonOutlierIndexedCopyRatios = data.getIndexedCopyRatios().stream().filter(icr -> !state.outlierIndicator(icr.getIndex())).collect(Collectors.toList());
            BiFunction<CopyRatioSegmentedData.IndexedCopyRatio, Double, Double> logConditionalPDF = (icr, newVariance) -> -0.5 * logCache.computeIfAbsent(newVariance) - CopyRatioSamplers.normalTerm(icr.getLog2CopyRatioValue(), state.segmentMean(icr.getSegmentIndex()), newVariance);
            return new MinibatchSliceSampler<CopyRatioSegmentedData.IndexedCopyRatio>(rng, nonOutlierIndexedCopyRatios, UNIFORM_LOG_PRIOR, logConditionalPDF, this.varianceMin, this.varianceMax, this.varianceSliceSamplingWidth, 1000, 0.1).sample(state.variance());
        }
    }
}

