/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.contamination;

import htsjdk.samtools.util.OverlapDetector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.function.DoubleUnaryOperator;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.commons.math3.util.FastMath;
import org.broadinstitute.hellbender.tools.walkers.contamination.ContaminationSegmenter;
import org.broadinstitute.hellbender.tools.walkers.contamination.MinorAlleleFractionRecord;
import org.broadinstitute.hellbender.tools.walkers.contamination.PileupSummary;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;

public class ContaminationModel {
    public static final double INITIAL_MAF_THRESHOLD = 0.4;
    public static final double MAF_TO_SWITCH_TO_HOM_REF = 0.25;
    public static final double MAF_TO_SWITCH_TO_UNSCRUPULOUS_HOM_REF = 0.2;
    public static final double UNSCRUPULOUS_HOM_REF_ALLELE_FRACTION = 0.15;
    public static final double UNSCRUPULOUS_HOM_REF_FRACTION_TO_REMOVE_FOR_POSSIBLE_LOH = 0.1;
    public static final double UNSCRUPULOUS_HOM_REF_PERCENTILE = 90.0;
    public static final double MINIMUM_UNSCRUPULOUS_HOM_REF_ALT_FRACTION_THRESHOLD = 0.1;
    public static final double MAF_STEP_SIZE = 0.04;
    private final double contamination;
    private final double errorRate;
    private final List<Double> minorAlleleFractions;
    private final List<List<PileupSummary>> segments;
    public static final int HOM_REF = 0;
    public static final int HOM_ALT = 3;
    private static final int NUM_ITERATIONS = 3;
    private static final double MIN_FRACTION_OF_SITES_TO_USE = 0.25;
    private static final double MIN_RELATIVE_ERROR = 0.2;
    private static final double MIN_ABSOLUTE_ERROR = 0.001;
    private static final List<Double> CONTAMINATION_INITIAL_GUESSES = Arrays.asList(0.02, 0.05, 0.1, 0.2);

    public ContaminationModel(List<PileupSummary> sites) {
        this.errorRate = ContaminationModel.calculateErrorRate(sites);
        this.segments = ContaminationSegmenter.findSegments(sites);
        int numSegments = this.segments.size();
        ArrayList<Double> minorAlleleFractionsGuess = new ArrayList<Double>(Collections.nCopies(this.segments.size(), 0.5));
        MutableDouble contaminationGuess = new MutableDouble(0.0);
        for (int n = 0; n < 3; ++n) {
            IntStream.range(0, numSegments).forEach(s -> minorAlleleFractionsGuess.set(s, ContaminationModel.calculateMinorAlleleFraction(contaminationGuess.doubleValue(), this.errorRate, this.segments.get(s))));
            Pair<List<List<PileupSummary>>, List<Double>> nonLOHSegmentsAndMafs = ContaminationModel.getNonLOHSegments(this.segments, minorAlleleFractionsGuess);
            contaminationGuess.setValue(ContaminationModel.calculateContamination(this.errorRate, (List)nonLOHSegmentsAndMafs.getLeft(), (List)nonLOHSegmentsAndMafs.getRight()));
        }
        this.minorAlleleFractions = minorAlleleFractionsGuess;
        this.contamination = contaminationGuess.doubleValue();
    }

    private static Pair<List<List<PileupSummary>>, List<Double>> getNonLOHSegments(List<List<PileupSummary>> segments, List<Double> mafs) {
        int numSites = segments.stream().mapToInt(List::size).sum();
        for (double minMaf = 0.4; minMaf > 0.0; minMaf -= 0.04) {
            double threshold = minMaf;
            int[] nonLOHIndices = IntStream.range(0, segments.size()).filter(s -> (Double)mafs.get(s) > threshold).toArray();
            List nonLOHSegments = Arrays.stream(nonLOHIndices).mapToObj(segments::get).collect(Collectors.toList());
            List nonLOHMafs = Arrays.stream(nonLOHIndices).mapToObj(mafs::get).collect(Collectors.toList());
            int numNonLOHSites = nonLOHSegments.stream().mapToInt(List::size).sum();
            if (!((double)numNonLOHSites / (double)numSites > 0.25)) continue;
            return ImmutablePair.of(nonLOHSegments, nonLOHMafs);
        }
        return ImmutablePair.of(segments, mafs);
    }

    public Pair<Double, Double> calculateContaminationFromHoms(List<PileupSummary> tumorSites) {
        for (double minMaf = 0.4; minMaf >= 0.0; minMaf -= 0.04) {
            Pair<Double, Double> result = minMaf > 0.25 ? this.calculateContamination(Strategy.HOM_ALT, tumorSites, minMaf) : (minMaf > 0.2 ? this.calculateContamination(Strategy.HOM_REF, tumorSites, minMaf) : this.calculateContamination(Strategy.UNSCRUPULOUS_HOM_REF, tumorSites, minMaf));
            if (Double.isNaN((Double)result.getLeft()) || !((Double)result.getRight() < (Double)result.getLeft() * 0.2 + 0.001)) continue;
            return result;
        }
        Pair result = this.calculateContamination(Strategy.UNSCRUPULOUS_HOM_REF, tumorSites, 0.0);
        return Double.isNaN((Double)result.getLeft()) ? Pair.of((Object)0.0, (Object)1.0) : result;
    }

    private Pair<Double, Double> calculateContamination(Strategy strategy, List<PileupSummary> tumorSites, double minMaf) {
        List<PileupSummary> genotypingHoms;
        boolean useHomAlt;
        boolean bl = useHomAlt = strategy == Strategy.HOM_ALT;
        if (strategy == Strategy.HOM_ALT) {
            genotypingHoms = this.homAlts(minMaf);
        } else if (strategy == Strategy.HOM_REF) {
            genotypingHoms = this.homRefs(minMaf);
        } else {
            List candidateHomRefs = tumorSites.stream().filter(site -> site.getAltFraction() < 0.15).collect(Collectors.toList());
            double altFractionThreshold = Math.max(0.1, new Percentile(90.0).evaluate(candidateHomRefs.stream().mapToDouble(PileupSummary::getAltFraction).toArray()));
            genotypingHoms = candidateHomRefs.stream().filter(site -> site.getAltFraction() <= altFractionThreshold).collect(Collectors.toList());
        }
        List<PileupSummary> homs = ContaminationModel.subsetSites(tumorSites, genotypingHoms);
        double tumorErrorRate = ContaminationModel.calculateErrorRate(tumorSites);
        ToIntFunction<PileupSummary> oppositeCount = useHomAlt ? PileupSummary::getRefCount : PileupSummary::getAltCount;
        ToDoubleFunction<PileupSummary> oppositeAlleleFrequency = useHomAlt ? PileupSummary::getRefFrequency : PileupSummary::getAlleleFrequency;
        long totalDepth = homs.stream().mapToLong(PileupSummary::getTotalCount).sum();
        long oppositeDepth = homs.stream().mapToLong(oppositeCount::applyAsInt).sum();
        long errorDepth = Math.round((double)totalDepth * tumorErrorRate / 3.0);
        long contaminationOppositeDepth = Math.max(oppositeDepth - errorDepth, 0L);
        double totalDepthWeightedByOppositeFrequency = homs.stream().mapToDouble(ps -> (double)ps.getTotalCount() * oppositeAlleleFrequency.applyAsDouble((PileupSummary)ps)).sum();
        double contamination = (double)contaminationOppositeDepth / totalDepthWeightedByOppositeFrequency;
        double stdError = homs.isEmpty() ? 1.0 : Math.sqrt(homs.stream().mapToDouble(ps -> {
            double d = ps.getTotalCount();
            double f = 1.0 - oppositeAlleleFrequency.applyAsDouble((PileupSummary)ps);
            return (1.0 - f) * d * contamination * (1.0 - contamination + f * d * contamination);
        }).sum()) / totalDepthWeightedByOppositeFrequency;
        return Pair.of((Object)Math.min(contamination, 1.0), (Object)stdError);
    }

    private List<PileupSummary> getType(int genotype, double minMaf) {
        int[] nonLOHIndices = IntStream.range(0, this.segments.size()).filter(s -> this.minorAlleleFractions.get(s) > minMaf).toArray();
        List nonLOHSegments = Arrays.stream(nonLOHIndices).mapToObj(this.segments::get).collect(Collectors.toList());
        List nonLOHMafs = Arrays.stream(nonLOHIndices).mapToObj(this.minorAlleleFractions::get).collect(Collectors.toList());
        return IntStream.range(0, nonLOHSegments.size()).mapToObj(n -> ((List)nonLOHSegments.get(n)).stream().filter(site -> ContaminationModel.probability(site, this.contamination, this.errorRate, (Double)nonLOHMafs.get(n), genotype) > 0.5)).flatMap(stream -> stream).collect(Collectors.toList());
    }

    private List<PileupSummary> homAlts(double minMaf) {
        return this.getType(3, minMaf);
    }

    private List<PileupSummary> homRefs(double minMaf) {
        return this.getType(0, minMaf);
    }

    public List<MinorAlleleFractionRecord> segmentationRecords() {
        return IntStream.range(0, this.segments.size()).mapToObj(n -> {
            List<PileupSummary> segment = this.segments.get(n);
            String contig = segment.get(0).getContig();
            int start = segment.get(0).getStart();
            int end = segment.get(segment.size() - 1).getEnd();
            double maf = this.minorAlleleFractions.get(n);
            return new MinorAlleleFractionRecord(new SimpleInterval(contig, start, end), maf);
        }).collect(Collectors.toList());
    }

    private static double calculateErrorRate(List<PileupSummary> sites) {
        long totalBases = sites.stream().mapToInt(PileupSummary::getTotalCount).sum();
        long otherAltBases = sites.stream().mapToInt(PileupSummary::getOtherAltCount).sum();
        return 1.5 * ((double)otherAltBases / (double)totalBases);
    }

    private static double calculateMinorAlleleFraction(double contamination, double errorRate, List<PileupSummary> segment) {
        DoubleUnaryOperator objective = maf -> ContaminationModel.segmentLogLikelihood(segment, contamination, errorRate, maf);
        return OptimizationUtils.max(objective, 0.1, 0.5, 0.4, 0.01, 0.01, 20).getPoint();
    }

    private static double calculateContamination(double errorRate, List<List<PileupSummary>> segments, List<Double> mafs) {
        DoubleUnaryOperator objective = c -> ContaminationModel.modelLogLikelihood(segments, c, errorRate, mafs);
        List optima = CONTAMINATION_INITIAL_GUESSES.stream().map(initial -> OptimizationUtils.max(objective, 0.0, 0.5, initial, 1.0E-4, 1.0E-4, 30)).collect(Collectors.toList());
        return Collections.max(optima, Comparator.comparingDouble(UnivariatePointValuePair::getValue)).getPoint();
    }

    private static double[] genotypeLikelihoods(PileupSummary site, double c, double errorRate, double maf) {
        double f = site.getAlleleFrequency();
        int k = site.getAltCount();
        int n = k + site.getRefCount();
        double[] samplePriors = new double[]{(1.0 - f) * (1.0 - f), f * (1.0 - f), f * (1.0 - f), f * f};
        double[] sampleAFs = new double[]{errorRate / 3.0, maf, 1.0 - maf, 1.0 - errorRate};
        return new IndexRange(0, 4).mapToDouble(sg -> samplePriors[sg] * MathUtils.binomialProbability(n, k, (1.0 - c) * sampleAFs[sg] + c * f));
    }

    private static double probability(PileupSummary site, double contamination, double errorRate, double minorAlleleFraction, int genotype) {
        double[] likelihoods = ContaminationModel.genotypeLikelihoods(site, contamination, errorRate, minorAlleleFraction);
        return likelihoods[genotype] / MathUtils.sum(likelihoods);
    }

    private static double segmentLogLikelihood(List<PileupSummary> segment, double contamination, double errorRate, double minorAlleleFraction) {
        return segment.stream().mapToDouble(site -> FastMath.log((double)MathUtils.sum(ContaminationModel.genotypeLikelihoods(site, contamination, errorRate, minorAlleleFraction)))).sum();
    }

    private static double modelLogLikelihood(List<List<PileupSummary>> segments, double contamination, double errorRate, List<Double> mafs) {
        Utils.validate(segments.size() == mafs.size(), " Must have one MAF per segment");
        return new IndexRange(0, segments.size()).sum(n -> ContaminationModel.segmentLogLikelihood((List)segments.get(n), contamination, errorRate, (Double)mafs.get(n)));
    }

    private static List<PileupSummary> subsetSites(List<PileupSummary> sites, List<PileupSummary> subsetLoci) {
        OverlapDetector od = OverlapDetector.create(subsetLoci);
        return sites.stream().filter(arg_0 -> ((OverlapDetector)od).overlapsAny(arg_0)).collect(Collectors.toList());
    }

    private static enum Strategy {
        HOM_ALT,
        HOM_REF,
        UNSCRUPULOUS_HOM_REF;

    }
}

