/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber;

import htsjdk.samtools.util.Interval;
import htsjdk.samtools.util.IntervalList;
import htsjdk.samtools.util.Locatable;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.argumentcollections.IntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.argumentcollections.RequiredIntervalArgumentCollection;
import org.broadinstitute.hellbender.cmdline.programgroups.CopyNumberProgramGroup;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AnnotatedIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.LocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AnnotatedInterval;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.SimpleCount;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.annotation.AnnotationKey;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.annotation.CopyNumberAnnotations;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;

@CommandLineProgramProperties(summary="Filters intervals based on annotations and/or count statistics", oneLineSummary="Filters intervals based on annotations and/or count statistics", programGroup=CopyNumberProgramGroup.class)
@DocumentedFeature
public final class FilterIntervals
extends CommandLineProgram {
    public static final String MINIMUM_GC_CONTENT_LONG_NAME = "minimum-gc-content";
    public static final String MAXIMUM_GC_CONTENT_LONG_NAME = "maximum-gc-content";
    public static final String MINIMUM_MAPPABILITY_LONG_NAME = "minimum-mappability";
    public static final String MAXIMUM_MAPPABILITY_LONG_NAME = "maximum-mappability";
    public static final String MINIMUM_SEGMENTAL_DUPLICATION_CONTENT_LONG_NAME = "minimum-segmental-duplication-content";
    public static final String MAXIMUM_SEGMENTAL_DUPLICATION_CONTENT_LONG_NAME = "maximum-segmental-duplication-content";
    public static final String LOW_COUNT_FILTER_COUNT_THRESHOLD_LONG_NAME = "low-count-filter-count-threshold";
    public static final String LOW_COUNT_FILTER_PERCENTAGE_OF_SAMPLES_LONG_NAME = "low-count-filter-percentage-of-samples";
    public static final String EXTREME_COUNT_FILTER_MINIMUM_PERCENTILE_LONG_NAME = "extreme-count-filter-minimum-percentile";
    public static final String EXTREME_COUNT_FILTER_MAXIMUM_PERCENTILE_LONG_NAME = "extreme-count-filter-maximum-percentile";
    public static final String EXTREME_COUNT_FILTER_PERCENTAGE_OF_SAMPLES_LONG_NAME = "extreme-count-filter-percentage-of-samples";
    @Argument(doc="Input file containing annotations for genomic intervals (output of AnnotateIntervals).  Must be provided if no counts files are provided.", fullName="annotated-intervals", optional=true)
    private File inputAnnotatedIntervalsFile = null;
    @Argument(doc="Input TSV or HDF5 files containing integer read counts in genomic intervals (output of CollectReadCounts).  Must be provided if no annotated-intervals file is provided.", fullName="input", shortName="I", optional=true)
    private List<File> inputReadCountFiles = new ArrayList<File>();
    @Argument(doc="Output Picard interval-list file containing the filtered intervals.", fullName="output", shortName="O")
    private File outputFilteredIntervalsFile;
    @ArgumentCollection
    protected IntervalArgumentCollection intervalArgumentCollection = new RequiredIntervalArgumentCollection();
    @Argument(doc="Minimum allowed value for GC-content annotation (inclusive).", fullName="minimum-gc-content", minValue=0.0, maxValue=1.0, optional=true)
    private double minimumGCContent = 0.1;
    @Argument(doc="Maximum allowed value for GC-content annotation (inclusive).", fullName="maximum-gc-content", minValue=0.0, maxValue=1.0, optional=true)
    private double maximumGCContent = 0.9;
    @Argument(doc="Minimum allowed value for mappability annotation (inclusive).", fullName="minimum-mappability", minValue=0.0, maxValue=1.0, optional=true)
    private double minimumMappability = 0.9;
    @Argument(doc="Maximum allowed value for mappability annotation (inclusive).", fullName="maximum-mappability", minValue=0.0, maxValue=1.0, optional=true)
    private double maximumMappability = 1.0;
    @Argument(doc="Minimum allowed value for segmental-duplication-content annotation (inclusive).", fullName="minimum-segmental-duplication-content", minValue=0.0, maxValue=1.0, optional=true)
    private double minimumSegmentalDuplicationContent = 0.0;
    @Argument(doc="Maximum allowed value for segmental-duplication-content annotation (inclusive).", fullName="maximum-segmental-duplication-content", minValue=0.0, maxValue=1.0, optional=true)
    private double maximumSegmentalDuplicationContent = 0.5;
    @Argument(doc="Count-threshold parameter for the low-count filter.  Intervals with a count strictly less than this threshold in a percentage of samples strictly greater than low-count-filter-percentage-of-samples will be filtered out.  (This is the first count-based filter applied.)", fullName="low-count-filter-count-threshold", minValue=0.0, optional=true)
    private int lowCountFilterCountThreshold = 5;
    @Argument(doc="Percentage-of-samples parameter for the low-count filter.  Intervals with a count strictly less than low-count-filter-count-threshold in a percentage of samples strictly greater than this will be filtered out.  (This is the first count-based filter applied.)", fullName="low-count-filter-percentage-of-samples", minValue=0.0, maxValue=100.0, optional=true)
    private double lowCountFilterPercentageOfSamples = 90.0;
    @Argument(doc="Minimum-percentile parameter for the extreme-count filter.  Intervals with a count that has a percentile strictly less than this in a percentage of samples strictly greater than extreme-count-filter-percentage-of-samples will be filtered out.  (This is the second count-based filter applied.)", fullName="extreme-count-filter-minimum-percentile", minValue=0.0, maxValue=100.0, optional=true)
    private double extremeCountFilterMinimumPercentile = 1.0;
    @Argument(doc="Maximum-percentile parameter for the extreme-count filter.  Intervals with a count that has a percentile strictly greater than this in a percentage of samples strictly greater than extreme-count-filter-percentage-of-samples will be filtered out.  (This is the second count-based filter applied.)", fullName="extreme-count-filter-maximum-percentile", minValue=0.0, maxValue=100.0, optional=true)
    private double extremeCountFilterMaximumPercentile = 99.0;
    @Argument(doc="Percentage-of-samples parameter for the extreme-count filter.  Intervals with a count that has a percentile outside of [extreme-count-filter-minimum-percentile, extreme-count-filter-maximum-percentile] in a percentage of samples strictly greater than this will be filtered out.  (This is the second count-based filter applied.)", fullName="extreme-count-filter-percentage-of-samples", minValue=0.0, maxValue=100.0, optional=true)
    private double extremeCountFilterPercentageOfSamples = 90.0;

    @Override
    public Object doWork() {
        this.validateArguments();
        Pair<SimpleIntervalCollection, AnnotatedIntervalCollection> intersectedIntervalsPair = FilterIntervals.resolveAndValidateIntervals(this.logger, this.inputAnnotatedIntervalsFile, this.inputReadCountFiles, this.intervalArgumentCollection);
        SimpleIntervalCollection intersectedIntervals = (SimpleIntervalCollection)intersectedIntervalsPair.getLeft();
        AnnotatedIntervalCollection intersectedAnnotatedIntervals = (AnnotatedIntervalCollection)intersectedIntervalsPair.getRight();
        SimpleIntervalCollection filteredIntervals = this.filterIntervals(intersectedIntervals, intersectedAnnotatedIntervals);
        this.logger.info(String.format("Writing filtered intervals to %s...", this.outputFilteredIntervalsFile.getAbsolutePath()));
        IntervalList filteredIntervalList = new IntervalList(((LocatableMetadata)filteredIntervals.getMetadata()).getSequenceDictionary());
        filteredIntervals.getIntervals().forEach(i -> filteredIntervalList.add(new Interval((Locatable)i)));
        filteredIntervalList.write(this.outputFilteredIntervalsFile);
        this.logger.info(String.format("%s complete.", this.getClass().getSimpleName()));
        return null;
    }

    private void validateArguments() {
        CopyNumberArgumentValidationUtils.validateIntervalArgumentCollection(this.intervalArgumentCollection);
        if (this.inputAnnotatedIntervalsFile == null && this.inputReadCountFiles.isEmpty()) {
            throw new UserException("Must provide annotated intervals or counts.");
        }
        Utils.validateArg(this.inputReadCountFiles.size() == new HashSet<File>(this.inputReadCountFiles).size(), "List of input read-count files cannot contain duplicates.");
        CopyNumberArgumentValidationUtils.validateInputs(this.inputAnnotatedIntervalsFile);
        this.inputReadCountFiles.forEach(xva$0 -> CopyNumberArgumentValidationUtils.validateInputs(xva$0));
        CopyNumberArgumentValidationUtils.validateOutputFiles(this.outputFilteredIntervalsFile);
    }

    private static Pair<SimpleIntervalCollection, AnnotatedIntervalCollection> resolveAndValidateIntervals(Logger logger, File inputAnnotatedIntervalsFile, List<File> inputReadCountFiles, IntervalArgumentCollection intervalArgumentCollection) {
        AnnotatedIntervalCollection intersectedAnnotatedIntervals;
        List<AnnotatedInterval> intersectedAnnotated;
        List intersected;
        List<SimpleInterval> resolved;
        LocatableMetadata metadata;
        AnnotatedIntervalCollection inputAnnotatedIntervals;
        if (inputAnnotatedIntervalsFile != null && inputReadCountFiles.isEmpty()) {
            inputAnnotatedIntervals = new AnnotatedIntervalCollection(inputAnnotatedIntervalsFile);
            metadata = (LocatableMetadata)inputAnnotatedIntervals.getMetadata();
            resolved = intervalArgumentCollection.getIntervals(metadata.getSequenceDictionary());
            intersected = ListUtils.intersection(resolved, inputAnnotatedIntervals.getIntervals());
            HashSet intersectedSet = new HashSet(intersected);
            intersectedAnnotated = inputAnnotatedIntervals.getRecords().stream().filter(ai -> intersectedSet.contains(ai.getInterval())).collect(Collectors.toList());
        } else if (inputAnnotatedIntervalsFile == null && !inputReadCountFiles.isEmpty()) {
            File firstReadCountFile = inputReadCountFiles.get(0);
            SimpleCountCollection firstReadCounts = SimpleCountCollection.read(firstReadCountFile);
            metadata = (LocatableMetadata)firstReadCounts.getMetadata();
            resolved = intervalArgumentCollection.getIntervals(metadata.getSequenceDictionary());
            intersected = ListUtils.intersection(resolved, firstReadCounts.getIntervals());
            intersectedAnnotated = null;
        } else {
            inputAnnotatedIntervals = new AnnotatedIntervalCollection(inputAnnotatedIntervalsFile);
            File firstReadCountFile = inputReadCountFiles.get(0);
            SimpleCountCollection firstReadCounts = SimpleCountCollection.read(firstReadCountFile);
            CopyNumberArgumentValidationUtils.isSameDictionary(((LocatableMetadata)inputAnnotatedIntervals.getMetadata()).getSequenceDictionary(), ((SampleLocatableMetadata)firstReadCounts.getMetadata()).getSequenceDictionary());
            metadata = (LocatableMetadata)inputAnnotatedIntervals.getMetadata();
            resolved = intervalArgumentCollection.getIntervals(metadata.getSequenceDictionary());
            intersected = ListUtils.intersection((List)ListUtils.intersection(resolved, inputAnnotatedIntervals.getIntervals()), firstReadCounts.getIntervals());
            HashSet intersectedSet = new HashSet(intersected);
            intersectedAnnotated = inputAnnotatedIntervals.getRecords().stream().filter(ai -> intersectedSet.contains(ai.getInterval())).collect(Collectors.toList());
        }
        Utils.validateArg(!intersected.isEmpty(), "At least one interval must remain after intersection.");
        logger.info(String.format("After interval resolution, %d intervals remain...", resolved.size()));
        logger.info(String.format("After interval intersection, %d intervals remain...", intersected.size()));
        SimpleIntervalCollection intersectedIntervals = new SimpleIntervalCollection(metadata, intersected);
        AnnotatedIntervalCollection annotatedIntervalCollection = intersectedAnnotatedIntervals = intersectedAnnotated == null ? null : new AnnotatedIntervalCollection(metadata, intersectedAnnotated);
        if (intersectedAnnotatedIntervals != null && !intersectedIntervals.getRecords().equals(intersectedAnnotatedIntervals.getIntervals())) {
            throw new GATKException.ShouldNeverReachHereException("After intersection, intervals should match those of annotated intervals.");
        }
        return Pair.of((Object)intersectedIntervals, (Object)intersectedAnnotatedIntervals);
    }

    private SimpleIntervalCollection filterIntervals(SimpleIntervalCollection intersectedIntervals, AnnotatedIntervalCollection intersectedAnnotatedIntervals) {
        int numIntersectedIntervals = intersectedIntervals.size();
        boolean[] mask = new boolean[numIntersectedIntervals];
        if (intersectedAnnotatedIntervals != null) {
            this.logger.info("Applying annotation-based filters...");
            List<AnnotationKey<?>> annotationKeys = ((AnnotatedInterval)intersectedAnnotatedIntervals.getRecords().get(0)).getAnnotationMap().getKeys();
            if (annotationKeys.contains(CopyNumberAnnotations.GC_CONTENT)) {
                FilterIntervals.updateMaskByAnnotationFilter(this.logger, intersectedIntervals, intersectedAnnotatedIntervals, mask, CopyNumberAnnotations.GC_CONTENT, "GC-content", this.minimumGCContent, this.maximumGCContent);
            }
            if (annotationKeys.contains(CopyNumberAnnotations.MAPPABILITY)) {
                FilterIntervals.updateMaskByAnnotationFilter(this.logger, intersectedIntervals, intersectedAnnotatedIntervals, mask, CopyNumberAnnotations.MAPPABILITY, "mappability", this.minimumMappability, this.maximumMappability);
            }
            if (annotationKeys.contains(CopyNumberAnnotations.SEGMENTAL_DUPLICATION_CONTENT)) {
                FilterIntervals.updateMaskByAnnotationFilter(this.logger, intersectedIntervals, intersectedAnnotatedIntervals, mask, CopyNumberAnnotations.SEGMENTAL_DUPLICATION_CONTENT, "segmental-duplication-content", this.minimumSegmentalDuplicationContent, this.maximumSegmentalDuplicationContent);
            }
        }
        if (!this.inputReadCountFiles.isEmpty()) {
            RealMatrix readCountMatrix = FilterIntervals.constructReadCountMatrix(this.logger, this.inputReadCountFiles, intersectedIntervals);
            int numSamples = readCountMatrix.getRowDimension();
            this.logger.info("Applying count-based filters...");
            IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).forEach(i -> {
                if ((double)Arrays.stream(readCountMatrix.getColumn(i)).filter(c -> c < (double)this.lowCountFilterCountThreshold).count() > this.lowCountFilterPercentageOfSamples * (double)numSamples / 100.0) {
                    mask[i] = true;
                }
            });
            this.logger.info(String.format("After applying low-count filter (intervals with a count < %d in > %s%% of samples fail), %d / %d intervals remain...", this.lowCountFilterCountThreshold, this.lowCountFilterPercentageOfSamples, FilterIntervals.countNumberPassing(mask), numIntersectedIntervals));
            boolean[][] percentileMask = new boolean[numSamples][numIntersectedIntervals];
            for (int sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) {
                double[] counts = readCountMatrix.getRow(sampleIndex);
                double[] filteredCounts = IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).mapToDouble(i -> counts[i]).toArray();
                double extremeCountMinimumPercentileThreshold = this.extremeCountFilterMinimumPercentile == 0.0 ? 0.0 : new Percentile(this.extremeCountFilterMinimumPercentile).evaluate(filteredCounts);
                double extremeCountMaximumPercentileThreshold = this.extremeCountFilterMaximumPercentile == 0.0 ? 0.0 : new Percentile(this.extremeCountFilterMaximumPercentile).evaluate(filteredCounts);
                for (int intervalIndex = 0; intervalIndex < numIntersectedIntervals; ++intervalIndex) {
                    double count = readCountMatrix.getEntry(sampleIndex, intervalIndex);
                    if (extremeCountMinimumPercentileThreshold <= count && count <= extremeCountMaximumPercentileThreshold) continue;
                    percentileMask[sampleIndex][intervalIndex] = true;
                }
            }
            IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).forEach(i -> {
                if ((double)IntStream.range(0, numSamples).filter(sampleIndex -> percentileMask[sampleIndex][i]).count() > this.extremeCountFilterPercentageOfSamples * (double)numSamples / 100.0) {
                    mask[i] = true;
                }
            });
            this.logger.info(String.format("After applying extreme-count filter (intervals with a count percentile outside of [%s, %s] in > %s%% of samples fail), %d / %d intervals remain...", this.extremeCountFilterMinimumPercentile, this.extremeCountFilterMaximumPercentile, this.extremeCountFilterPercentageOfSamples, FilterIntervals.countNumberPassing(mask), numIntersectedIntervals));
        }
        Map<String, Long> contigToIntervalCountMap = IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).mapToObj(i -> (SimpleInterval)intersectedIntervals.getRecords().get(i)).collect(Collectors.groupingBy(SimpleInterval::getContig, Collectors.counting()));
        IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).forEach(i -> {
            String contig = ((SimpleInterval)intersectedIntervals.getRecords().get(i)).getContig();
            long intervalCount = (Long)contigToIntervalCountMap.get(contig);
            if (intervalCount == 1L) {
                this.logger.warn(String.format("After applying provided filters, contig %s was left with a single interval that was filtered out.", contig));
                mask[i] = true;
            }
        });
        this.logger.info(String.format("%d / %d intervals passed all filters...", FilterIntervals.countNumberPassing(mask), numIntersectedIntervals));
        return new SimpleIntervalCollection((LocatableMetadata)intersectedIntervals.getMetadata(), IntStream.range(0, numIntersectedIntervals).filter(i -> !mask[i]).mapToObj(i -> (SimpleInterval)intersectedIntervals.getRecords().get(i)).collect(Collectors.toList()));
    }

    private static void updateMaskByAnnotationFilter(Logger logger, SimpleIntervalCollection intersectedIntervals, AnnotatedIntervalCollection intersectedAnnotatedIntervals, boolean[] mask, AnnotationKey<Double> annotationKey, String filterName, double minValue, double maxValue) {
        IntStream.range(0, intersectedIntervals.size()).filter(i -> !mask[i]).forEach(i -> {
            double value = (Double)((AnnotatedInterval)intersectedAnnotatedIntervals.getRecords().get(i)).getAnnotationMap().getValue(annotationKey);
            if (!(minValue <= value) || !(value <= maxValue)) {
                mask[i] = true;
            }
        });
        logger.info(String.format("After applying %s filter (intervals with values outside of [%s, %s] fail), %d / %d intervals remain...", filterName, minValue, maxValue, FilterIntervals.countNumberPassing(mask), intersectedIntervals.size()));
    }

    private static RealMatrix constructReadCountMatrix(Logger logger, List<File> inputReadCountFiles, SimpleIntervalCollection intersectedIntervals) {
        logger.info("Validating and aggregating input read-counts files...");
        int numSamples = inputReadCountFiles.size();
        int numIntervals = intersectedIntervals.size();
        HashSet intervalSubset = new HashSet(intersectedIntervals.getRecords());
        Array2DRowRealMatrix readCountMatrix = new Array2DRowRealMatrix(numSamples, numIntervals);
        ListIterator<File> inputReadCountFilesIterator = inputReadCountFiles.listIterator();
        while (inputReadCountFilesIterator.hasNext()) {
            double[] subsetReadCounts;
            int sampleIndex = inputReadCountFilesIterator.nextIndex();
            File inputReadCountFile = inputReadCountFilesIterator.next();
            logger.info(String.format("Aggregating read-counts file %s (%d / %d)", inputReadCountFile, sampleIndex + 1, numSamples));
            SimpleCountCollection readCounts = SimpleCountCollection.read(inputReadCountFile);
            if (!CopyNumberArgumentValidationUtils.isSameDictionary(((SampleLocatableMetadata)readCounts.getMetadata()).getSequenceDictionary(), ((LocatableMetadata)intersectedIntervals.getMetadata()).getSequenceDictionary())) {
                logger.warn(String.format("Sequence dictionary for read-counts file %s is inconsistent with those for other inputs.", inputReadCountFile));
            }
            Utils.validateArg((subsetReadCounts = readCounts.getRecords().stream().filter(c -> intervalSubset.contains(c.getInterval())).mapToDouble(SimpleCount::getCount).toArray()).length == intervalSubset.size(), String.format("Intervals for read-count file %s do not contain all specified intervals.", inputReadCountFile));
            readCountMatrix.setRow(sampleIndex, subsetReadCounts);
        }
        return readCountMatrix;
    }

    private static int countNumberPassing(boolean[] mask) {
        int numPassing = (int)IntStream.range(0, mask.length).filter(i -> !mask[i]).count();
        if (numPassing == 0) {
            throw new UserException.BadInput("Filtering removed all intervals.  Select less strict filtering criteria.");
        }
        return numPassing;
    }
}

