/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import com.google.common.collect.ArrayListMultimap;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseEdge;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseGraph;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseVertex;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.ChainPruner;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.Path;
import org.broadinstitute.hellbender.tools.walkers.mutect.Mutect2Engine;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public class AdaptiveChainPruner<V extends BaseVertex, E extends BaseEdge>
extends ChainPruner<V, E> {
    private final double initialErrorProbability;
    private final double logOddsThreshold;
    private final double seedingLogOddsThreshold;
    private final int maxUnprunedVariants;

    public AdaptiveChainPruner(double initialErrorProbability, double logOddsThreshold, double seedingLogOddsThreshold, int maxUnprunedVariants) {
        ParamUtils.isPositive(initialErrorProbability, "Must have positive error probability");
        this.initialErrorProbability = initialErrorProbability;
        this.logOddsThreshold = logOddsThreshold;
        this.seedingLogOddsThreshold = seedingLogOddsThreshold;
        this.maxUnprunedVariants = maxUnprunedVariants;
    }

    @Override
    protected Collection<Path<V, E>> chainsToRemove(List<Path<V, E>> chains) {
        if (chains.isEmpty()) {
            return Collections.emptyList();
        }
        BaseGraph<V, E> graph = chains.get(0).getGraph();
        Collection<Path<V, E>> probableErrorChains = this.likelyErrorChains(chains, graph, this.initialErrorProbability);
        int errorCount = probableErrorChains.stream().mapToInt(c -> ((BaseEdge)c.getLastEdge()).getMultiplicity()).sum();
        int totalBases = chains.stream().mapToInt(c -> c.getEdges().stream().mapToInt(BaseEdge::getMultiplicity).sum()).sum();
        double errorRate = (double)errorCount / (double)totalBases;
        return this.likelyErrorChains(chains, graph, errorRate).stream().filter(c -> !c.getEdges().stream().anyMatch(BaseEdge::isRef)).collect(Collectors.toList());
    }

    private Collection<Path<V, E>> likelyErrorChains(List<Path<V, E>> chains, BaseGraph<V, E> graph, double errorRate) {
        Map<Path, Pair> chainLogOdds = chains.stream().collect(Collectors.toMap(c -> c, c -> this.chainLogOdds((Path<V, E>)c, graph, errorRate)));
        ArrayListMultimap vertexToSeedableChains = ArrayListMultimap.create();
        ArrayListMultimap vertexToGoodIncomingChains = ArrayListMultimap.create();
        ArrayListMultimap vertexToGoodOutgoingChains = ArrayListMultimap.create();
        for (Path<V, E> chain2 : chains) {
            if ((Double)chainLogOdds.get(chain2).getRight() >= this.logOddsThreshold) {
                vertexToGoodIncomingChains.put(chain2.getLastVertex(), chain2);
            }
            if ((Double)chainLogOdds.get(chain2).getLeft() >= this.logOddsThreshold) {
                vertexToGoodOutgoingChains.put(chain2.getFirstVertex(), chain2);
            }
            if (!((Double)chainLogOdds.get(chain2).getRight() >= this.seedingLogOddsThreshold) || !((Double)chainLogOdds.get(chain2).getLeft() >= this.seedingLogOddsThreshold)) continue;
            vertexToSeedableChains.put(chain2.getFirstVertex(), chain2);
            vertexToSeedableChains.put(chain2.getLastVertex(), chain2);
        }
        PriorityQueue<Pair> chainsToAdd = new PriorityQueue<Pair>(Comparator.comparingDouble(p -> -((Double)p.getRight()).doubleValue()).thenComparing(p -> ((BaseVertex)((Path)p.getLeft()).getFirstVertex()).getSequence(), BaseUtils.BASES_COMPARATOR));
        Path<V, E> maxWeightChain = this.getMaxWeightChain(chains);
        chainsToAdd.add((Pair)ImmutablePair.of(maxWeightChain, (Object)Double.POSITIVE_INFINITY));
        LinkedHashSet<BaseVertex> processedVertices = new LinkedHashSet<BaseVertex>();
        for (BaseVertex vertex : vertexToSeedableChains.keySet()) {
            if (vertexToSeedableChains.get((Object)vertex).size() <= 2) continue;
            vertexToGoodOutgoingChains.get((Object)vertex).forEach(chain -> chainsToAdd.add((Pair)ImmutablePair.of((Object)chain, (Object)((Pair)chainLogOdds.get(chain)).getLeft())));
            vertexToGoodIncomingChains.get((Object)vertex).forEach(chain -> chainsToAdd.add((Pair)ImmutablePair.of((Object)chain, (Object)((Pair)chainLogOdds.get(chain)).getRight())));
            processedVertices.add(vertex);
        }
        LinkedHashSet<Path> goodChains = new LinkedHashSet<Path>();
        HashSet verticesThatAlreadyHaveOutgoingGoodChains = new HashSet();
        MutableInt variantCount = new MutableInt(0);
        while (!chainsToAdd.isEmpty() && variantCount.intValue() <= this.maxUnprunedVariants) {
            boolean newVariant;
            Path chain3 = (Path)chainsToAdd.poll().getLeft();
            if (!goodChains.add(chain3)) continue;
            boolean bl = newVariant = !verticesThatAlreadyHaveOutgoingGoodChains.add(chain3.getFirstVertex());
            if (newVariant) {
                variantCount.increment();
            }
            if (newVariant && variantCount.intValue() > this.maxUnprunedVariants) continue;
            for (BaseVertex vertex : Arrays.asList(chain3.getFirstVertex(), chain3.getLastVertex())) {
                if (processedVertices.contains(vertex)) continue;
                vertexToGoodOutgoingChains.get((Object)vertex).forEach(c -> chainsToAdd.add((Pair)ImmutablePair.of((Object)c, (Object)((Pair)chainLogOdds.get(c)).getLeft())));
                vertexToGoodIncomingChains.get((Object)vertex).forEach(c -> chainsToAdd.add((Pair)ImmutablePair.of((Object)c, (Object)((Pair)chainLogOdds.get(c)).getRight())));
                processedVertices.add(vertex);
            }
        }
        return chains.stream().filter(c -> !goodChains.contains(c)).collect(Collectors.toSet());
    }

    private Path<V, E> getMaxWeightChain(Collection<Path<V, E>> chains) {
        return chains.stream().max(Comparator.comparingInt(chain -> chain.getEdges().stream().mapToInt(BaseEdge::getMultiplicity).max().orElse(0)).thenComparingInt(Path::length).thenComparing(c -> ((BaseVertex)c.getFirstVertex()).getSequence(), BaseUtils.BASES_COMPARATOR)).get();
    }

    private Pair<Double, Double> chainLogOdds(Path<V, E> chain, BaseGraph<V, E> graph, double errorRate) {
        int leftTotalMultiplicity = MathUtils.sumIntFunction(graph.outgoingEdgesOf(chain.getFirstVertex()), BaseEdge::getMultiplicity);
        int rightTotalMultiplicity = MathUtils.sumIntFunction(graph.incomingEdgesOf(chain.getLastVertex()), BaseEdge::getMultiplicity);
        int leftMultiplicity = ((BaseEdge)chain.getEdges().get(0)).getMultiplicity();
        int rightMultiplicity = ((BaseEdge)chain.getLastEdge()).getMultiplicity();
        double leftLogOdds = graph.isSource(chain.getFirstVertex()) ? 0.0 : Mutect2Engine.logLikelihoodRatio(leftTotalMultiplicity - leftMultiplicity, leftMultiplicity, errorRate);
        double rightLogOdds = graph.isSink(chain.getLastVertex()) ? 0.0 : Mutect2Engine.logLikelihoodRatio(rightTotalMultiplicity - rightMultiplicity, rightMultiplicity, errorRate);
        return ImmutablePair.of((Object)leftLogOdds, (Object)rightLogOdds);
    }
}

