/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect.clustering;

import com.google.common.primitives.Doubles;
import htsjdk.variant.variantcontext.VariantContext;
import it.unimi.dsi.fastutil.ints.Int2DoubleArrayMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.util.MathArrays;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.mutect.MutectStats;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.AlleleFractionCluster;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.BetaBinomialCluster;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.BinomialCluster;
import org.broadinstitute.hellbender.tools.walkers.mutect.clustering.Datum;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.M2FiltersArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Utils;

public class SomaticClusteringModel {
    protected final Logger logger = LogManager.getLogger(this.getClass());
    private boolean clustersHaveBeenInitialized;
    private static final int MAX_INDEL_SIZE_IN_PRIOR_MAP = 10;
    private static final int NUM_INITIALIZATION_QUANTILES = 50;
    private static final double MIN_QUANTILE_FOR_MAKING_CLUSTER = 0.1;
    private static final int MIN_QUANTILE_INDEX_FOR_MAKING_CLUSTER = 5;
    private final Map<Integer, Double> logVariantPriors = new Int2DoubleArrayMap();
    private double logVariantVsArtifactPrior;
    private final OptionalDouble callableSites;
    private static final double INITIAL_HIGH_AF_WEIGHT = 0.01;
    public static final double MAX_FRACTION_OF_BACKGROUND_TO_SPLIT_OFF = 0.9;
    private double REGULARIZING_PSEUDOCOUNT = 1.0;
    private double[] logClusterWeights;
    private static final int NUM_ITERATIONS = 5;
    private static final int MAX_BINOMIAL_CLUSTERS = 5;
    private static final BetaDistributionShape INITIAL_HIGH_AF_BETA = new BetaDistributionShape(10.0, 1.0);
    private static final BetaDistributionShape INITIAL_BACKGROUND_BETA = BetaDistributionShape.FLAT_BETA;
    final List<Datum> data = new ArrayList<Datum>();
    final List<AlleleFractionCluster> clusters = new ArrayList<AlleleFractionCluster>();
    private final MutableInt obviousArtifactCount = new MutableInt(0);
    private static final double OBVIOUS_ARTIFACT_PROBABILITY_THRESHOLD = 0.9;

    public SomaticClusteringModel(M2FiltersArgumentCollection MTFAC, List<MutectStats> mutectStats) {
        boolean noCallableSites;
        IntStream.range(-10, 11).forEach(n -> this.logVariantPriors.put(n, MTFAC.getLogIndelPrior()));
        this.logVariantPriors.put(0, MTFAC.getLogSnvPrior());
        this.logVariantVsArtifactPrior = MTFAC.initialLogPriorOfVariantVersusArtifact;
        OptionalDouble callableSitesFromStats = mutectStats.stream().filter(stat -> stat.getStatistic().equals("callable")).mapToDouble(MutectStats::getValue).findFirst();
        boolean bl = noCallableSites = callableSitesFromStats.isPresent() && callableSitesFromStats.getAsDouble() < 1.0;
        if (noCallableSites) {
            this.logger.warn("No callable sites found in Mutect stats.  Running without the full somatic clustering model.  Something is seriously wrong!");
        }
        this.callableSites = noCallableSites ? OptionalDouble.empty() : callableSitesFromStats;
        this.clusters.add(new BetaBinomialCluster(INITIAL_BACKGROUND_BETA));
        this.clusters.add(new BetaBinomialCluster(INITIAL_HIGH_AF_BETA));
        this.logClusterWeights = new double[]{Math.log1p(0.01), Math.log(0.01)};
    }

    public void record(int[] tumorADs, double[] tumorLogOdds, List<Double> artifactProbabilities, List<Double> nonSomaticProbabilities, VariantContext vc) {
        List<Integer> symIndexes = new IndexRange(0, vc.getNAlleles() - 1).filter(n -> vc.getAlternateAllele(n).isSymbolic());
        symIndexes.forEach(i -> {
            tumorADs[i.intValue()] = 0;
        });
        int totalAD = (int)MathUtils.sum(tumorADs);
        for (int i2 = 0; i2 < tumorLogOdds.length; ++i2) {
            if (vc.getAlternateAllele(i2).isSymbolic()) continue;
            if (artifactProbabilities.get(i2) > 0.9) {
                this.obviousArtifactCount.increment();
                continue;
            }
            if (nonSomaticProbabilities.get(i2) > 0.9) continue;
            this.data.add(new Datum(tumorLogOdds[i2], artifactProbabilities.get(i2), nonSomaticProbabilities.get(i2), tumorADs[i2 + 1], totalAD, SomaticClusteringModel.indelLength(vc, i2)));
        }
    }

    public double getLogPriorOfSomaticVariant(VariantContext vc, int altIndex) {
        return this.getLogPriorOfSomaticVariant(SomaticClusteringModel.indelLength(vc, altIndex));
    }

    public double getLogPriorOfVariantVersusArtifact() {
        return this.logVariantVsArtifactPrior;
    }

    public double probabilityOfSequencingError(Datum datum) {
        double[] logClusterLikelihoods = new IndexRange(0, this.clusters.size()).mapToDouble(c -> this.logClusterWeights[c] + this.clusters.get(c).correctedLogLikelihood(datum));
        double variantLogLikelihood = NaturalLogUtils.logSumExp(logClusterLikelihoods);
        return Mutect2FilteringEngine.posteriorProbabilityOfError(variantLogLikelihood, this.getLogPriorOfSomaticVariant(datum.getIndelLength()));
    }

    private double probabilityOfSomaticVariant(Datum datum) {
        double artifactProb = datum.getArtifactProb();
        double nonSomaticProb = datum.getNonSequencingErrorProb();
        double sequencingErrorProb = this.probabilityOfSequencingError(datum);
        return (1.0 - artifactProb) * (1.0 - nonSomaticProb) * (1.0 - sequencingErrorProb);
    }

    private void initializeClusters() {
        Utils.validate(!this.clustersHaveBeenInitialized, "Clusters have already been initialized.");
        double[] somaticProbs = this.data.stream().mapToDouble(this::probabilityOfSomaticVariant).toArray();
        double previousBIC = Double.NEGATIVE_INFINITY;
        for (int cluster = 0; cluster < 5; ++cluster) {
            Pair biggestPeakAndMass;
            double[] totalQuantileBackgroundResponsibilities;
            double[] oldLogClusterWeights = Arrays.copyOf(this.logClusterWeights, this.logClusterWeights.length);
            double[] backgroundProbsGivenSomatic = this.data.stream().mapToDouble(datum -> this.backgroundProbGivenSomatic(datum.getTotalCount(), datum.getAltCount())).toArray();
            double[] backGroundProbs = MathArrays.ebeMultiply((double[])somaticProbs, (double[])backgroundProbsGivenSomatic);
            double[] alleleFractionQuantiles = this.calculateAlleleFractionQuantiles();
            List<Pair<Double, Double>> peaksAndMasses = this.calculatePeaksAndMasses(alleleFractionQuantiles, totalQuantileBackgroundResponsibilities = this.calculateQuantileBackgroundResponsibilities(alleleFractionQuantiles, backGroundProbs));
            if (peaksAndMasses.isEmpty() || (Double)(biggestPeakAndMass = peaksAndMasses.stream().sorted(Comparator.comparingDouble(Pair::getRight).reversed()).findFirst().get()).getLeft() < alleleFractionQuantiles[Math.min(5, alleleFractionQuantiles.length - 1)]) break;
            double totalMass = peaksAndMasses.stream().mapToDouble(Pair::getRight).sum();
            double fractionOfBackgroundToSplit = Math.min(0.9, (Double)biggestPeakAndMass.getRight() / totalMass);
            double newClusterLogWeight = Math.log(fractionOfBackgroundToSplit) + this.logClusterWeights[0];
            double newBackgroundWeight = Math.log1p(fractionOfBackgroundToSplit) + this.logClusterWeights[0];
            this.clusters.add(new BinomialCluster((Double)biggestPeakAndMass.getLeft()));
            ArrayList<Double> newLogWeights = new ArrayList<Double>(Doubles.asList((double[])this.logClusterWeights));
            newLogWeights.add(newClusterLogWeight);
            newLogWeights.set(0, newBackgroundWeight);
            this.logClusterWeights = Doubles.toArray(newLogWeights);
            for (int n = 0; n < 5; ++n) {
                this.performEMIteration(false);
            }
            double[] logLikelihoodsGivenSomatic = this.data.stream().mapToDouble(datum -> this.logLikelihoodGivenSomatic(datum.getTotalCount(), datum.getAltCount())).toArray();
            double weightedLogLikelihood = MathUtils.sum(MathArrays.ebeMultiply((double[])somaticProbs, (double[])logLikelihoodsGivenSomatic));
            double effectiveSomaticCount = MathUtils.sum(somaticProbs);
            double numParameters = 2 * this.clusters.size();
            double currentBIC = weightedLogLikelihood - numParameters * Math.log(effectiveSomaticCount);
            if (currentBIC < previousBIC) {
                this.clusters.remove(this.clusters.size() - 1);
                this.logClusterWeights = oldLogClusterWeights;
                break;
            }
            previousBIC = currentBIC;
        }
        this.clustersHaveBeenInitialized = true;
    }

    private double[] calculateAlleleFractionQuantiles() {
        double quantileStep;
        List alleleFractionsAndSomaticProbs = this.data.stream().map(d -> ImmutablePair.of((Object)((double)d.getAltCount() / (double)d.getTotalCount()), (Object)this.probabilityOfSomaticVariant((Datum)d))).sorted(Comparator.comparingDouble(p -> (Double)p.getLeft())).collect(Collectors.toList());
        double totalSomaticProb = alleleFractionsAndSomaticProbs.stream().mapToDouble(p -> (Double)p.getRight()).sum();
        double cumulativeProb = 0.0;
        double quantileProb = quantileStep = totalSomaticProb / 50.0;
        ArrayList<Object> alleleFractionQuantilesList = new ArrayList<Object>(50);
        for (int n = 0; n < this.data.size(); ++n) {
            if (!((cumulativeProb += ((Double)((Pair)alleleFractionsAndSomaticProbs.get(n)).getRight()).doubleValue()) > quantileProb)) continue;
            alleleFractionQuantilesList.add(((Pair)alleleFractionsAndSomaticProbs.get(n)).getLeft());
            while (cumulativeProb > quantileProb) {
                quantileProb += quantileStep;
            }
        }
        return Doubles.toArray((Collection)alleleFractionQuantilesList.stream().distinct().collect(Collectors.toList()));
    }

    private double[] calculateQuantileBackgroundResponsibilities(double[] alleleFractionQuantiles, double[] backgroundProbs) {
        double[] totalQuantileResponsibilities = new double[alleleFractionQuantiles.length];
        for (int n = 0; n < this.data.size(); ++n) {
            Datum datum = this.data.get(n);
            double backgroundProb = backgroundProbs[n];
            double[] quantileResponsibilities = MathUtils.applyToArray(alleleFractionQuantiles, f -> MathUtils.binomialProbability(datum.getTotalCount(), datum.getAltCount(), f));
            MathUtils.applyToArrayInPlace(quantileResponsibilities, x -> x * backgroundProb * (double)(datum.getTotalCount() + 1));
            MathUtils.addToArrayInPlace(totalQuantileResponsibilities, quantileResponsibilities);
        }
        return totalQuantileResponsibilities;
    }

    private List<Pair<Double, Double>> calculatePeaksAndMasses(double[] alleleFractionQuantiles, double[] totalQuantileResponsibilities) {
        ArrayList<Pair<Double, Double>> peaksAndMasses = new ArrayList<Pair<Double, Double>>();
        double currentPeakMass = 0.0;
        double currentPeak = 0.0;
        double currentPeakResponsibility = 0.0;
        for (int q = 0; q < alleleFractionQuantiles.length; ++q) {
            boolean localMin;
            double leftResponsibility = q == 0 ? 0.0 : totalQuantileResponsibilities[q - 1];
            double responsibility = totalQuantileResponsibilities[q];
            double rightResponsibility = q == alleleFractionQuantiles.length - 1 ? 0.0 : totalQuantileResponsibilities[q + 1];
            double leftAlleleFraction = q == 0 ? 0.0 : alleleFractionQuantiles[q - 1];
            double alleleFraction = alleleFractionQuantiles[q];
            currentPeakMass += (alleleFraction - leftAlleleFraction) * (leftResponsibility + responsibility) / 2.0;
            if (responsibility > currentPeakResponsibility) {
                currentPeak = alleleFraction;
                currentPeakResponsibility = responsibility;
            }
            int leftCompare = Double.compare(responsibility, leftResponsibility);
            int rightCompare = Double.compare(responsibility, rightResponsibility);
            boolean bl = localMin = leftCompare < 0 && rightCompare <= 0 || leftCompare <= 0 && rightCompare < 0;
            if ((!localMin || q <= 0) && q != alleleFractionQuantiles.length - 1) continue;
            peaksAndMasses.add((Pair<Double, Double>)ImmutablePair.of((Object)currentPeak, (Object)currentPeakMass));
            currentPeakMass = 0.0;
            currentPeak = alleleFraction;
            currentPeakResponsibility = responsibility;
        }
        return peaksAndMasses;
    }

    private double getLogPriorOfSomaticVariant(int indelLength) {
        if (!this.logVariantPriors.containsKey(indelLength)) {
            this.logVariantPriors.put(indelLength, this.logVariantPriors.values().stream().mapToDouble(d -> d).min().getAsDouble());
        }
        return this.logVariantPriors.get(indelLength) + (indelLength == 0 ? MathUtils.LOG_ONE_THIRD : 0.0);
    }

    public void learnAndClearAccumulatedData() {
        if (!this.clustersHaveBeenInitialized) {
            this.initializeClusters();
        }
        for (int iteration = 0; iteration < 5; ++iteration) {
            this.performEMIteration(true);
        }
        this.data.clear();
        this.obviousArtifactCount.setValue(0);
    }

    private void performEMIteration(boolean updateSomaticPriors) {
        Map<Integer, MutableDouble> variantCountsByIndelLength = IntStream.range(-10, 11).boxed().collect(Collectors.toMap(l -> l, l -> new MutableDouble(0.0)));
        ArrayList<double[]> responsibilities = new ArrayList<double[]>(this.data.size());
        double[] totalClusterResponsibilities = new double[this.clusters.size()];
        for (Datum datum : this.data) {
            double somaticProb = this.probabilityOfSomaticVariant(datum);
            int indelLength = datum.getIndelLength();
            variantCountsByIndelLength.putIfAbsent(indelLength, new MutableDouble(0.0));
            variantCountsByIndelLength.get(indelLength).add(somaticProb);
            double[] clusterLogLikelihoods = new IndexRange(0, this.clusters.size()).mapToDouble(c -> this.logClusterWeights[c] + this.clusters.get(c).logLikelihood(datum.getTotalCount(), datum.getAltCount()));
            double[] clusterResponsibilitiesIfSomatic = NaturalLogUtils.normalizeFromLogToLinearSpace(clusterLogLikelihoods);
            double[] clusterResponsibilities = MathArrays.scale((double)somaticProb, (double[])clusterResponsibilitiesIfSomatic);
            MathUtils.addToArrayInPlace(totalClusterResponsibilities, clusterResponsibilities);
            responsibilities.add(clusterResponsibilities);
        }
        MathUtils.applyToArrayInPlace(totalClusterResponsibilities, x -> x + this.REGULARIZING_PSEUDOCOUNT);
        this.logClusterWeights = MathUtils.applyToArrayInPlace(MathUtils.normalizeSumToOne(totalClusterResponsibilities), Math::log);
        double technicalArtifactCount = (double)this.obviousArtifactCount.getValue().intValue() + this.data.stream().mapToDouble(Datum::getArtifactProb).sum();
        double variantCount = variantCountsByIndelLength.values().stream().mapToDouble(MutableDouble::doubleValue).sum();
        if (updateSomaticPriors) {
            this.logVariantVsArtifactPrior = Math.log((variantCount + this.REGULARIZING_PSEUDOCOUNT) / (variantCount + technicalArtifactCount + this.REGULARIZING_PSEUDOCOUNT * 2.0));
            if (this.callableSites.isPresent()) {
                IntStream.range(-10, 11).forEach(n -> {
                    double empiricalRatio = variantCountsByIndelLength.getOrDefault(n, new MutableDouble(0.0)).doubleValue() / this.callableSites.getAsDouble();
                    this.logVariantPriors.put(n, Math.log(Math.max(empiricalRatio, n == 0 ? 1.0E-8 : 1.0E-9)));
                });
            }
        }
        new IndexRange(0, this.clusters.size()).forEach(n -> {
            double[] responsibilitiesForThisCluster = responsibilities.stream().mapToDouble(array -> array[n]).toArray();
            this.clusters.get(n).learn(this.data, responsibilitiesForThisCluster);
        });
    }

    public double logLikelihoodGivenSomatic(int totalCount, int altCount) {
        double[] logClusterLikelihoods = new IndexRange(0, this.clusters.size()).mapToDouble(c -> this.logClusterWeights[c] + this.clusters.get(c).logLikelihood(totalCount, altCount));
        return NaturalLogUtils.logSumExp(logClusterLikelihoods);
    }

    private double backgroundProbGivenSomatic(int totalCount, int altCount) {
        double[] logClusterLikelihoods = new IndexRange(0, this.clusters.size()).mapToDouble(c -> this.logClusterWeights[c] + this.clusters.get(c).logLikelihood(totalCount, altCount));
        double[] clusterProbabilities = NaturalLogUtils.normalizeFromLogToLinearSpace(logClusterLikelihoods);
        return clusterProbabilities[0];
    }

    public List<Pair<String, String>> clusteringMetadata() {
        ArrayList<Pair<String, String>> result = new ArrayList<Pair<String, String>>();
        IntStream.range(-10, 11).forEach(n -> {
            double logPrior = this.logVariantPriors.get(n);
            String type = n == 0 ? "SNV" : (n < 0 ? "deletion" : "insertion") + " of length " + Math.abs(n);
            result.add((Pair<String, String>)ImmutablePair.of((Object)("Ln prior of " + type), (Object)Double.toString(logPrior)));
        });
        result.add((Pair<String, String>)ImmutablePair.of((Object)"Background beta-binomial cluster", (Object)String.format("weight = %.4f, %s", Math.exp(this.logClusterWeights[0]), this.clusters.get(0).toString())));
        result.add((Pair<String, String>)ImmutablePair.of((Object)"High-AF beta-binomial cluster", (Object)String.format("weight = %.4f, %s", Math.exp(this.logClusterWeights[1]), this.clusters.get(1).toString())));
        IntStream.range(2, this.clusters.size()).boxed().sorted(Comparator.comparingDouble(c -> -this.logClusterWeights[c])).forEach(c -> result.add((Pair<String, String>)ImmutablePair.of((Object)"Binomial cluster", (Object)String.format("weight = %.4f, %s", Math.exp(this.logClusterWeights[c]), this.clusters.get((int)c).toString()))));
        return result;
    }

    public static int indelLength(VariantContext vc, int altIndex) {
        return vc.getAlternateAllele(altIndex).length() - vc.getReference().length();
    }
}

