/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.models;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.Locatable;
import htsjdk.samtools.util.OverlapDetector;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.AllelicCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.CopyRatioCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.ModeledSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.MultidimensionalSegmentCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleIntervalCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SimpleLocatableMetadata;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.CopyRatio;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.ModeledSegment;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionModeller;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionPrior;
import org.broadinstitute.hellbender.tools.copynumber.models.CopyRatioModeller;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class MultidimensionalModeller {
    private static final Logger logger = LogManager.getLogger(MultidimensionalModeller.class);
    private final SampleLocatableMetadata metadata;
    private final CopyRatioCollection denoisedCopyRatios;
    private final OverlapDetector<CopyRatio> copyRatioMidpointOverlapDetector;
    private final AllelicCountCollection allelicCounts;
    private final OverlapDetector<AllelicCount> allelicCountOverlapDetector;
    private final AlleleFractionPrior alleleFractionPrior;
    private CopyRatioModeller copyRatioModeller;
    private AlleleFractionModeller alleleFractionModeller;
    private SimpleIntervalCollection currentSegments;
    private final List<ModeledSegment> modeledSegments = new ArrayList<ModeledSegment>();
    private boolean isModelFit;
    private final int numSamplesCopyRatio;
    private final int numBurnInCopyRatio;
    private final int numSamplesAlleleFraction;
    private final int numBurnInAlleleFraction;

    public MultidimensionalModeller(MultidimensionalSegmentCollection multidimensionalSegments, CopyRatioCollection denoisedCopyRatios, AllelicCountCollection allelicCounts, AlleleFractionPrior alleleFractionPrior, int numSamplesCopyRatio, int numBurnInCopyRatio, int numSamplesAlleleFraction, int numBurnInAlleleFraction) {
        Utils.validateArg(Stream.of((SampleLocatableMetadata)Utils.nonNull(multidimensionalSegments).getMetadata(), (SampleLocatableMetadata)Utils.nonNull(denoisedCopyRatios).getMetadata(), (SampleLocatableMetadata)Utils.nonNull(allelicCounts).getMetadata()).distinct().count() == 1L, "Metadata from all inputs must match.");
        ParamUtils.isPositive(multidimensionalSegments.size(), "Number of segments must be positive.");
        this.metadata = (SampleLocatableMetadata)multidimensionalSegments.getMetadata();
        this.currentSegments = new SimpleIntervalCollection(new SimpleLocatableMetadata(this.metadata.getSequenceDictionary()), multidimensionalSegments.getIntervals());
        this.denoisedCopyRatios = denoisedCopyRatios;
        this.copyRatioMidpointOverlapDetector = denoisedCopyRatios.getMidpointOverlapDetector();
        this.allelicCounts = allelicCounts;
        this.allelicCountOverlapDetector = allelicCounts.getOverlapDetector();
        this.alleleFractionPrior = Utils.nonNull(alleleFractionPrior);
        this.numSamplesCopyRatio = numSamplesCopyRatio;
        this.numBurnInCopyRatio = numBurnInCopyRatio;
        this.numSamplesAlleleFraction = numSamplesAlleleFraction;
        this.numBurnInAlleleFraction = numBurnInAlleleFraction;
        logger.info("Fitting initial model...");
        this.fitModel();
    }

    public ModeledSegmentCollection getModeledSegments() {
        return new ModeledSegmentCollection(this.metadata, this.modeledSegments);
    }

    private void fitModel() {
        logger.info("Fitting copy-ratio model...");
        this.copyRatioModeller = new CopyRatioModeller(this.denoisedCopyRatios, this.currentSegments);
        this.copyRatioModeller.fitMCMC(this.numSamplesCopyRatio, this.numBurnInCopyRatio);
        logger.info("Fitting allele-fraction model...");
        this.alleleFractionModeller = new AlleleFractionModeller(this.allelicCounts, this.currentSegments, this.alleleFractionPrior);
        this.alleleFractionModeller.fitMCMC(this.numSamplesAlleleFraction, this.numBurnInAlleleFraction);
        this.modeledSegments.clear();
        List<ModeledSegment.SimplePosteriorSummary> segmentMeansPosteriorSummaries = this.copyRatioModeller.getSegmentMeansPosteriorSummaries();
        List<ModeledSegment.SimplePosteriorSummary> minorAlleleFractionsPosteriorSummaries = this.alleleFractionModeller.getMinorAlleleFractionsPosteriorSummaries();
        for (int segmentIndex = 0; segmentIndex < this.currentSegments.size(); ++segmentIndex) {
            SimpleInterval segment = (SimpleInterval)this.currentSegments.getRecords().get(segmentIndex);
            int numPointsCopyRatio = this.copyRatioMidpointOverlapDetector.getOverlaps((Locatable)segment).size();
            int numPointsAlleleFraction = this.allelicCountOverlapDetector.getOverlaps((Locatable)segment).size();
            ModeledSegment.SimplePosteriorSummary segmentMeansPosteriorSummary = segmentMeansPosteriorSummaries.get(segmentIndex);
            ModeledSegment.SimplePosteriorSummary minorAlleleFractionPosteriorSummary = minorAlleleFractionsPosteriorSummaries.get(segmentIndex);
            this.modeledSegments.add(new ModeledSegment(segment, numPointsCopyRatio, numPointsAlleleFraction, segmentMeansPosteriorSummary, minorAlleleFractionPosteriorSummary));
        }
        this.isModelFit = true;
    }

    public void smoothSegments(int maxNumSmoothingIterations, int numSmoothingIterationsPerFit, double smoothingCredibleIntervalThresholdCopyRatio, double smoothingCredibleIntervalThresholdAlleleFraction) {
        ParamUtils.isPositiveOrZero(maxNumSmoothingIterations, "The maximum number of smoothing iterations must be non-negative.");
        ParamUtils.isPositiveOrZero(smoothingCredibleIntervalThresholdCopyRatio, "The number of smoothing iterations per fit must be non-negative.");
        ParamUtils.isPositiveOrZero(smoothingCredibleIntervalThresholdAlleleFraction, "The allele-fraction credible-interval threshold for segmentation smoothing must be non-negative.");
        logger.info(String.format("Initial number of segments before smoothing: %d", this.modeledSegments.size()));
        for (int numIterations = 1; numIterations <= maxNumSmoothingIterations; ++numIterations) {
            logger.info(String.format("Smoothing iteration: %d", numIterations));
            int prevNumSegments = this.modeledSegments.size();
            if (numSmoothingIterationsPerFit > 0 && numIterations % numSmoothingIterationsPerFit == 0) {
                this.performSmoothingIteration(smoothingCredibleIntervalThresholdCopyRatio, smoothingCredibleIntervalThresholdAlleleFraction, true);
            } else {
                this.performSmoothingIteration(smoothingCredibleIntervalThresholdCopyRatio, smoothingCredibleIntervalThresholdAlleleFraction, false);
            }
            if (this.modeledSegments.size() == prevNumSegments) break;
        }
        if (!this.isModelFit) {
            this.fitModel();
        }
        logger.info(String.format("Final number of segments after smoothing: %d", this.modeledSegments.size()));
    }

    private void performSmoothingIteration(double intervalThresholdSegmentMean, double intervalThresholdMinorAlleleFraction, boolean doModelFit) {
        logger.info("Number of segments before smoothing iteration: " + this.modeledSegments.size());
        List mergedSegments = SimilarSegmentUtils.mergeSimilarSegments(this.modeledSegments, intervalThresholdSegmentMean, intervalThresholdMinorAlleleFraction);
        logger.info("Number of segments after smoothing iteration: " + mergedSegments.size());
        this.currentSegments = new SimpleIntervalCollection(new SimpleLocatableMetadata(this.metadata.getSequenceDictionary()), mergedSegments.stream().map(ModeledSegment::getInterval).collect(Collectors.toList()));
        if (doModelFit) {
            this.fitModel();
        } else {
            this.modeledSegments.clear();
            this.modeledSegments.addAll(mergedSegments);
            this.isModelFit = false;
        }
    }

    public void writeModelParameterFiles(File copyRatioParameterFile, File alleleFractionParameterFile) {
        Utils.nonNull(copyRatioParameterFile);
        Utils.nonNull(alleleFractionParameterFile);
        this.ensureModelIsFit();
        logger.info(String.format("Writing posterior summaries for copy-ratio global parameters to %s...", copyRatioParameterFile.getAbsolutePath()));
        this.copyRatioModeller.getGlobalParameterDeciles().write(copyRatioParameterFile);
        logger.info(String.format("Writing posterior summaries for allele-fraction global parameters to %s...", alleleFractionParameterFile.getAbsolutePath()));
        this.alleleFractionModeller.getGlobalParameterDeciles().write(alleleFractionParameterFile);
    }

    @VisibleForTesting
    CopyRatioModeller getCopyRatioModeller() {
        return this.copyRatioModeller;
    }

    @VisibleForTesting
    AlleleFractionModeller getAlleleFractionModeller() {
        return this.alleleFractionModeller;
    }

    private void ensureModelIsFit() {
        if (!this.isModelFit) {
            logger.warn("Attempted to write results to file when model was not completely fit. Performing model fit now.");
            this.fitModel();
        }
    }

    private static final class SimilarSegmentUtils {
        private SimilarSegmentUtils() {
        }

        private static List<ModeledSegment> mergeSimilarSegments(List<ModeledSegment> segments, double intervalThresholdSegmentMean, double intervalThresholdMinorAlleleFraction) {
            ArrayList<ModeledSegment> mergedSegments = new ArrayList<ModeledSegment>(segments);
            for (int index = 0; index < mergedSegments.size() - 1; ++index) {
                ModeledSegment segment1 = (ModeledSegment)mergedSegments.get(index);
                ModeledSegment segment2 = (ModeledSegment)mergedSegments.get(index + 1);
                if (!segment1.getContig().equals(segment2.getContig()) || !SimilarSegmentUtils.areSimilar(segment1, segment2, intervalThresholdSegmentMean, intervalThresholdMinorAlleleFraction)) continue;
                mergedSegments.set(index, SimilarSegmentUtils.merge(segment1, segment2));
                mergedSegments.remove(index + 1);
                --index;
            }
            return mergedSegments;
        }

        private static boolean areSimilar(ModeledSegment.SimplePosteriorSummary summary1, ModeledSegment.SimplePosteriorSummary summary2, double intervalThreshold) {
            if (Double.isNaN(summary1.getDecile50()) || Double.isNaN(summary2.getDecile50())) {
                return true;
            }
            double absoluteDifference = Math.abs(summary1.getDecile50() - summary2.getDecile50());
            return absoluteDifference < intervalThreshold * (summary1.getDecile90() - summary1.getDecile10()) || absoluteDifference < intervalThreshold * (summary2.getDecile90() - summary2.getDecile10());
        }

        private static boolean areSimilar(ModeledSegment segment1, ModeledSegment segment2, double intervalThresholdSegmentMean, double intervalThresholdMinorAlleleFraction) {
            return SimilarSegmentUtils.areSimilar(segment1.getLog2CopyRatioSimplePosteriorSummary(), segment2.getLog2CopyRatioSimplePosteriorSummary(), intervalThresholdSegmentMean) && SimilarSegmentUtils.areSimilar(segment1.getMinorAlleleFractionSimplePosteriorSummary(), segment2.getMinorAlleleFractionSimplePosteriorSummary(), intervalThresholdMinorAlleleFraction);
        }

        private static ModeledSegment.SimplePosteriorSummary merge(ModeledSegment.SimplePosteriorSummary summary1, ModeledSegment.SimplePosteriorSummary summary2) {
            if (Double.isNaN(summary1.getDecile50()) && !Double.isNaN(summary2.getDecile50())) {
                return summary2;
            }
            if (!Double.isNaN(summary1.getDecile50()) && Double.isNaN(summary2.getDecile50()) || Double.isNaN(summary1.getDecile50()) && Double.isNaN(summary2.getDecile50())) {
                return summary1;
            }
            double standardDeviation1 = (summary1.getDecile90() - summary1.getDecile10()) / 2.0;
            double standardDeviation2 = (summary2.getDecile90() - summary2.getDecile10()) / 2.0;
            double variance = 1.0 / (1.0 / Math.pow(standardDeviation1, 2.0) + 1.0 / Math.pow(standardDeviation2, 2.0));
            double mean = (summary1.getDecile50() / Math.pow(standardDeviation1, 2.0) + summary2.getDecile50() / Math.pow(standardDeviation2, 2.0)) * variance;
            double standardDeviation = Math.sqrt(variance);
            return new ModeledSegment.SimplePosteriorSummary(mean, mean - standardDeviation, mean + standardDeviation);
        }

        private static ModeledSegment merge(ModeledSegment segment1, ModeledSegment segment2) {
            return new ModeledSegment(SimilarSegmentUtils.mergeSegments(segment1.getInterval(), segment2.getInterval()), segment1.getNumPointsCopyRatio() + segment2.getNumPointsCopyRatio(), segment1.getNumPointsAlleleFraction() + segment2.getNumPointsAlleleFraction(), SimilarSegmentUtils.merge(segment1.getLog2CopyRatioSimplePosteriorSummary(), segment2.getLog2CopyRatioSimplePosteriorSummary()), SimilarSegmentUtils.merge(segment1.getMinorAlleleFractionSimplePosteriorSummary(), segment2.getMinorAlleleFractionSimplePosteriorSummary()));
        }

        private static SimpleInterval mergeSegments(SimpleInterval segment1, SimpleInterval segment2) {
            Utils.validateArg(segment1.getContig().equals(segment2.getContig()), String.format("Cannot join segments %s and %s on different chromosomes.", segment1.toString(), segment2.toString()));
            int start = Math.min(segment1.getStart(), segment2.getStart());
            int end = Math.max(segment1.getEnd(), segment2.getEnd());
            return new SimpleInterval(segment1.getContig(), start, end);
        }
    }
}

