/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect.filtering;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.variant.variantcontext.VariantContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.util.CombinatoricsUtils;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.StrandBiasUtils;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.ErrorProbabilities;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.ErrorType;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2AlleleFilter;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.tools.walkers.validation.basicshortmutpileup.BetaBinomialDistribution;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

public class StrandArtifactFilter
extends Mutect2AlleleFilter {
    private double INITIAL_ALPHA_STRAND = 1.0;
    private double INITIAL_BETA_STRAND = 20.0;
    private double alphaStrand = this.INITIAL_ALPHA_STRAND;
    private double betaStrand = this.INITIAL_BETA_STRAND;
    private static final double ALPHA_SEQ = 1.0;
    private static final double BETA_SEQ_SNV = 1000.0;
    private static final double BETA_SEQ_SHORT_INDEL = 5000.0;
    private static final double BETA_SEQ_LONG_INDEL = 50000.0;
    private static final int LONG_INDEL_SIZE = 3;
    private static final int LONGEST_STRAND_ARTIFACT_INDEL_SIZE = 4;
    private static final double INITIAL_STRAND_ARTIFACT_PRIOR = 0.001;
    private double strandArtifactPrior = 0.001;
    private static final double ARTIFACT_PSEUDOCOUNT = 1.0;
    private static final double NON_ARTIFACT_PSEUDOCOUNT = 1000.0;
    private final List<EStep> eSteps = new ArrayList<EStep>();

    @Override
    public ErrorType errorType() {
        return ErrorType.ARTIFACT;
    }

    @Override
    public List<Double> calculateErrorProbabilityForAlleles(VariantContext vc, Mutect2FilteringEngine filteringEngine, ReferenceContext referenceContext) {
        List<EStep> alleleProbs = this.calculateArtifactProbabilities(vc, filteringEngine);
        return alleleProbs.isEmpty() ? Collections.emptyList() : alleleProbs.stream().map(probabilities -> ((EStep)probabilities).forwardArtifactResponsibility + ((EStep)probabilities).reverseArtifactResponsibility).collect(Collectors.toList());
    }

    public List<EStep> calculateArtifactProbabilities(VariantContext vc, Mutect2FilteringEngine filteringEngine) {
        List<List<Integer>> sbs = StrandBiasUtils.getSBsForAlleles(vc);
        if (sbs == null || sbs.isEmpty() || sbs.size() <= 1) {
            return Collections.emptyList();
        }
        if (vc.hasSymbolicAlleles()) {
            sbs = GATKVariantContextUtils.removeDataForSymbolicAlleles(vc, sbs);
        }
        List indelSizes = vc.getAlternateAlleles().stream().map(alt -> Math.abs(vc.getReference().length() - alt.length())).collect(Collectors.toList());
        int totalFwd = sbs.stream().map(sb -> (Integer)sb.get(0)).mapToInt(i -> i).sum();
        int totalRev = sbs.stream().map(sb -> (Integer)sb.get(1)).mapToInt(i -> i).sum();
        List<List<Integer>> altSBs = sbs.subList(1, sbs.size());
        return IntStream.range(0, altSBs.size()).mapToObj(i -> {
            List altSB = (List)altSBs.get(i);
            int altIndelSize = (Integer)indelSizes.get(i);
            if (altSB.stream().mapToInt(Integer::intValue).sum() == 0 || altIndelSize > 4) {
                return new EStep(0.0, 0.0, totalFwd, totalRev, (Integer)altSB.get(0), (Integer)altSB.get(1));
            }
            return this.strandArtifactProbability(this.strandArtifactPrior, totalFwd, totalRev, (Integer)altSB.get(0), (Integer)altSB.get(1), altIndelSize);
        }).collect(Collectors.toList());
    }

    @Override
    protected void accumulateDataForLearning(VariantContext vc, ErrorProbabilities errorProbabilities, Mutect2FilteringEngine filteringEngine) {
        if (this.requiredInfoAnnotations().stream().allMatch(arg_0 -> ((VariantContext)vc).hasAttribute(arg_0))) {
            List<EStep> altESteps = this.calculateArtifactProbabilities(vc, filteringEngine);
            this.eSteps.addAll(altESteps);
        }
    }

    @Override
    protected void clearAccumulatedData() {
        this.eSteps.clear();
    }

    @Override
    protected void learnParameters() {
        List potentialArtifacts = this.eSteps.stream().filter(eStep -> eStep.getArtifactProbability() > 0.1).collect(Collectors.toList());
        double totalArtifacts = potentialArtifacts.stream().mapToDouble(EStep::getArtifactProbability).sum();
        double totalNonArtifacts = this.eSteps.stream().mapToDouble(e -> 1.0 - e.getArtifactProbability()).sum();
        this.strandArtifactPrior = (totalArtifacts + 1.0) / (totalArtifacts + 1.0 + totalNonArtifacts + 1000.0);
        double artifactAltCount = potentialArtifacts.stream().mapToDouble(e -> ((EStep)e).forwardArtifactResponsibility * (double)((EStep)e).forwardAltCount + ((EStep)e).reverseArtifactResponsibility * (double)((EStep)e).reverseAltCount).sum();
        double artifactDepth = potentialArtifacts.stream().mapToDouble(e -> ((EStep)e).forwardArtifactResponsibility * (double)((EStep)e).forwardCount + ((EStep)e).reverseArtifactResponsibility * (double)((EStep)e).reverseCount).sum();
        double artifactBetaMean = (artifactAltCount + this.INITIAL_ALPHA_STRAND) / (artifactDepth + this.INITIAL_ALPHA_STRAND + this.INITIAL_BETA_STRAND);
        DoubleUnaryOperator objective = alpha -> {
            double beta = (1.0 / artifactBetaMean - 1.0) * alpha;
            return potentialArtifacts.stream().mapToDouble(e -> e.getForwardArtifactResponsibility() * StrandArtifactFilter.artifactStrandLogLikelihood(((EStep)e).forwardCount, ((EStep)e).forwardAltCount, alpha, beta) + e.getReverseArtifactResponsibility() * StrandArtifactFilter.artifactStrandLogLikelihood(((EStep)e).reverseCount, ((EStep)e).reverseAltCount, alpha, beta)).sum();
        };
        this.alphaStrand = OptimizationUtils.max(objective, 0.01, 100.0, this.INITIAL_ALPHA_STRAND, 0.01, 0.01, 100).getPoint();
        this.betaStrand = (1.0 / artifactBetaMean - 1.0) * this.alphaStrand;
        this.eSteps.clear();
    }

    @VisibleForTesting
    EStep strandArtifactProbability(double strandArtifactPrior, int forwardCount, int reverseCount, int forwardAltCount, int reverseAltCount, int indelSize) {
        double forwardLogLikelihood = this.artifactStrandLogLikelihood(forwardCount, forwardAltCount) + this.nonArtifactStrandLogLikelihood(reverseCount, reverseAltCount, indelSize);
        double reverseLogLikelihood = this.artifactStrandLogLikelihood(reverseCount, reverseAltCount) + this.nonArtifactStrandLogLikelihood(forwardCount, forwardAltCount, indelSize);
        double noneLogLikelihood = CombinatoricsUtils.binomialCoefficientLog((int)forwardCount, (int)forwardAltCount) + CombinatoricsUtils.binomialCoefficientLog((int)reverseCount, (int)reverseAltCount) - CombinatoricsUtils.binomialCoefficientLog((int)(forwardCount + reverseCount), (int)(forwardAltCount + reverseAltCount)) + new BetaBinomialDistribution(null, 1.0, 1.0, forwardCount + reverseCount).logProbability(forwardAltCount + reverseAltCount);
        double forwardLogPrior = Math.log(strandArtifactPrior / 2.0);
        double reverseLogPrior = Math.log(strandArtifactPrior / 2.0);
        double noneLogPrior = Math.log(1.0 - strandArtifactPrior);
        double[] forwardReverseNoneProbs = MathUtils.normalizeLog10(new double[]{(forwardLogLikelihood + forwardLogPrior) * MathUtils.LOG10_E, (reverseLogLikelihood + reverseLogPrior) * MathUtils.LOG10_E, (noneLogLikelihood + noneLogPrior) * MathUtils.LOG10_E}, false, true);
        return new EStep(forwardReverseNoneProbs[0], forwardReverseNoneProbs[1], forwardCount, reverseCount, forwardAltCount, reverseAltCount);
    }

    @Override
    public String filterName() {
        return "strand_bias";
    }

    @Override
    protected List<String> requiredInfoAnnotations() {
        return Collections.singletonList("AS_SB_TABLE");
    }

    private double artifactStrandLogLikelihood(int strandCount, int strandAltCount) {
        return StrandArtifactFilter.artifactStrandLogLikelihood(strandCount, strandAltCount, this.alphaStrand, this.betaStrand);
    }

    private static double artifactStrandLogLikelihood(int strandCount, int strandAltCount, double alpha, double beta) {
        return new BetaBinomialDistribution(null, alpha, beta, strandCount).logProbability(strandAltCount);
    }

    private double nonArtifactStrandLogLikelihood(int strandCount, int strandAltCount, int indelSize) {
        double betaSeq = indelSize == 0 ? 1000.0 : (indelSize < 3 ? 5000.0 : 50000.0);
        return new BetaBinomialDistribution(null, 1.0, betaSeq, strandCount).logProbability(strandAltCount);
    }

    @Override
    public Optional<String> phredScaledPosteriorAnnotationName() {
        return Optional.of("STRANDQ");
    }

    public static final class EStep {
        private double forwardArtifactResponsibility;
        private double reverseArtifactResponsibility;
        private int forwardCount;
        private int reverseCount;
        private int forwardAltCount;
        private int reverseAltCount;

        public EStep(double forwardArtifactResponsibility, double reverseArtifactResponsibility, int forwardCount, int reverseCount, int forwardAltCount, int reverseAltCount) {
            this.forwardArtifactResponsibility = forwardArtifactResponsibility;
            this.reverseArtifactResponsibility = reverseArtifactResponsibility;
            this.forwardCount = forwardCount;
            this.reverseCount = reverseCount;
            this.forwardAltCount = forwardAltCount;
            this.reverseAltCount = reverseAltCount;
        }

        public double getForwardArtifactResponsibility() {
            return this.forwardArtifactResponsibility;
        }

        public double getReverseArtifactResponsibility() {
            return this.reverseArtifactResponsibility;
        }

        public double getArtifactProbability() {
            return this.getForwardArtifactResponsibility() + this.getReverseArtifactResponsibility();
        }

        public int getForwardCount() {
            return this.forwardCount;
        }

        public int getReverseCount() {
            return this.reverseCount;
        }

        public int getForwardAltCount() {
            return this.forwardAltCount;
        }

        public int getReverseAltCount() {
            return this.reverseAltCount;
        }
    }
}

