/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.genotyper;

import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.GenotypeLikelihoods;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.broadinstitute.hellbender.tools.walkers.genotyper.DRAGENGenotypesModel;
import org.broadinstitute.hellbender.tools.walkers.genotyper.FRDBQDUtils;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAlleleCounts;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculator;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculators;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerGenotypingDebugger;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.genotyper.LikelihoodMatrix;
import org.broadinstitute.hellbender.utils.read.GATKRead;

public final class GenotypeLikelihoodCalculatorDRAGEN
extends GenotypeLikelihoodCalculator {
    static final double BQD_FIXED_ERROR_RATE = 0.5;
    static final double PHRED_SCALED_ADJUSTMENT_FOR_BQ_SCORE = 2.5;
    private LikelihoodMatrix<?, ?> cachedLikelihoods = null;
    private final double cachedLog10ErrorRate;
    private final double cachedLog10NonErrorRate;

    protected GenotypeLikelihoodCalculatorDRAGEN(int ploidy, int alleleCount, int[][] alleleFirstGenotypeOffsetByPloidy, GenotypeAlleleCounts[][] genotypeTableByPloidy) {
        super(ploidy, alleleCount, alleleFirstGenotypeOffsetByPloidy, genotypeTableByPloidy);
        Utils.validateArg(ploidy > 0, () -> "ploidy must be at least 1 but was " + ploidy);
        this.cachedLog10ErrorRate = Math.log10(0.5);
        this.cachedLog10NonErrorRate = Math.log10(0.5);
    }

    public <A extends Allele> double[] calculateBQDLikelihoods(LikelihoodMatrix<GATKRead, A> sampleLikelihoods, List<DRAGENGenotypesModel.DragenReadContainer> strandForward, List<DRAGENGenotypesModel.DragenReadContainer> strandReverse, byte[] paddedReference, int offsetForRefIntoEvent, GenotypeLikelihoodCalculators calculators) {
        Utils.validate(sampleLikelihoods == this.cachedLikelihoods, "There was a mismatch between the sample stored by the genotyper and the one requested for BQD, this will result in invalid genotype calling");
        double[] outputArray = new double[this.genotypeCount];
        Arrays.fill(outputArray, Double.NEGATIVE_INFINITY);
        A refAllele = sampleLikelihoods.getAllele(0);
        int readCount = sampleLikelihoods.evidenceCount();
        int alleleDataSize = readCount * (this.ploidy + 1);
        for (int gtAlleleIndex = 0; gtAlleleIndex < sampleLikelihoods.numberOfAlleles(); ++gtAlleleIndex) {
            int indexForGT = calculators.genotypeCount(this.ploidy, gtAlleleIndex + 1) - 1;
            double[] readLikelihoodsForGT = this.readLikelihoodsByGenotypeIndex[indexForGT];
            int errorAlleleIndex = 0;
            int offsetForReadLikelihoodGivenAlleleIndex = readCount;
            while (errorAlleleIndex < sampleLikelihoods.numberOfAlleles()) {
                if (sampleLikelihoods.getAllele(gtAlleleIndex) != sampleLikelihoods.getAllele(errorAlleleIndex) && sampleLikelihoods.getAllele(gtAlleleIndex).length() == refAllele.length() && sampleLikelihoods.getAllele(errorAlleleIndex).length() == refAllele.length()) {
                    byte baseOfErrorAllele = sampleLikelihoods.getAllele(errorAlleleIndex).getBases()[0];
                    double forwardHomopolymerAdjustment = FRDBQDUtils.computeForwardHomopolymerAdjustment(paddedReference, offsetForRefIntoEvent, baseOfErrorAllele);
                    double reverseHomopolymerAdjustment = FRDBQDUtils.computeReverseHomopolymerAdjustment(paddedReference, offsetForRefIntoEvent, baseOfErrorAllele);
                    double minScoreFoundForwardsStrand = this.computeBQDModelForStrandData(strandForward, forwardHomopolymerAdjustment, readLikelihoodsForGT, offsetForReadLikelihoodGivenAlleleIndex, true, errorAlleleIndex);
                    double minScoreFoundReverseStrand = this.computeBQDModelForStrandData(strandReverse, reverseHomopolymerAdjustment, readLikelihoodsForGT, offsetForReadLikelihoodGivenAlleleIndex, false, errorAlleleIndex);
                    double modelScoreInLog10 = (minScoreFoundForwardsStrand + minScoreFoundReverseStrand) * -0.1;
                    outputArray[indexForGT] = Math.max(outputArray[indexForGT], modelScoreInLog10);
                }
                ++errorAlleleIndex;
                offsetForReadLikelihoodGivenAlleleIndex += alleleDataSize;
            }
        }
        return outputArray;
    }

    private double computeBQDModelForStrandData(List<DRAGENGenotypesModel.DragenReadContainer> positionSortedReads, double homopolymerAdjustment, double[] readLikelihoodsForGT, int offsetForReadLikelihoodGivenAlleleIndex, boolean forwards, int errorAlleleIndex) {
        if (positionSortedReads.isEmpty()) {
            return 0.0;
        }
        if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
            HaplotypeCallerGenotypingDebugger.println("errorAllele index: " + errorAlleleIndex + " theta: " + (forwards ? "1" : "2") + " homopolymerAdjustment: " + homopolymerAdjustment);
        }
        int evidenceSize = positionSortedReads.size();
        double[] cumulativeProbReadForErrorAllele = new double[evidenceSize + 1];
        double[] cumulativeMeanBaseQualityPhredAdjusted = new double[evidenceSize + 1];
        double[] cumulativeProbGenotype = new double[evidenceSize + 1];
        double totalBaseQuality = 0.0;
        int baseQualityDenominator = 0;
        for (int i = 1; i < cumulativeProbReadForErrorAllele.length; ++i) {
            double errorAlleleContribution;
            double homozygousGenotypeContribution;
            DRAGENGenotypesModel.DragenReadContainer container = positionSortedReads.get(i - 1);
            int readIndex = container.getIndexInLikelihoodsObject();
            if (readIndex != -1) {
                homozygousGenotypeContribution = readLikelihoodsForGT[readIndex] - -MathUtils.LOG10_ONE_HALF;
                errorAlleleContribution = this.readAlleleLikelihoodByAlleleCount[offsetForReadLikelihoodGivenAlleleIndex + readIndex];
            } else {
                homozygousGenotypeContribution = 0.0;
                errorAlleleContribution = 0.0;
            }
            double phredContributionForRead = homozygousGenotypeContribution == 0.0 && errorAlleleContribution == 0.0 ? 0.0 : -10.0 * MathUtils.approximateLog10SumLog10(errorAlleleContribution + this.cachedLog10ErrorRate, homozygousGenotypeContribution + this.cachedLog10NonErrorRate);
            cumulativeProbReadForErrorAllele[i] = cumulativeProbReadForErrorAllele[i - 1] + phredContributionForRead;
            cumulativeProbGenotype[i] = cumulativeProbGenotype[i - 1] + -10.0 * homozygousGenotypeContribution;
            if (container.hasValidBaseQuality()) {
                totalBaseQuality += (double)container.getBaseQuality();
                ++baseQualityDenominator;
            }
            cumulativeMeanBaseQualityPhredAdjusted[i] = Math.max(0.0, totalBaseQuality / (double)(baseQualityDenominator == 0 ? 1 : baseQualityDenominator) * 2.5 - homopolymerAdjustment);
        }
        double minScoreFound = Double.POSITIVE_INFINITY;
        int nIndexUsed = 0;
        for (int n = 0; n < cumulativeMeanBaseQualityPhredAdjusted.length; ++n) {
            double bqdScore = cumulativeMeanBaseQualityPhredAdjusted[n] + cumulativeProbReadForErrorAllele[n] + (cumulativeProbGenotype[cumulativeProbGenotype.length - 1] - cumulativeProbGenotype[n]);
            if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                HaplotypeCallerGenotypingDebugger.println(String.format("n=%d: %.2f, cum_phred_bq=%.2f, cum_prob_r_Error=%.2f, prob_G_remaining=%.2f", n, bqdScore, cumulativeMeanBaseQualityPhredAdjusted[n], cumulativeProbReadForErrorAllele[n], cumulativeProbGenotype[cumulativeProbGenotype.length - 1] - cumulativeProbGenotype[n]));
            }
            if (!(minScoreFound > bqdScore)) continue;
            minScoreFound = bqdScore;
            nIndexUsed = n;
        }
        if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
            HaplotypeCallerGenotypingDebugger.println(String.format("theta=%d n%d=%2d, best_phred_score =%5.2f q_mean=%5.2f, alpha=%4.2f, Ph(E)=%4.2f;  ", forwards ? 1 : 0, forwards ? 1 : 2, nIndexUsed, minScoreFound, cumulativeMeanBaseQualityPhredAdjusted[nIndexUsed], 0.5, cumulativeProbReadForErrorAllele[nIndexUsed]));
        }
        return minScoreFound;
    }

    public <A extends Allele> double[] calculateFRDLikelihoods(LikelihoodMatrix<GATKRead, A> sampleLikelihoods, double[] ploidyModelLikelihoods, List<DRAGENGenotypesModel.DragenReadContainer> readContainers, double snipAprioriHet, double indelAprioriHet, int maxEffectiveDepthForHetAdjustment, GenotypeLikelihoodCalculators calculators) {
        Utils.validate(sampleLikelihoods == this.cachedLikelihoods, "There was a mismatch between the sample stored by the genotyper and the one requested for BQD, this will result in invalid genotyping");
        double[] outputArray = new double[this.genotypeCount];
        Arrays.fill(outputArray, Double.NEGATIVE_INFINITY);
        A refAllele = sampleLikelihoods.getAllele(0);
        int readCount = sampleLikelihoods.evidenceCount();
        int alleleDataSize = readCount * (this.ploidy + 1);
        int fAlleleIndex = 0;
        int offsetForReadLikelihoodGivenAlleleIndex = readCount;
        while (fAlleleIndex < sampleLikelihoods.numberOfAlleles()) {
            boolean isIndel;
            boolean bl = isIndel = sampleLikelihoods.getAllele(fAlleleIndex).length() != refAllele.length();
            FRDCriticalThresholds thresholds = this.computeCriticalValues(readContainers, fAlleleIndex == 0 ? 0.0 : (isIndel ? indelAprioriHet : snipAprioriHet) * -0.1);
            if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                HaplotypeCallerGenotypingDebugger.println("fIndex: " + fAlleleIndex + " criticalValues: \n" + thresholds.getCriticalThresholdsTotal().stream().map(d -> Double.toString(d)).collect(Collectors.joining("\n")));
            }
            for (int gtAlleleIndex = 0; gtAlleleIndex < sampleLikelihoods.numberOfAlleles(); ++gtAlleleIndex) {
                double[] localBestModel;
                if (gtAlleleIndex == fAlleleIndex) continue;
                int indexForGT = calculators.genotypeCount(this.ploidy, gtAlleleIndex + 1) - 1;
                double[] readLikelihoodsForGT = this.readLikelihoodsByGenotypeIndex[indexForGT];
                if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                    HaplotypeCallerGenotypingDebugger.println("indexForGT " + indexForGT + " ooffsetForReadLikelihoodGivenAlleleIndex =" + offsetForReadLikelihoodGivenAlleleIndex);
                    HaplotypeCallerGenotypingDebugger.println("\nForwards Strands: ");
                }
                double[] maxLog10FForwardsStrand = this.computeFRDModelForStrandData(readContainers, c -> !c.isReverseStrand(), offsetForReadLikelihoodGivenAlleleIndex, readLikelihoodsForGT, thresholds.getCriticalThresholdsTotal());
                if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                    HaplotypeCallerGenotypingDebugger.println("\nReverse Strands: ");
                }
                double[] maxLog10FReverseStrand = this.computeFRDModelForStrandData(readContainers, c -> c.isReverseStrand(), offsetForReadLikelihoodGivenAlleleIndex, readLikelihoodsForGT, thresholds.getCriticalThresholdsTotal());
                if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                    HaplotypeCallerGenotypingDebugger.println("\nBoth Strands: ");
                }
                double[] maxLog10FBothStrands = this.computeFRDModelForStrandData(readContainers, c -> true, offsetForReadLikelihoodGivenAlleleIndex, readLikelihoodsForGT, thresholds.getCriticalThresholdsTotal());
                if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                    HaplotypeCallerGenotypingDebugger.println("gtAlleleIndex : " + gtAlleleIndex + " fAlleleIndex: " + fAlleleIndex + " forwards: " + maxLog10FForwardsStrand + " reverse: " + maxLog10FReverseStrand + " both: " + maxLog10FBothStrands);
                }
                if ((localBestModel = maxLog10FForwardsStrand)[0] < maxLog10FReverseStrand[0]) {
                    localBestModel = maxLog10FReverseStrand;
                }
                if (localBestModel[0] < maxLog10FBothStrands[0]) {
                    localBestModel = maxLog10FBothStrands;
                }
                if (maxEffectiveDepthForHetAdjustment > 0) {
                    double localBestModelScore = localBestModel[0] - localBestModel[1];
                    int closestGTAlleleIndex = this.allelesToIndex(gtAlleleIndex, fAlleleIndex);
                    double log10LikelihoodsForPloyidyModel = ploidyModelLikelihoods[closestGTAlleleIndex] - -MathUtils.LOG10_ONE_HALF;
                    int depthForGenotyping = sampleLikelihoods.evidenceCount();
                    double adjustedBestModel = log10LikelihoodsForPloyidyModel + (localBestModelScore - log10LikelihoodsForPloyidyModel) * ((double)Math.min(depthForGenotyping, maxEffectiveDepthForHetAdjustment) * 1.0 / (double)depthForGenotyping);
                    outputArray[indexForGT] = Math.max(outputArray[indexForGT], adjustedBestModel + localBestModel[1]);
                    if (!HaplotypeCallerGenotypingDebugger.isEnabled()) continue;
                    HaplotypeCallerGenotypingDebugger.println("best FRD likelihoods: " + localBestModelScore + " P(F) score used: " + localBestModel[1] + "  use MaxEffectiveDepth: " + maxEffectiveDepthForHetAdjustment);
                    HaplotypeCallerGenotypingDebugger.println("Using array index " + closestGTAlleleIndex + " for mixture gt with likelihood of " + log10LikelihoodsForPloyidyModel + " adjusted based on depth: " + depthForGenotyping);
                    HaplotypeCallerGenotypingDebugger.println("p_rG_adj : " + adjustedBestModel);
                    continue;
                }
                outputArray[indexForGT] = Math.max(outputArray[indexForGT], localBestModel[0]);
            }
            ++fAlleleIndex;
            offsetForReadLikelihoodGivenAlleleIndex += alleleDataSize;
        }
        return outputArray;
    }

    private double[] computeFRDModelForStrandData(List<DRAGENGenotypesModel.DragenReadContainer> positionSortedReads, Predicate<DRAGENGenotypesModel.DragenReadContainer> predicate, int offsetForReadLikelihoodGivenAlleleIndex, double[] readLikelihoodsForGT, List<Double> criticalThresholdsSorted) {
        if (positionSortedReads.isEmpty()) {
            return new double[]{Double.NEGATIVE_INFINITY, 0.0};
        }
        int counter = 0;
        double maxLpspi = Double.NEGATIVE_INFINITY;
        double lpfApplied = 0.0;
        for (Double logProbFAllele : criticalThresholdsSorted) {
            double fAlleleProbRatio = 0.0;
            double fAlleleProbDenom = 0.0;
            double localMaxLpspi = Double.NEGATIVE_INFINITY;
            for (DRAGENGenotypesModel.DragenReadContainer container : positionSortedReads) {
                if (container.wasFilteredByHMM()) continue;
                int readIndex = container.getIndexInLikelihoodsObject();
                if (!predicate.test(container)) continue;
                double LPd_r_F = container.getPhredPFValue() + 1.0E-7 <= logProbFAllele ? Double.NEGATIVE_INFINITY : this.readAlleleLikelihoodByAlleleCount[offsetForReadLikelihoodGivenAlleleIndex + readIndex];
                double lp_r_GT = readLikelihoodsForGT[readIndex] - -MathUtils.LOG10_ONE_HALF;
                fAlleleProbRatio += Math.pow(10.0, LPd_r_F - MathUtils.approximateLog10SumLog10(LPd_r_F, lp_r_GT));
                fAlleleProbDenom += 1.0;
            }
            double foreignAlleleLikelihood = Math.min(fAlleleProbRatio / fAlleleProbDenom, 0.5);
            double log10ForeignAlleleLikelihood = Math.log10(foreignAlleleLikelihood);
            double log10NotForeignAlleleLikelihood = Math.log10(1.0 - foreignAlleleLikelihood);
            double cumulativeLog10LikelihoodOfForeignReadHypothesis = 0.0;
            for (DRAGENGenotypesModel.DragenReadContainer container : positionSortedReads) {
                if (container.wasFilteredByHMM()) continue;
                int readIndex = container.getIndexInLikelihoodsObject();
                double log10LikelihoodReadForGenotype = readLikelihoodsForGT[readIndex] - -MathUtils.LOG10_ONE_HALF;
                if (predicate.test(container)) {
                    double log10LikelihoodOfForeignAlleleGivenLPFCutoff = container.getPhredPFValue() + 1.0E-7 <= logProbFAllele ? Double.NEGATIVE_INFINITY : this.readAlleleLikelihoodByAlleleCount[offsetForReadLikelihoodGivenAlleleIndex + readIndex];
                    cumulativeLog10LikelihoodOfForeignReadHypothesis += MathUtils.approximateLog10SumLog10(log10ForeignAlleleLikelihood + log10LikelihoodOfForeignAlleleGivenLPFCutoff, log10NotForeignAlleleLikelihood + log10LikelihoodReadForGenotype);
                    continue;
                }
                cumulativeLog10LikelihoodOfForeignReadHypothesis += log10LikelihoodReadForGenotype;
            }
            double LPsi = logProbFAllele + cumulativeLog10LikelihoodOfForeignReadHypothesis;
            localMaxLpspi = Math.max(localMaxLpspi, LPsi);
            if (HaplotypeCallerGenotypingDebugger.isEnabled()) {
                HaplotypeCallerGenotypingDebugger.println("beta: " + foreignAlleleLikelihood + " localMaxLpspi: " + localMaxLpspi + " for lpf: " + logProbFAllele + " with LP_R_GF: " + cumulativeLog10LikelihoodOfForeignReadHypothesis + " index: " + counter++);
            }
            if (!(localMaxLpspi > maxLpspi)) continue;
            maxLpspi = Math.max(maxLpspi, localMaxLpspi);
            lpfApplied = logProbFAllele;
        }
        return new double[]{maxLpspi, lpfApplied};
    }

    private FRDCriticalThresholds computeCriticalValues(List<DRAGENGenotypesModel.DragenReadContainer> container, double log10MapqPriorAdjustment) {
        HashSet criticalThresholdsForwards = new HashSet();
        HashSet criticalThresholdsReverse = new HashSet();
        HashSet<Double> criticalThresholdsTotal = new HashSet<Double>();
        for (int i = 0; i < container.size(); ++i) {
            DRAGENGenotypesModel.DragenReadContainer readContainer = container.get(i);
            double log10CriticalValue = readContainer.getPhredScaledMappingQuality() * -0.1 + log10MapqPriorAdjustment;
            readContainer.setPhredPFValue(log10CriticalValue);
            criticalThresholdsTotal.add(log10CriticalValue);
        }
        return new FRDCriticalThresholds(criticalThresholdsForwards, criticalThresholdsReverse, criticalThresholdsTotal);
    }

    @Override
    public <EVIDENCE, A extends Allele> GenotypeLikelihoods genotypeLikelihoods(LikelihoodMatrix<EVIDENCE, A> likelihoods) {
        this.cachedLikelihoods = null;
        GenotypeLikelihoods output = super.genotypeLikelihoods(likelihoods);
        this.cachedLikelihoods = likelihoods;
        return output;
    }

    public <EVIDENCE, A extends Allele> double[] rawGenotypeLikelihoods(LikelihoodMatrix<EVIDENCE, A> likelihoods) {
        this.cachedLikelihoods = null;
        double[] output = super.getReadRawReadLikelihoodsByGenotypeIndex(likelihoods);
        this.cachedLikelihoods = likelihoods;
        return output;
    }

    private class FRDCriticalThresholds {
        private final List<Double> criticalThresholdsForwards;
        private final List<Double> criticalThresholdsReverse;
        private final List<Double> criticalThresholdsTotal;

        private FRDCriticalThresholds(Set<Double> criticalThresholdsForwards, Set<Double> criticalThresholdsReverse, Set<Double> criticalThresholdsTotal) {
            this.criticalThresholdsForwards = criticalThresholdsForwards.stream().sorted(Double::compareTo).collect(Collectors.toList());
            this.criticalThresholdsReverse = criticalThresholdsReverse.stream().sorted(Double::compareTo).collect(Collectors.toList());
            this.criticalThresholdsTotal = criticalThresholdsTotal.stream().sorted(Double::compareTo).collect(Collectors.toList());
        }

        public List<Double> getCriticalThresholdsTotal() {
            return this.criticalThresholdsTotal;
        }

        public List<Double> getCriticalThresholdsForwards() {
            return this.criticalThresholdsForwards;
        }

        public List<Double> getCriticalThresholdsReverse() {
            return this.criticalThresholdsReverse;
        }
    }
}

