/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.caller;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.formats.CopyNumberFormatsUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CalledCopyRatioSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CalledCopyRatioSegment;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyRatioSegment;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class SimpleCopyRatioCaller {
    private static final Logger logger = LogManager.getLogger(SimpleCopyRatioCaller.class);
    private final double neutralSegmentCopyRatioLowerBound;
    private final double neutralSegmentCopyRatioUpperBound;
    private final double outlierNeutralSegmentCopyRatioZScoreThreshold;
    private final double callingCopyRatioZScoreThreshold;
    private final Statistics callingStatistics;
    private final CopyRatioSegmentCollection copyRatioSegments;

    public SimpleCopyRatioCaller(CopyRatioSegmentCollection copyRatioSegments, double neutralSegmentCopyRatioLowerBound, double neutralSegmentCopyRatioUpperBound, double outlierNeutralSegmentCopyRatioZScoreThreshold, double callingCopyRatioZScoreThreshold) {
        ParamUtils.isPositiveOrZero(neutralSegmentCopyRatioLowerBound, "Copy-neutral lower bound must be non-negative.");
        Utils.validateArg(neutralSegmentCopyRatioLowerBound < neutralSegmentCopyRatioUpperBound, "Copy-neutral lower bound must be less than upper bound.");
        ParamUtils.isPositive(outlierNeutralSegmentCopyRatioZScoreThreshold, "Outlier z-score threshold must be positive.");
        ParamUtils.isPositive(callingCopyRatioZScoreThreshold, "Calling z-score threshold must be positive.");
        this.copyRatioSegments = Utils.nonNull(copyRatioSegments);
        this.neutralSegmentCopyRatioLowerBound = neutralSegmentCopyRatioLowerBound;
        this.neutralSegmentCopyRatioUpperBound = neutralSegmentCopyRatioUpperBound;
        this.outlierNeutralSegmentCopyRatioZScoreThreshold = outlierNeutralSegmentCopyRatioZScoreThreshold;
        this.callingCopyRatioZScoreThreshold = callingCopyRatioZScoreThreshold;
        this.callingStatistics = this.calculateCallingStatistics();
    }

    public CalledCopyRatioSegmentCollection makeCalls() {
        List segments = this.copyRatioSegments.getRecords();
        ArrayList<CalledCopyRatioSegment> calledSegments = new ArrayList<CalledCopyRatioSegment>(segments.size());
        for (CopyRatioSegment segment : segments) {
            double copyRatioMean = Math.pow(2.0, segment.getMeanLog2CopyRatio());
            if (this.neutralSegmentCopyRatioLowerBound <= copyRatioMean && copyRatioMean <= this.neutralSegmentCopyRatioUpperBound) {
                calledSegments.add(new CalledCopyRatioSegment(segment, CalledCopyRatioSegment.Call.NEUTRAL));
                continue;
            }
            double copyRatioDeviation = copyRatioMean - this.callingStatistics.mean;
            if (copyRatioDeviation < -this.callingStatistics.standardDeviation * this.callingCopyRatioZScoreThreshold) {
                calledSegments.add(new CalledCopyRatioSegment(segment, CalledCopyRatioSegment.Call.DELETION));
                continue;
            }
            if (copyRatioDeviation > this.callingStatistics.standardDeviation * this.callingCopyRatioZScoreThreshold) {
                calledSegments.add(new CalledCopyRatioSegment(segment, CalledCopyRatioSegment.Call.AMPLIFICATION));
                continue;
            }
            calledSegments.add(new CalledCopyRatioSegment(segment, CalledCopyRatioSegment.Call.NEUTRAL));
        }
        return new CalledCopyRatioSegmentCollection((SampleLocatableMetadata)this.copyRatioSegments.getMetadata(), calledSegments);
    }

    private Statistics calculateCallingStatistics() {
        List<CopyRatioSegment> copyNeutralSegments = this.copyRatioSegments.getRecords().stream().filter(s -> {
            double copyRatioMean = Math.pow(2.0, s.getMeanLog2CopyRatio());
            return this.neutralSegmentCopyRatioLowerBound <= copyRatioMean && copyRatioMean <= this.neutralSegmentCopyRatioUpperBound;
        }).collect(Collectors.toList());
        logger.info(String.format("%d segments in copy-neutral region [%s, %s]...", copyNeutralSegments.size(), CopyNumberFormatsUtils.formatDouble(this.neutralSegmentCopyRatioLowerBound), CopyNumberFormatsUtils.formatDouble(this.neutralSegmentCopyRatioUpperBound)));
        Statistics unfilteredStatistics = SimpleCopyRatioCaller.calculateLengthWeightedStatistics(copyNeutralSegments);
        logger.info(String.format("Length-weighted mean of segments in copy-neutral region (CR space): %s", CopyNumberFormatsUtils.formatDouble(unfilteredStatistics.mean)));
        logger.info(String.format("Length-weighted standard deviation for segments in copy-neutral region : %s", CopyNumberFormatsUtils.formatDouble(unfilteredStatistics.standardDeviation)));
        List<CopyRatioSegment> filteredCopyNeutralSegments = copyNeutralSegments.stream().filter(s -> Math.abs(Math.pow(2.0, s.getMeanLog2CopyRatio()) - unfilteredStatistics.mean) <= unfilteredStatistics.standardDeviation * this.outlierNeutralSegmentCopyRatioZScoreThreshold).collect(Collectors.toList());
        logger.info(String.format("%d / %d segments in copy-neutral region remain after outliers filtered using z-score threshold (%s)...", filteredCopyNeutralSegments.size(), copyNeutralSegments.size(), CopyNumberFormatsUtils.formatDouble(this.outlierNeutralSegmentCopyRatioZScoreThreshold)));
        Statistics statistics = SimpleCopyRatioCaller.calculateLengthWeightedStatistics(filteredCopyNeutralSegments);
        logger.info(String.format("Length-weighted mean for z-score calling (CR space): %s", CopyNumberFormatsUtils.formatDouble(statistics.mean)));
        logger.info(String.format("Length-weighted standard deviation for z-score calling (CR space): %s", CopyNumberFormatsUtils.formatDouble(statistics.standardDeviation)));
        return statistics;
    }

    private static Statistics calculateLengthWeightedStatistics(List<CopyRatioSegment> copyRatioSegments) {
        List segmentLengths = copyRatioSegments.stream().map(c -> c.getInterval().getLengthOnReference()).collect(Collectors.toList());
        double totalLength = segmentLengths.stream().mapToDouble(Integer::doubleValue).sum();
        int numSegments = segmentLengths.size();
        double lengthWeightedCopyRatioMean = IntStream.range(0, numSegments).mapToDouble(i -> (double)((Integer)segmentLengths.get(i)).intValue() * Math.pow(2.0, ((CopyRatioSegment)copyRatioSegments.get(i)).getMeanLog2CopyRatio())).sum() / totalLength;
        double lengthWeightedCopyRatioStandardDeviation = Math.sqrt(IntStream.range(0, numSegments).mapToDouble(i -> (double)((Integer)segmentLengths.get(i)).intValue() * Math.pow(Math.pow(2.0, ((CopyRatioSegment)copyRatioSegments.get(i)).getMeanLog2CopyRatio()) - lengthWeightedCopyRatioMean, 2.0)).sum() / ((double)(numSegments - 1) / (double)numSegments * totalLength));
        return new Statistics(lengthWeightedCopyRatioMean, lengthWeightedCopyRatioStandardDeviation);
    }

    private static final class Statistics {
        private final double mean;
        private final double standardDeviation;

        private Statistics(double mean, double standardDeviation) {
            this.mean = mean;
            this.standardDeviation = standardDeviation;
        }
    }
}

