/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.segmentation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AlleleFractionSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AlleleFractionSegment;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.utils.segmentation.KernelSegmenter;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class AlleleFractionKernelSegmenter {
    private static final Logger logger = LogManager.getLogger(AlleleFractionKernelSegmenter.class);
    private static final int MIN_NUM_POINTS_REQUIRED_PER_CHROMOSOME = 10;
    private static final Function<Double, BiFunction<Double, Double, Double>> KERNEL = variance -> variance == 0.0 ? (x, y) -> x * y : (x, y) -> FastMath.exp((double)(-(x - y) * (x - y) / (2.0 * variance)));
    private final AllelicCountCollection allelicCounts;
    private final Map<String, List<AllelicCount>> allelicCountsPerChromosome;

    public AlleleFractionKernelSegmenter(AllelicCountCollection allelicCounts) {
        Utils.nonNull(allelicCounts);
        this.allelicCounts = allelicCounts;
        this.allelicCountsPerChromosome = allelicCounts.getRecords().stream().collect(Collectors.groupingBy(AllelicCount::getContig, LinkedHashMap::new, Collectors.mapping(Function.identity(), Collectors.toList())));
    }

    public AlleleFractionSegmentCollection findSegmentation(int maxNumSegmentsPerChromosome, double kernelVariance, int kernelApproximationDimension, List<Integer> windowSizes, double numChangepointsPenaltyLinearFactor, double numChangepointsPenaltyLogLinearFactor) {
        ParamUtils.isPositive(maxNumSegmentsPerChromosome, "Maximum number of segments must be positive.");
        ParamUtils.isPositiveOrZero(kernelVariance, "Variance of Gaussian kernel must be non-negative (if zero, a linear kernel will be used).");
        ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
        Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
        Utils.validateArg(new HashSet<Integer>(windowSizes).size() == windowSizes.size(), "Window sizes must all be unique.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLinearFactor, "Linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLogLinearFactor, "Log-linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        int maxNumChangepointsPerChromosome = maxNumSegmentsPerChromosome - 1;
        logger.info(String.format("Finding changepoints in %d data points and %d chromosomes...", this.allelicCounts.size(), this.allelicCountsPerChromosome.size()));
        ArrayList<AlleleFractionSegment> segments = new ArrayList<AlleleFractionSegment>();
        for (String chromosome : this.allelicCountsPerChromosome.keySet()) {
            List<AllelicCount> allelicCountsInChromosome = this.allelicCountsPerChromosome.get(chromosome);
            int numAllelicCountsInChromosome = allelicCountsInChromosome.size();
            logger.info(String.format("Finding changepoints in %d data points in chromosome %s...", numAllelicCountsInChromosome, chromosome));
            if (numAllelicCountsInChromosome < 10) {
                logger.warn(String.format("Number of points in chromosome %s (%d) is less than that required (%d), skipping segmentation...", chromosome, numAllelicCountsInChromosome, 10));
                int start = allelicCountsInChromosome.get(0).getStart();
                int end = allelicCountsInChromosome.get(numAllelicCountsInChromosome - 1).getEnd();
                segments.add(new AlleleFractionSegment(new SimpleInterval(chromosome, start, end), numAllelicCountsInChromosome));
                continue;
            }
            List alternateAlleleFractionsInChromosome = this.allelicCountsPerChromosome.get(chromosome).stream().map(AllelicCount::getAlternateAlleleFraction).collect(Collectors.toList());
            ArrayList<Integer> changepoints = new ArrayList<Integer>(new KernelSegmenter(alternateAlleleFractionsInChromosome).findChangepoints(maxNumChangepointsPerChromosome, KERNEL.apply(kernelVariance), kernelApproximationDimension, windowSizes, numChangepointsPenaltyLinearFactor, numChangepointsPenaltyLogLinearFactor, KernelSegmenter.ChangepointSortOrder.INDEX));
            if (!changepoints.contains(numAllelicCountsInChromosome)) {
                changepoints.add(numAllelicCountsInChromosome - 1);
            }
            int previousChangepoint = -1;
            Iterator iterator = changepoints.iterator();
            while (iterator.hasNext()) {
                int changepoint = (Integer)iterator.next();
                int start = this.allelicCountsPerChromosome.get(chromosome).get(previousChangepoint + 1).getStart();
                int end = this.allelicCountsPerChromosome.get(chromosome).get(changepoint).getEnd();
                List<AllelicCount> allelicCountsInSegment = allelicCountsInChromosome.subList(previousChangepoint + 1, changepoint + 1);
                segments.add(new AlleleFractionSegment(new SimpleInterval(chromosome, start, end), allelicCountsInSegment));
                previousChangepoint = changepoint;
            }
        }
        logger.info(String.format("Found %d segments in %d chromosomes.", segments.size(), this.allelicCountsPerChromosome.keySet().size()));
        return new AlleleFractionSegmentCollection((SampleLocatableMetadata)this.allelicCounts.getMetadata(), segments);
    }
}

