/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.readorientation;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.Histogram;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.MathArrays;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.walkers.readorientation.AltSiteRecord;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactPrior;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactState;
import org.broadinstitute.hellbender.tools.walkers.readorientation.BetaDistributionShape;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2FilterConstants;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2FilterUtils;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ReadOrientation;
import org.broadinstitute.hellbender.tools.walkers.validation.basicshortmutpileup.BetaBinomialDistribution;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.NaturalLogUtils;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;

public class LearnReadOrientationModelEngine {
    private final double convergenceThreshold;
    private final int maxEMIterations;
    private final String referenceContext;
    private final Nucleotide refAllele;
    private final Histogram<Integer> refHistogram;
    private final List<Histogram<Integer>> altDepthOneHistograms;
    private final List<AltSiteRecord> altDesignMatrix;
    private final RealMatrix altResponsibilities;
    private final Map<Triple<Integer, Nucleotide, ReadOrientation>, double[]> responsibilitiesOfAltDepth1Sites;
    private final RealMatrix refResponsibilities;
    private final int numAltExamples;
    private final int numRefExamples;
    private final int numExamples;
    private RealVector effectiveCounts = new ArrayRealVector(F1R2FilterConstants.NUM_STATES);
    private static final Map<ArtifactState, BetaDistributionShape> alleleFractionPseudoCounts = LearnReadOrientationModelEngine.getPseudoCountsForAlleleFraction();
    private static final Map<ArtifactState, BetaDistributionShape> altF1R2FractionPseudoCounts = LearnReadOrientationModelEngine.getPseudoCountsForAltF1R2Fraction();
    private final MutableInt numIterations = new MutableInt();
    private final Logger logger;
    private int maxDepth;
    private static final double ALT_PSEUDOCOUNT = 1.0;
    private static final double REF_PSEUDOCOUNT = 9.0;
    private static final double PSEUDOCOUNT_OF_HOM_LIKELY = 10000.0;
    private static final double PSEUDOCOUNT_OF_HOM_UNLIKELY = 3.0;
    private static final double BALANCED_HET_PSEUDOCOUNT = 5.0;
    private static final double BALANCED_F1R2_PRIOR = 10.0;
    private static final double PSEUDOCOUNT_OF_SOMATIC_ALT = 2.0;
    private static final double PSEUDOCOUNT_OF_SOMATIC_REF = 5.0;
    private static final double PSEUDOCOUNT_OF_LIKELY_OUTCOME = 100.0;
    private static final double PSEUDOCOUNT_OF_RARE_OUTCOME = 1.0;

    public LearnReadOrientationModelEngine(Histogram<Integer> refHistogram, List<Histogram<Integer>> altDepthOneHistograms, List<AltSiteRecord> altDesignMatrixForContext, double convergenceThreshold, int maxEMIterations, int maxDepth, Logger logger) {
        this.refHistogram = Utils.nonNull(refHistogram);
        this.altDepthOneHistograms = Utils.nonNull(altDepthOneHistograms);
        this.altDesignMatrix = Utils.nonNull(altDesignMatrixForContext);
        this.referenceContext = refHistogram.getValueLabel();
        Utils.validate(this.referenceContext.length() == 3, String.format("reference context must have length %d but got %s", 3, this.referenceContext));
        Utils.validate(F1R2FilterConstants.CANONICAL_KMERS.contains(this.referenceContext), this.referenceContext + " is not in the set of canonical kmers");
        this.numAltExamples = this.altDesignMatrix.size() + altDepthOneHistograms.stream().mapToInt(h -> (int)h.getSumOfValues()).sum();
        this.numRefExamples = (int)refHistogram.getSumOfValues();
        this.numExamples = this.numAltExamples + this.numRefExamples;
        this.refResponsibilities = new Array2DRowRealMatrix(maxDepth, F1R2FilterConstants.NUM_STATES);
        this.altResponsibilities = new Array2DRowRealMatrix(this.altDesignMatrix.size(), F1R2FilterConstants.NUM_STATES);
        this.responsibilitiesOfAltDepth1Sites = new HashMap<Triple<Integer, Nucleotide, ReadOrientation>, double[]>();
        this.refAllele = F1R2FilterUtils.getMiddleBase(this.referenceContext);
        this.convergenceThreshold = convergenceThreshold;
        this.maxEMIterations = maxEMIterations;
        this.maxDepth = maxDepth;
        this.logger = logger;
    }

    public ArtifactPrior learnPriorForArtifactStates() {
        double l2Distance;
        double[] pseudocounts = LearnReadOrientationModelEngine.getFlatPrior(this.refAllele);
        double[] statePrior = Arrays.copyOf(pseudocounts, F1R2FilterConstants.NUM_STATES);
        do {
            double[] oldStatePrior = Arrays.copyOf(statePrior, F1R2FilterConstants.NUM_STATES);
            this.takeEstep(statePrior);
            statePrior = this.takeMstep(pseudocounts);
            l2Distance = MathArrays.distance((double[])oldStatePrior, (double[])statePrior);
            this.numIterations.increment();
        } while (l2Distance > this.convergenceThreshold && this.numIterations.intValue() < this.maxEMIterations);
        if (this.numIterations.intValue() == this.maxEMIterations) {
            this.logger.info(String.format("Context %s: with %s ref and %s alt examples, EM failed to converge within %d steps", this.referenceContext, this.numRefExamples, this.numAltExamples, this.maxEMIterations));
        } else {
            this.logger.info(String.format("Context %s: with %s ref and %s alt examples, EM converged in %d steps", this.referenceContext, this.numRefExamples, this.numAltExamples, this.numIterations.intValue()));
        }
        return new ArtifactPrior(this.referenceContext, statePrior, this.numExamples, this.numAltExamples);
    }

    private void takeEstep(double[] artifactPriors) {
        int i;
        for (i = 0; i < this.maxDepth; ++i) {
            int depth = i + 1;
            this.refResponsibilities.setRow(i, LearnReadOrientationModelEngine.computeResponsibilities(this.refAllele, this.refAllele, 0, 0, depth, artifactPriors, false));
        }
        for (int n = 0; n < this.altDesignMatrix.size(); ++n) {
            AltSiteRecord example = this.altDesignMatrix.get(n);
            int depth = example.getDepth();
            int altDepth = example.getAltCount();
            int altF1R2 = example.getAltF1R2();
            this.altResponsibilities.setRow(n, LearnReadOrientationModelEngine.computeResponsibilities(this.refAllele, example.getAltAllele(), altDepth, altF1R2, depth, artifactPriors, false));
        }
        for (i = 0; i < this.maxDepth; ++i) {
            int depth = i + 1;
            for (Nucleotide altAllele : Nucleotide.STANDARD_BASES) {
                for (ReadOrientation orientation : ReadOrientation.values()) {
                    if (altAllele == this.refAllele) continue;
                    int f1r2Depth = orientation == ReadOrientation.F1R2 ? 1 : 0;
                    Triple<Integer, Nucleotide, ReadOrientation> key = this.createKey(depth, altAllele, orientation);
                    this.responsibilitiesOfAltDepth1Sites.put(key, LearnReadOrientationModelEngine.computeResponsibilities(this.refAllele, altAllele, 1, f1r2Depth, depth, artifactPriors, false));
                }
            }
        }
    }

    private double[] takeMstep(double[] pseudocounts) {
        double[] effectiveAltCountsFromDesignMatrix = MathUtils.sumArrayFunction(0, this.altDesignMatrix.size(), n -> this.altResponsibilities.getRow(n));
        double[] effectiveAltCountsFromHistograms = new double[F1R2FilterConstants.NUM_STATES];
        for (Histogram<Integer> histogram : this.altDepthOneHistograms) {
            Triple<String, Nucleotide, ReadOrientation> triplet = F1R2FilterUtils.labelToTriplet(histogram.getValueLabel());
            Nucleotide altAllele = (Nucleotide)((Object)triplet.getMiddle());
            ReadOrientation orientation = (ReadOrientation)((Object)triplet.getRight());
            double[] effectiveAltCountsFromHistogram = MathUtils.sumArrayFunction(0, this.maxDepth, i -> MathArrays.scale((double)histogram.get((Comparable)Integer.valueOf(i + 1)).getValue(), (double[])this.responsibilitiesOfAltDepth1Sites.get(this.createKey(i + 1, altAllele, orientation))));
            effectiveAltCountsFromHistograms = MathArrays.ebeAdd((double[])effectiveAltCountsFromHistograms, (double[])effectiveAltCountsFromHistogram);
        }
        double[] effectiveAltCounts = MathArrays.ebeAdd((double[])effectiveAltCountsFromDesignMatrix, (double[])effectiveAltCountsFromHistograms);
        double[] effectiveRefCounts = MathUtils.sumArrayFunction(0, this.maxDepth, i -> MathArrays.scale((double)this.refHistogram.get((Comparable)Integer.valueOf(i + 1)).getValue(), (double[])this.refResponsibilities.getRow(i)));
        this.effectiveCounts = new ArrayRealVector(MathArrays.ebeAdd((double[])effectiveAltCounts, (double[])effectiveRefCounts));
        return MathUtils.normalizeSumToOne(this.effectiveCounts.add((RealVector)new ArrayRealVector(pseudocounts)).toArray());
    }

    public static double[] computeResponsibilities(Nucleotide refAllele, Nucleotide altAllele, int altDepth, int f1r2AltCount, int depth, double[] artifactPrior, boolean givenNotHomRef) {
        double[] logUnnormalizedResponsibilities = new double[F1R2FilterConstants.NUM_STATES];
        List<ArtifactState> refToRefArtifacts = ArtifactState.getRefToRefArtifacts(refAllele);
        for (ArtifactState state : ArtifactState.values()) {
            int stateIndex = state.ordinal();
            logUnnormalizedResponsibilities[stateIndex] = refToRefArtifacts.contains((Object)state) ? Double.NEGATIVE_INFINITY : (ArtifactState.artifactStates.contains((Object)state) && state.getAltAlleleOfArtifact() != altAllele ? Double.NEGATIVE_INFINITY : LearnReadOrientationModelEngine.computeLogPosterior(altDepth, f1r2AltCount, depth, artifactPrior[stateIndex], alleleFractionPseudoCounts.get((Object)state), altF1R2FractionPseudoCounts.get((Object)state)));
        }
        if (givenNotHomRef) {
            logUnnormalizedResponsibilities[ArtifactState.HOM_REF.ordinal()] = Double.NEGATIVE_INFINITY;
        }
        return NaturalLogUtils.normalizeFromLogToLinearSpace(logUnnormalizedResponsibilities);
    }

    private static double computeLogPosterior(int altDepth, int altF1R2Depth, int depth, double statePrior, BetaDistributionShape afPseudoCounts, BetaDistributionShape f1r2PseudoCounts) {
        Utils.validateArg(MathUtils.isValidProbability(statePrior), String.format("statePrior must be a probability but got %f", statePrior));
        return Math.log(statePrior) + new BetaBinomialDistribution(null, afPseudoCounts.getAlpha(), afPseudoCounts.getBeta(), depth).logProbability(altDepth) + new BetaBinomialDistribution(null, f1r2PseudoCounts.getAlpha(), f1r2PseudoCounts.getBeta(), altDepth).logProbability(altF1R2Depth);
    }

    private static Map<ArtifactState, BetaDistributionShape> getPseudoCountsForAlleleFraction() {
        HashMap<ArtifactState, BetaDistributionShape> alleleFractionPseudoCounts = new HashMap<ArtifactState, BetaDistributionShape>(ArtifactState.values().length);
        ArtifactState.getF1R2ArtifactStates().forEach(s -> alleleFractionPseudoCounts.put((ArtifactState)((Object)s), new BetaDistributionShape(1.0, 9.0)));
        ArtifactState.getF2R1ArtifactStates().forEach(s -> alleleFractionPseudoCounts.put((ArtifactState)((Object)s), new BetaDistributionShape(1.0, 9.0)));
        alleleFractionPseudoCounts.put(ArtifactState.HOM_REF, new BetaDistributionShape(3.0, 10000.0));
        alleleFractionPseudoCounts.put(ArtifactState.GERMLINE_HET, new BetaDistributionShape(5.0, 5.0));
        alleleFractionPseudoCounts.put(ArtifactState.SOMATIC_HET, new BetaDistributionShape(2.0, 5.0));
        alleleFractionPseudoCounts.put(ArtifactState.HOM_VAR, new BetaDistributionShape(10000.0, 3.0));
        return alleleFractionPseudoCounts;
    }

    private static Map<ArtifactState, BetaDistributionShape> getPseudoCountsForAltF1R2Fraction() {
        HashMap<ArtifactState, BetaDistributionShape> altF1R2FractionPseudoCounts = new HashMap<ArtifactState, BetaDistributionShape>(ArtifactState.values().length);
        ArtifactState.getF1R2ArtifactStates().forEach(z -> altF1R2FractionPseudoCounts.put((ArtifactState)((Object)z), new BetaDistributionShape(100.0, 1.0)));
        ArtifactState.getF2R1ArtifactStates().forEach(z -> altF1R2FractionPseudoCounts.put((ArtifactState)((Object)z), new BetaDistributionShape(1.0, 100.0)));
        ArtifactState.getNonArtifactStates().forEach(z -> altF1R2FractionPseudoCounts.put((ArtifactState)((Object)z), new BetaDistributionShape(10.0, 10.0)));
        return altF1R2FractionPseudoCounts;
    }

    @VisibleForTesting
    public double[] getRefResonsibilities(int rowNum) {
        return this.refResponsibilities.getRow(rowNum);
    }

    @VisibleForTesting
    public double[] getAltResonsibilities(int rowNum) {
        return this.altResponsibilities.getRow(rowNum);
    }

    @VisibleForTesting
    public double[] getAltDepth1Resonsibilities(int rowNum) {
        return null;
    }

    @VisibleForTesting
    public RealVector getEffectiveCounts() {
        return this.effectiveCounts;
    }

    @VisibleForTesting
    public double getEffectiveCounts(ArtifactState state) {
        return this.effectiveCounts.getEntry(state.ordinal());
    }

    public static double[] getFlatPrior(Nucleotide refAllele) {
        List<ArtifactState> refToRefStates = ArtifactState.getRefToRefArtifacts(refAllele);
        double[] prior = new double[F1R2FilterConstants.NUM_STATES];
        Arrays.fill(prior, 1.0 / (double)(F1R2FilterConstants.NUM_STATES - refToRefStates.size()));
        for (ArtifactState s : refToRefStates) {
            prior[s.ordinal()] = 0.0;
        }
        return prior;
    }

    private Triple<Integer, Nucleotide, ReadOrientation> createKey(int depth, Nucleotide altAllele, ReadOrientation orientation) {
        return new ImmutableTriple((Object)depth, (Object)altAllele, (Object)orientation);
    }
}

