/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.segmentation;

import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.MultidimensionalSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyRatio;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.MultidimensionalSegment;
import org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class MultidimensionalKernelSegmenter {
    private static final Logger logger = LogManager.getLogger(MultidimensionalKernelSegmenter.class);
    private static final int MIN_NUM_POINTS_REQUIRED_PER_CHROMOSOME = 10;
    private static final SimpleInterval DUMMY_INTERVAL = new SimpleInterval("DUMMY", 1, 1);
    private static final AllelicCount BALANCED_ALLELIC_COUNT = new AllelicCount(DUMMY_INTERVAL, 1, 1);
    private static final Function<Double, BiFunction<Double, Double, Double>> KERNEL = standardDeviation -> standardDeviation == 0.0 ? (x, y) -> x * y : (x, y) -> new NormalDistribution(null, x.doubleValue(), standardDeviation.doubleValue()).density(y.doubleValue());
    private final CopyRatioCollection denoisedCopyRatios;
    private final OverlapDetector<CopyRatio> copyRatioMidpointOverlapDetector;
    private final AllelicCountCollection allelicCounts;
    private final OverlapDetector<AllelicCount> allelicCountOverlapDetector;
    private final Comparator<Locatable> comparator;
    private final Map<String, List<MultidimensionalPoint>> multidimensionalPointsPerChromosome;

    public MultidimensionalKernelSegmenter(CopyRatioCollection denoisedCopyRatios, AllelicCountCollection allelicCounts) {
        Utils.nonNull(denoisedCopyRatios);
        Utils.nonNull(allelicCounts);
        Utils.validateArg(((SampleLocatableMetadata)denoisedCopyRatios.getMetadata()).equals(allelicCounts.getMetadata()), "Metadata do not match.");
        this.denoisedCopyRatios = denoisedCopyRatios;
        this.copyRatioMidpointOverlapDetector = denoisedCopyRatios.getMidpointOverlapDetector();
        this.allelicCounts = allelicCounts;
        this.allelicCountOverlapDetector = allelicCounts.getOverlapDetector();
        int numAllelicCountsToUse = (int)denoisedCopyRatios.getRecords().stream().filter(arg_0 -> this.allelicCountOverlapDetector.overlapsAny(arg_0)).count();
        logger.info(String.format("Using first allelic-count site in each copy-ratio interval (%d / %d) for multidimensional segmentation...", numAllelicCountsToUse, allelicCounts.size()));
        this.comparator = denoisedCopyRatios.getComparator();
        this.multidimensionalPointsPerChromosome = denoisedCopyRatios.getRecords().stream().map(cr -> new MultidimensionalPoint(cr.getInterval(), cr.getLog2CopyRatioValue(), this.allelicCountOverlapDetector.getOverlaps((Locatable)cr).stream().min(this.comparator::compare).orElse(BALANCED_ALLELIC_COUNT).getAlternateAlleleFraction())).collect(Collectors.groupingBy(MultidimensionalPoint::getContig, LinkedHashMap::new, Collectors.toList()));
    }

    public MultidimensionalSegmentCollection findSegmentation(int maxNumSegmentsPerChromosome, double kernelVarianceCopyRatio, double kernelVarianceAlleleFraction, double kernelScalingAlleleFraction, int kernelApproximationDimension, List<Integer> windowSizes, double numChangepointsPenaltyLinearFactor, double numChangepointsPenaltyLogLinearFactor) {
        ParamUtils.isPositive(maxNumSegmentsPerChromosome, "Maximum number of segments must be positive.");
        ParamUtils.isPositiveOrZero(kernelVarianceCopyRatio, "Variance of copy-ratio Gaussian kernel must be non-negative (if zero, a linear kernel will be used).");
        ParamUtils.isPositiveOrZero(kernelVarianceAlleleFraction, "Variance of allele-fraction Gaussian kernel must be non-negative (if zero, a linear kernel will be used).");
        ParamUtils.isPositiveOrZero(kernelScalingAlleleFraction, "Scaling of allele-fraction Gaussian kernel must be non-negative.");
        ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
        Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
        Utils.validateArg(new HashSet<Integer>(windowSizes).size() == windowSizes.size(), "Window sizes must all be unique.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLinearFactor, "Linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLogLinearFactor, "Log-linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        int maxNumChangepointsPerChromosome = maxNumSegmentsPerChromosome - 1;
        BiFunction<MultidimensionalPoint, MultidimensionalPoint, Double> kernel = this.constructKernel(kernelVarianceCopyRatio, kernelVarianceAlleleFraction, kernelScalingAlleleFraction);
        logger.info(String.format("Finding changepoints in (%d, %d) data points and %d chromosomes...", this.denoisedCopyRatios.size(), this.allelicCounts.size(), this.multidimensionalPointsPerChromosome.size()));
        ArrayList<MultidimensionalSegment> segments = new ArrayList<MultidimensionalSegment>();
        for (String chromosome : this.multidimensionalPointsPerChromosome.keySet()) {
            List<MultidimensionalPoint> multidimensionalPointsInChromosome = this.multidimensionalPointsPerChromosome.get(chromosome);
            int numMultidimensionalPointsInChromosome = multidimensionalPointsInChromosome.size();
            logger.info(String.format("Finding changepoints in %d data points in chromosome %s...", numMultidimensionalPointsInChromosome, chromosome));
            if (numMultidimensionalPointsInChromosome < 10) {
                logger.warn(String.format("Number of points in chromosome %s (%d) is less than that required (%d), skipping segmentation...", chromosome, numMultidimensionalPointsInChromosome, 10));
                int start = multidimensionalPointsInChromosome.get(0).getStart();
                int end = multidimensionalPointsInChromosome.get(numMultidimensionalPointsInChromosome - 1).getEnd();
                segments.add(new MultidimensionalSegment(new SimpleInterval(chromosome, start, end), this.comparator, this.copyRatioMidpointOverlapDetector, this.allelicCountOverlapDetector));
                continue;
            }
            ArrayList<Integer> changepoints = new ArrayList<Integer>(new KernelSegmenter<MultidimensionalPoint>(multidimensionalPointsInChromosome).findChangepoints(maxNumChangepointsPerChromosome, kernel, kernelApproximationDimension, windowSizes, numChangepointsPenaltyLinearFactor, numChangepointsPenaltyLogLinearFactor, KernelSegmenter.ChangepointSortOrder.INDEX));
            if (!changepoints.contains(numMultidimensionalPointsInChromosome)) {
                changepoints.add(numMultidimensionalPointsInChromosome - 1);
            }
            int previousChangepoint = -1;
            Iterator iterator = changepoints.iterator();
            while (iterator.hasNext()) {
                int changepoint = (Integer)iterator.next();
                int start = this.multidimensionalPointsPerChromosome.get(chromosome).get(previousChangepoint + 1).getStart();
                int end = this.multidimensionalPointsPerChromosome.get(chromosome).get(changepoint).getEnd();
                segments.add(new MultidimensionalSegment(new SimpleInterval(chromosome, start, end), this.comparator, this.copyRatioMidpointOverlapDetector, this.allelicCountOverlapDetector));
                previousChangepoint = changepoint;
            }
        }
        logger.info(String.format("Found %d segments in %d chromosomes.", segments.size(), this.multidimensionalPointsPerChromosome.keySet().size()));
        return new MultidimensionalSegmentCollection((SampleLocatableMetadata)this.allelicCounts.getMetadata(), segments);
    }

    private BiFunction<MultidimensionalPoint, MultidimensionalPoint, Double> constructKernel(double kernelVarianceCopyRatio, double kernelVarianceAlleleFraction, double kernelScalingAlleleFraction) {
        double standardDeviationCopyRatio = Math.sqrt(kernelVarianceCopyRatio);
        double standardDeviationAlleleFraction = Math.sqrt(kernelVarianceAlleleFraction);
        return (p1, p2) -> KERNEL.apply(standardDeviationCopyRatio).apply(((MultidimensionalPoint)p1).log2CopyRatio, ((MultidimensionalPoint)p2).log2CopyRatio) + kernelScalingAlleleFraction * KERNEL.apply(standardDeviationAlleleFraction).apply(((MultidimensionalPoint)p1).alternateAlleleFraction, ((MultidimensionalPoint)p2).alternateAlleleFraction);
    }

    private static final class MultidimensionalPoint
    implements Locatable {
        private final SimpleInterval interval;
        private final double log2CopyRatio;
        private final double alternateAlleleFraction;

        MultidimensionalPoint(SimpleInterval interval, double log2CopyRatio, double alternateAlleleFraction) {
            this.interval = interval;
            this.log2CopyRatio = log2CopyRatio;
            this.alternateAlleleFraction = alternateAlleleFraction;
        }

        public String getContig() {
            return this.interval.getContig();
        }

        public int getStart() {
            return this.interval.getStart();
        }

        public int getEnd() {
            return this.interval.getEnd();
        }
    }
}

