/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.sv.evidence;

import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.BreakpointEvidence;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.PartitionCrossingChecker;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.ReadMetadata;
import org.broadinstitute.hellbender.tools.spark.sv.utils.PairedStrandedIntervalTree;
import org.broadinstitute.hellbender.tools.spark.sv.utils.PairedStrandedIntervals;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVInterval;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVIntervalTree;
import org.broadinstitute.hellbender.tools.spark.sv.utils.StrandedInterval;
import org.broadinstitute.hellbender.utils.Utils;
import scala.Tuple2;

public final class BreakpointDensityFilter
implements Iterator<BreakpointEvidence> {
    private final ReadMetadata readMetadata;
    private final double minEvidenceWeight;
    private final double minCoherentEvidenceWeight;
    private final PartitionCrossingChecker partitionCrossingChecker;
    private final SVIntervalTree<List<BreakpointEvidence>> evidenceTree;
    private final int minEvidenceMapq;
    private Iterator<SVIntervalTree.Entry<List<BreakpointEvidence>>> treeItr;
    private Iterator<BreakpointEvidence> listItr;

    public BreakpointDensityFilter(Iterator<BreakpointEvidence> evidenceItr, ReadMetadata readMetadata, double minEvidenceWeightPerCoverage, double minCoherentEvidenceWeightPerCoverage, PartitionCrossingChecker partitionCrossingChecker, int minEvidenceMapq) {
        this.readMetadata = readMetadata;
        this.minEvidenceWeight = minEvidenceWeightPerCoverage * (double)readMetadata.getCoverage();
        this.minCoherentEvidenceWeight = minCoherentEvidenceWeightPerCoverage * (double)readMetadata.getCoverage();
        this.partitionCrossingChecker = partitionCrossingChecker;
        this.evidenceTree = BreakpointDensityFilter.buildTree(evidenceItr);
        this.treeItr = this.evidenceTree.iterator();
        this.listItr = null;
        this.minEvidenceMapq = minEvidenceMapq;
    }

    @Override
    public boolean hasNext() {
        if (this.listItr != null && this.listItr.hasNext()) {
            return true;
        }
        this.listItr = null;
        boolean result = false;
        while (!result && this.treeItr.hasNext()) {
            SVIntervalTree.Entry<List<BreakpointEvidence>> entry = this.treeItr.next();
            SVInterval curInterval = entry.getInterval();
            if (this.isValidated(entry.getValue()) || this.hasEnoughOverlappers(curInterval)) {
                entry.getValue().forEach(ev -> ev.setValidated(true));
                result = true;
            } else if (this.partitionCrossingChecker.onBoundary(curInterval)) {
                result = true;
            }
            if (!result) continue;
            this.listItr = entry.getValue().iterator();
        }
        return result;
    }

    @Override
    public BreakpointEvidence next() {
        if (!this.hasNext()) {
            throw new NoSuchElementException("No next element.");
        }
        return this.listItr.next();
    }

    private static SVIntervalTree<List<BreakpointEvidence>> buildTree(Iterator<BreakpointEvidence> evidenceItr) {
        SVIntervalTree tree = new SVIntervalTree();
        while (evidenceItr.hasNext()) {
            BreakpointEvidence evidence = evidenceItr.next();
            BreakpointDensityFilter.addToTree(tree, evidence.getLocation(), evidence);
        }
        return tree;
    }

    private boolean isValidated(List<BreakpointEvidence> evList) {
        for (BreakpointEvidence ev : evList) {
            if (!ev.isValidated()) continue;
            return true;
        }
        return false;
    }

    @VisibleForTesting
    boolean hasEnoughOverlappers(SVInterval interval) {
        Iterator<SVIntervalTree.Entry<List<BreakpointEvidence>>> itr = this.evidenceTree.overlappers(interval);
        PairedStrandedIntervalTree<BreakpointEvidence> targetIntervalTree = new PairedStrandedIntervalTree<BreakpointEvidence>();
        int weight = 0;
        while (itr.hasNext()) {
            List<BreakpointEvidence> evidenceForInterval = itr.next().getValue();
            if ((double)(weight += evidenceForInterval.stream().mapToInt(BreakpointEvidence::getWeight).sum()) >= this.minEvidenceWeight) {
                return true;
            }
            for (BreakpointEvidence evidence : evidenceForInterval) {
                if (!evidence.hasDistalTargets(this.readMetadata, this.minEvidenceMapq)) continue;
                List<StrandedInterval> distalTargets = evidence.getDistalTargets(this.readMetadata, this.minEvidenceMapq);
                for (int i = 0; i < distalTargets.size(); ++i) {
                    targetIntervalTree.put(new PairedStrandedIntervals(new StrandedInterval(evidence.getLocation(), evidence.isEvidenceUpstreamOfBreakpoint()), distalTargets.get(i)), evidence);
                }
            }
        }
        for (Tuple2 next : targetIntervalTree) {
            int coherentEvidenceWeight = (int)Utils.stream(targetIntervalTree.overlappers((PairedStrandedIntervals)next._1())).count();
            if (!((double)coherentEvidenceWeight >= this.minCoherentEvidenceWeight)) continue;
            return true;
        }
        return false;
    }

    private static <T> void addToTree(SVIntervalTree<List<T>> tree, SVInterval interval, T value) {
        SVIntervalTree.Entry<List<T>> entry = tree.find(interval);
        if (entry != null) {
            entry.getValue().add(value);
        } else {
            ArrayList<T> valueList = new ArrayList<T>(1);
            valueList.add(value);
            tree.put(interval, valueList);
        }
    }
}

