/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.sv.evidence;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.util.FVec;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.CigarElement;
import htsjdk.samtools.CigarOperator;
import htsjdk.samtools.TextCigarCodec;
import htsjdk.samtools.util.IOUtil;
import htsjdk.tribble.bed.BEDFeature;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.spark.sv.StructuralVariationDiscoveryArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.BreakpointEvidence;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.EvidenceFeatures;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.EvidenceOverlapChecker;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.LibraryStatistics;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.PartitionCrossingChecker;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.ReadMetadata;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVInterval;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVIntervalTree;
import org.broadinstitute.hellbender.tools.spark.utils.IntHistogram;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.io.IOUtils;

public final class XGBoostEvidenceFilter
implements Iterator<BreakpointEvidence> {
    private static final boolean USE_FAST_MATH_EXP = true;
    private static final List<Class<?>> DEFAULT_EVIDENCE_TYPE_ORDER = Arrays.asList(BreakpointEvidence.TemplateSizeAnomaly.class, BreakpointEvidence.MateUnmapped.class, BreakpointEvidence.InterContigPair.class, BreakpointEvidence.SplitRead.class, BreakpointEvidence.LargeIndel.class, BreakpointEvidence.WeirdTemplateSize.class, BreakpointEvidence.SameStrandPair.class, BreakpointEvidence.OutiesPair.class);
    private static final Map<Class<?>, Integer> evidenceTypeMap = XGBoostEvidenceFilter.evidenceTypeOrderToImmutableMap(DEFAULT_EVIDENCE_TYPE_ORDER);
    private static final String DEFAULT_PREDICTOR_RESOURCE_PATH = "/large/sv_evidence_classifier.bin";
    private static final double DEFAULT_GOOD_GAP_OVERLAP = 0.0;
    private static final double DEFAULT_GOOD_MAPPABILITY = 1.0;
    private static final int DEFAULT_GOOD_MAPPING_QUALITY = 60;
    private static final double NON_READ_MAPPING_QUALITY = 60.0;
    private static final double NON_READ_CIGAR_LENGTHS = 0.0;
    private final PartitionCrossingChecker partitionCrossingChecker;
    private final Predictor predictor;
    private final double thresholdProbability;
    private final ReadMetadata readMetadata;
    private final EvidenceOverlapChecker evidenceOverlapChecker;
    private final Map<BreakpointEvidence, UnscaledOverlapInfo> rawFeatureCache;
    private Iterator<SVIntervalTree.Entry<List<BreakpointEvidence>>> treeItr;
    private Iterator<BreakpointEvidence> listItr;
    private final FeatureDataSource<BEDFeature> genomeGaps;
    private final FeatureDataSource<BEDFeature> umapS100Mappability;

    XGBoostEvidenceFilter(Iterator<BreakpointEvidence> evidenceItr, ReadMetadata readMetadata, StructuralVariationDiscoveryArgumentCollection.FindBreakpointEvidenceSparkArgumentCollection params, PartitionCrossingChecker partitionCrossingChecker) {
        if (params.svGenomeGapsFile == null && params.runWithoutGapsAnnotation) {
            this.genomeGaps = null;
        } else if (params.svGenomeGapsFile != null && !params.runWithoutGapsAnnotation) {
            this.genomeGaps = new FeatureDataSource(params.svGenomeGapsFile);
        } else {
            throw new IllegalArgumentException("XGBoostEvidenceFilter requires specifying --sv-genome-gaps-file or passing --run-without-gaps-annotation (but not both)");
        }
        if (params.svGenomeUmapS100File == null && params.runWithoutUmapS100Annotation) {
            this.umapS100Mappability = null;
        } else if (params.svGenomeUmapS100File != null && !params.runWithoutUmapS100Annotation) {
            this.umapS100Mappability = new FeatureDataSource(params.svGenomeUmapS100File);
        } else {
            throw new IllegalArgumentException("XGBoostEvidenceFilter requires specifying --sv-genome-umap-s100-file or passing --run-without-umap-s100-annotation (but not both)");
        }
        this.predictor = XGBoostEvidenceFilter.loadPredictor(params.svEvidenceFilterModelFile);
        this.partitionCrossingChecker = partitionCrossingChecker;
        this.thresholdProbability = params.svEvidenceFilterThresholdProbability;
        this.readMetadata = readMetadata;
        this.evidenceOverlapChecker = new EvidenceOverlapChecker(evidenceItr, readMetadata, params.minEvidenceMapQ);
        this.rawFeatureCache = new HashMap<BreakpointEvidence, UnscaledOverlapInfo>();
        this.listItr = null;
        this.treeItr = this.evidenceOverlapChecker.getTreeIterator();
    }

    private static Map<Class<?>, Integer> evidenceTypeOrderToImmutableMap(List<Class<?>> evidenceTypeOrder) {
        HashMap evidenceTypeMap = new HashMap();
        for (int index = 0; index < evidenceTypeOrder.size(); ++index) {
            evidenceTypeMap.put(evidenceTypeOrder.get(index), index);
        }
        return Collections.unmodifiableMap(evidenceTypeMap);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static Predictor loadPredictor(String modelFileLocation) {
        ObjFunction.useFastMathExp((boolean)true);
        try (InputStream inputStream = modelFileLocation == null ? XGBoostEvidenceFilter.resourcePathToInputStream(DEFAULT_PREDICTOR_RESOURCE_PATH) : BucketUtils.openFile(modelFileLocation);){
            Predictor predictor = new Predictor(inputStream);
            return predictor;
        }
        catch (Exception e) {
            String string;
            StringBuilder stringBuilder = new StringBuilder().append("Unable to load predictor from classifier file ");
            if (modelFileLocation == null) {
                string = DEFAULT_PREDICTOR_RESOURCE_PATH;
                throw new GATKException(stringBuilder.append(string).append(": ").append(e.getMessage()).toString());
            }
            string = modelFileLocation;
            throw new GATKException(stringBuilder.append(string).append(": ").append(e.getMessage()).toString());
        }
    }

    private static InputStream resourcePathToInputStream(String resourcePath) throws IOException {
        InputStream inputStream = XGBoostEvidenceFilter.class.getResourceAsStream(resourcePath);
        return IOUtil.hasBlockCompressedExtension((String)resourcePath) ? IOUtils.makeZippedInputStream(new BufferedInputStream(inputStream)) : inputStream;
    }

    @Override
    public boolean hasNext() {
        if (this.listItr != null && this.listItr.hasNext()) {
            return true;
        }
        this.listItr = null;
        boolean result = false;
        while (!result && this.treeItr.hasNext()) {
            SVIntervalTree.Entry<List<BreakpointEvidence>> entry = this.treeItr.next();
            SVInterval curInterval = entry.getInterval();
            List<BreakpointEvidence> evidenceList = entry.getValue();
            if (this.isValidated(entry.getValue()) || this.partitionCrossingChecker.onBoundary(curInterval)) {
                result = true;
            } else if (this.anyPassesFilter(evidenceList)) {
                evidenceList.forEach(ev -> ev.setValidated(true));
                result = true;
            }
            if (!result) continue;
            this.listItr = entry.getValue().iterator();
        }
        return result;
    }

    @Override
    public BreakpointEvidence next() {
        if (!this.hasNext()) {
            throw new NoSuchElementException("No next element.");
        }
        return this.listItr.next();
    }

    private boolean isValidated(List<BreakpointEvidence> evList) {
        for (BreakpointEvidence ev : evList) {
            if (!ev.isValidated()) continue;
            return true;
        }
        return false;
    }

    private boolean anyPassesFilter(List<BreakpointEvidence> evidenceList) {
        for (BreakpointEvidence evidence : evidenceList) {
            if (!(this.predictProbability(evidence) > this.thresholdProbability)) continue;
            return true;
        }
        return false;
    }

    @VisibleForTesting
    double predictProbability(BreakpointEvidence evidence) {
        return this.predictor.predictSingle((FVec)this.getFeatures(evidence));
    }

    @VisibleForTesting
    EvidenceFeatures getFeatures(BreakpointEvidence evidence) {
        CigarQualityInfo cigarQualityInfo = new CigarQualityInfo(evidence);
        double evidenceType = evidenceTypeMap.get(evidence.getClass()).intValue();
        double mappingQuality = this.getMappingQuality(evidence);
        CoverageScaledOverlapInfo individualOverlapInfo = this.getIndividualOverlapInfo(evidence);
        CoverageScaledOverlapInfo clusterOverlapInfo = this.getClusterOverlapInfo(evidence);
        double referenceGapOverlap = this.genomeGaps == null ? 0.0 : XGBoostEvidenceFilter.getGenomeIntervalsOverlap(evidence, this.genomeGaps, this.readMetadata);
        double umapS100 = this.umapS100Mappability == null ? 1.0 : XGBoostEvidenceFilter.getGenomeIntervalsOverlap(evidence, this.umapS100Mappability, this.readMetadata);
        double templateSizeOrReadCount = this.getTemplateSizeOrReadCount(evidence);
        return new EvidenceFeatures(new double[]{cigarQualityInfo.basesMatched, cigarQualityInfo.referenceLength, evidenceType, mappingQuality, templateSizeOrReadCount, individualOverlapInfo.numOverlap, individualOverlapInfo.totalOverlapMappingQuality, individualOverlapInfo.meanOverlapMappingQuality, individualOverlapInfo.numCoherent, individualOverlapInfo.totalCoherentMappingQuality, clusterOverlapInfo.numOverlap, clusterOverlapInfo.totalOverlapMappingQuality, clusterOverlapInfo.meanOverlapMappingQuality, clusterOverlapInfo.numCoherent, clusterOverlapInfo.totalCoherentMappingQuality, referenceGapOverlap, umapS100});
    }

    private double getMappingQuality(BreakpointEvidence evidence) {
        return evidence instanceof BreakpointEvidence.ReadEvidence ? (double)((BreakpointEvidence.ReadEvidence)evidence).getMappingQuality() : 60.0;
    }

    private int getMappingQualityForOverlap(BreakpointEvidence evidence) {
        return evidence instanceof BreakpointEvidence.ReadEvidence ? ((BreakpointEvidence.ReadEvidence)evidence).getMappingQuality() : 60;
    }

    private double getTemplateSizeOrReadCount(BreakpointEvidence evidence) {
        if (evidence instanceof BreakpointEvidence.ReadEvidence) {
            return this.getTemplateSize((BreakpointEvidence.ReadEvidence)evidence);
        }
        if (evidence instanceof BreakpointEvidence.TemplateSizeAnomaly) {
            return this.getReadCounts((BreakpointEvidence.TemplateSizeAnomaly)evidence);
        }
        throw new IllegalStateException("templateSizeOrReadCount feature is only defined for ReadEvidence and TemplateSizeAnomaly, not " + evidence.getClass().getName());
    }

    private double getTemplateSize(BreakpointEvidence.ReadEvidence readEvidence) {
        int templateSize = readEvidence.getTemplateSize();
        String readGroup = readEvidence.getReadGroup();
        String library = this.readMetadata.getReadGroupToLibraryMap().get(readGroup);
        LibraryStatistics libraryStatistics = this.readMetadata.getLibraryStatistics(library);
        IntHistogram.CDF templateSizeCDF = libraryStatistics.getCDF();
        int cdfBin = Integer.min(Math.abs(templateSize), templateSizeCDF.size() - 1);
        return templateSizeCDF.getFraction(cdfBin);
    }

    private double getReadCounts(BreakpointEvidence.TemplateSizeAnomaly templateSizeAnomaly) {
        Integer readCounts = templateSizeAnomaly.getReadCount();
        return (double)readCounts.intValue() / (double)this.readMetadata.getCoverage();
    }

    private CoverageScaledOverlapInfo getIndividualOverlapInfo(BreakpointEvidence evidence) {
        if (!this.rawFeatureCache.containsKey(evidence)) {
            this.cacheOverlapInfo(evidence);
        }
        UnscaledOverlapInfo evidenceFeatureCache = this.rawFeatureCache.get(evidence);
        return new CoverageScaledOverlapInfo(evidenceFeatureCache.numOverlap, evidenceFeatureCache.numCoherent, evidenceFeatureCache.totalOverlapMappingQuality, evidenceFeatureCache.totalCoherentMappingQuality, evidenceFeatureCache.meanOverlapMappingQuality, this.readMetadata.getCoverage());
    }

    private CoverageScaledOverlapInfo getClusterOverlapInfo(BreakpointEvidence evidence) {
        int clusterNumOverlap = 0;
        int clusterNumCoherent = 0;
        int clusterOverlapMappingQuality = 0;
        int clusterCoherentMappingQuality = 0;
        double clusterMeanOverlapMappingQuality = 0.0;
        EvidenceOverlapChecker.OverlapperIterator overlapperItr = this.evidenceOverlapChecker.overlappers(evidence);
        while (overlapperItr.hasNext()) {
            BreakpointEvidence overlapper = (BreakpointEvidence)overlapperItr.next();
            if (overlapper.equals(evidence)) continue;
            if (!this.rawFeatureCache.containsKey(overlapper)) {
                this.cacheOverlapInfo(overlapper);
            }
            UnscaledOverlapInfo overlapperFeatureCache = this.rawFeatureCache.get(overlapper);
            clusterNumOverlap = Math.max(clusterNumOverlap, overlapperFeatureCache.numOverlap);
            clusterNumCoherent = Math.max(clusterNumCoherent, overlapperFeatureCache.numCoherent);
            clusterOverlapMappingQuality = Math.max(clusterOverlapMappingQuality, overlapperFeatureCache.totalOverlapMappingQuality);
            clusterCoherentMappingQuality = Math.max(clusterCoherentMappingQuality, overlapperFeatureCache.totalCoherentMappingQuality);
            clusterMeanOverlapMappingQuality = Math.max(clusterMeanOverlapMappingQuality, overlapperFeatureCache.meanOverlapMappingQuality);
        }
        return new CoverageScaledOverlapInfo(clusterNumOverlap, clusterNumCoherent, clusterOverlapMappingQuality, clusterCoherentMappingQuality, clusterMeanOverlapMappingQuality, this.readMetadata.getCoverage());
    }

    private void cacheOverlapInfo(BreakpointEvidence evidence) {
        int numOverlap = 0;
        int totalOverlapMappingQuality = 0;
        int numCoherent = 0;
        int totalCoherentMappingQuality = 0;
        EvidenceOverlapChecker.OverlapAndCoherenceIterator overlapperItr = this.evidenceOverlapChecker.overlappersWithCoherence(evidence);
        while (overlapperItr.hasNext()) {
            ImmutablePair<BreakpointEvidence, Boolean> itrResults = overlapperItr.next();
            BreakpointEvidence overlapper = (BreakpointEvidence)itrResults.left;
            if (overlapper.equals(evidence)) continue;
            ++numOverlap;
            int mappingQuality = this.getMappingQualityForOverlap(overlapper);
            totalOverlapMappingQuality += mappingQuality;
            boolean isCoherent = (Boolean)itrResults.right;
            if (!isCoherent) continue;
            ++numCoherent;
            totalCoherentMappingQuality += mappingQuality;
        }
        this.rawFeatureCache.put(evidence, new UnscaledOverlapInfo(numOverlap, numCoherent, totalOverlapMappingQuality, totalCoherentMappingQuality));
    }

    private static double getGenomeIntervalsOverlap(BreakpointEvidence evidence, FeatureDataSource<BEDFeature> genomeIntervals, ReadMetadata readMetadata) {
        SVInterval location = evidence.getLocation();
        SimpleInterval simpleInterval = new SimpleInterval(readMetadata.getContigName(location.getContig()), location.getStart(), location.getEnd() - 1);
        int overlap = 0;
        Iterator<BEDFeature> overlapperItr = genomeIntervals.query(simpleInterval);
        while (overlapperItr.hasNext()) {
            BEDFeature overlapper = overlapperItr.next();
            int overlapLength = Math.min(simpleInterval.getEnd(), overlapper.getEnd()) + 1 - Math.max(simpleInterval.getStart(), overlapper.getStart());
            overlap += overlapLength;
        }
        return (double)overlap / (double)simpleInterval.size();
    }

    private static class CigarQualityInfo {
        final double basesMatched;
        final double referenceLength;

        CigarQualityInfo(BreakpointEvidence evidence) {
            if (evidence instanceof BreakpointEvidence.ReadEvidence) {
                int numMatched = 0;
                int refLength = 0;
                String cigarString = ((BreakpointEvidence.ReadEvidence)evidence).getCigarString();
                for (CigarElement element : TextCigarCodec.decode((String)cigarString).getCigarElements()) {
                    CigarOperator op = element.getOperator();
                    if (!op.consumesReferenceBases()) continue;
                    refLength += element.getLength();
                    if (!op.consumesReadBases()) continue;
                    numMatched += element.getLength();
                }
                this.basesMatched = numMatched;
                this.referenceLength = refLength;
            } else {
                this.basesMatched = 0.0;
                this.referenceLength = 0.0;
            }
        }
    }

    private static class CoverageScaledOverlapInfo {
        final double numOverlap;
        final double totalOverlapMappingQuality;
        final double meanOverlapMappingQuality;
        final double numCoherent;
        final double totalCoherentMappingQuality;

        CoverageScaledOverlapInfo(int numOverlap, int numCoherent, int totalOverlapMappingQuality, int totalCoherentMappingQuality, double meanOverlapMappingQuality, double coverage) {
            this.numOverlap = (double)numOverlap / coverage;
            this.totalOverlapMappingQuality = (double)totalOverlapMappingQuality / coverage;
            this.numCoherent = (double)numCoherent / coverage;
            this.totalCoherentMappingQuality = (double)totalCoherentMappingQuality / coverage;
            this.meanOverlapMappingQuality = meanOverlapMappingQuality;
        }
    }

    private static class UnscaledOverlapInfo {
        final int numOverlap;
        final int numCoherent;
        final int totalOverlapMappingQuality;
        final int totalCoherentMappingQuality;
        final double meanOverlapMappingQuality;

        UnscaledOverlapInfo(int numOverlap, int numCoherent, int totalOverlapMappingQuality, int totalCoherentMappingQuality) {
            this.numOverlap = numOverlap;
            this.numCoherent = numCoherent;
            this.totalOverlapMappingQuality = totalOverlapMappingQuality;
            this.totalCoherentMappingQuality = totalCoherentMappingQuality;
            this.meanOverlapMappingQuality = (double)this.totalOverlapMappingQuality / (double)numOverlap;
        }
    }
}

