/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.denoising;

import com.google.common.primitives.Doubles;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.arguments.CopyNumberArgumentValidationUtils;
import org.broadinstitute.hellbender.tools.copynumber.denoising.GCBiasCorrector;
import org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisedCopyRatioResult;
import org.broadinstitute.hellbender.tools.copynumber.denoising.SVDReadCountPanelOfNormals;
import org.broadinstitute.hellbender.tools.copynumber.formats.collections.SimpleCountCollection;
import org.broadinstitute.hellbender.tools.copynumber.formats.metadata.SampleLocatableMetadata;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.MatrixSummaryUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class SVDDenoisingUtils {
    private static final Logger logger = LogManager.getLogger(SVDDenoisingUtils.class);
    private static final double EPSILON = 1.0E-9;
    private static final double LN2_EPSILON = Math.log(1.0E-9) * MathUtils.INV_LOG_2;

    private SVDDenoisingUtils() {
    }

    static PreprocessedStandardizedResult preprocessAndStandardizePanel(RealMatrix readCounts, double[] intervalGCContent, double minimumIntervalMedianPercentile, double maximumZerosInSamplePercentage, double maximumZerosInIntervalPercentage, double extremeSampleMedianPercentile, boolean doImputeZeros, double extremeOutlierTruncationPercentile) {
        logger.info("Preprocessing read counts...");
        PreprocessedStandardizedResult preprocessedStandardizedResult = SVDDenoisingUtils.preprocessPanel(readCounts, intervalGCContent, minimumIntervalMedianPercentile, maximumZerosInSamplePercentage, maximumZerosInIntervalPercentage, extremeSampleMedianPercentile, doImputeZeros, extremeOutlierTruncationPercentile);
        logger.info("Panel read counts preprocessed.");
        logger.info("Standardizing read counts...");
        SVDDenoisingUtils.divideBySampleMedianAndTransformToLog2(preprocessedStandardizedResult.preprocessedStandardizedValues);
        logger.info("Subtracting median of sample medians...");
        double[] sampleLog2Medians = MatrixSummaryUtils.getRowMedians(preprocessedStandardizedResult.preprocessedStandardizedValues);
        final double medianOfSampleMedians = new Median().evaluate(sampleLog2Medians);
        preprocessedStandardizedResult.preprocessedStandardizedValues.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return value - medianOfSampleMedians;
            }
        });
        logger.info("Panel read counts standardized.");
        return preprocessedStandardizedResult;
    }

    static SVDDenoisedCopyRatioResult denoise(SVDReadCountPanelOfNormals panelOfNormals, SimpleCountCollection readCounts, int numEigensamples) {
        RealMatrix denoisedCopyRatioValues;
        Utils.nonNull(panelOfNormals);
        if (!CopyNumberArgumentValidationUtils.isSameDictionary(panelOfNormals.getSequenceDictionary(), ((SampleLocatableMetadata)readCounts.getMetadata()).getSequenceDictionary())) {
            logger.warn("Sequence dictionaries in panel and case sample do not match.");
        }
        ParamUtils.isPositiveOrZero(numEigensamples, "Number of eigensamples to use for denoising must be non-negative.");
        Utils.validateArg(numEigensamples <= panelOfNormals.getNumEigensamples(), "Number of eigensamples to use for denoising is greater than the number available in the panel of normals.");
        logger.info("Validating sample intervals against original intervals used to build panel of normals...");
        Utils.validateArg(panelOfNormals.getOriginalIntervals().equals(readCounts.getIntervals()), "Sample intervals must be identical to the original intervals used to build the panel of normals.");
        logger.info("Preprocessing and standardizing sample read counts...");
        RealMatrix standardizedCopyRatioValues = SVDDenoisingUtils.preprocessAndStandardizeSample(panelOfNormals, readCounts.getCounts());
        if (numEigensamples == 0 || panelOfNormals.getNumEigensamples() == 0) {
            logger.warn("A zero number of eigensamples was specified or no eigensamples were available to perform denoising; denoised copy ratios will be identical to the standardized copy ratios...");
            denoisedCopyRatioValues = standardizedCopyRatioValues;
        } else {
            logger.info(String.format("Using %d out of %d eigensamples to denoise...", numEigensamples, panelOfNormals.getNumEigensamples()));
            logger.info("Subtracting projection onto space spanned by eigensamples...");
            denoisedCopyRatioValues = SVDDenoisingUtils.subtractProjection(standardizedCopyRatioValues, panelOfNormals.getEigensampleVectors(), numEigensamples);
        }
        logger.info("Sample denoised.");
        return new SVDDenoisedCopyRatioResult((SampleLocatableMetadata)readCounts.getMetadata(), panelOfNormals.getPanelIntervals(), standardizedCopyRatioValues, denoisedCopyRatioValues);
    }

    public static RealMatrix preprocessAndStandardizeSample(double[] readCounts, double[] intervalGCContent) {
        Utils.nonNull(readCounts);
        Utils.validateArg(intervalGCContent == null || readCounts.length == intervalGCContent.length, "Number of intervals for read counts must match those for GC-content annotations.");
        Array2DRowRealMatrix result = new Array2DRowRealMatrix((double[][])new double[][]{readCounts});
        logger.info("Preprocessing read counts...");
        SVDDenoisingUtils.transformToFractionalCoverage((RealMatrix)result);
        SVDDenoisingUtils.performOptionalGCBiasCorrection((RealMatrix)result, intervalGCContent);
        logger.info("Sample read counts preprocessed.");
        logger.info("Standardizing read counts...");
        SVDDenoisingUtils.divideBySampleMedianAndTransformToLog2((RealMatrix)result);
        logger.info("Subtracting sample median...");
        final double[] sampleLog2Medians = MatrixSummaryUtils.getRowMedians((RealMatrix)result);
        result.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return value - sampleLog2Medians[sampleIndex];
            }
        });
        logger.info("Sample read counts standardized.");
        return result;
    }

    private static PreprocessedStandardizedResult preprocessPanel(RealMatrix readCounts, double[] intervalGCContent, double minimumIntervalMedianPercentile, double maximumZerosInSamplePercentage, double maximumZerosInIntervalPercentage, double extremeSampleMedianPercentile, boolean doImputeZeros, double extremeOutlierTruncationPercentile) {
        SVDDenoisingUtils.transformToFractionalCoverage(readCounts);
        SVDDenoisingUtils.performOptionalGCBiasCorrection(readCounts, intervalGCContent);
        int numOriginalSamples = readCounts.getRowDimension();
        int numOriginalIntervals = readCounts.getColumnDimension();
        boolean[] filterSamples = new boolean[numOriginalSamples];
        boolean[] filterIntervals = new boolean[numOriginalIntervals];
        double[] originalIntervalMedians = MatrixSummaryUtils.getColumnMedians(readCounts);
        if (minimumIntervalMedianPercentile == 0.0) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding filtering step will be skipped...", "minimum-interval-median-percentile"));
        } else {
            double minimumIntervalMedianThreshold = new Percentile(minimumIntervalMedianPercentile).evaluate(originalIntervalMedians);
            logger.info(String.format("Filtering intervals with median (across samples) less than or equal to the %.2f percentile (%.2f)...", minimumIntervalMedianPercentile, minimumIntervalMedianThreshold));
            IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> originalIntervalMedians[intervalIndex] <= minimumIntervalMedianThreshold).forEach(intervalIndex -> {
                filterIntervals[intervalIndex] = true;
            });
            logger.info(String.format("After filtering, %d out of %d intervals remain...", SVDDenoisingUtils.countNumberPassingFilter(filterIntervals), numOriginalIntervals));
        }
        logger.info("Dividing by interval medians...");
        IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex]).forEach(intervalIndex -> IntStream.range(0, numOriginalSamples).filter(sampleIndex -> !filterSamples[sampleIndex]).forEach(sampleIndex -> {
            double value = readCounts.getEntry(sampleIndex, intervalIndex);
            readCounts.setEntry(sampleIndex, intervalIndex, value / originalIntervalMedians[intervalIndex]);
        }));
        if (maximumZerosInSamplePercentage == 100.0) {
            logger.info(String.format("A value of 100 was provided for argument %s, so the corresponding filtering step will be skipped...", "maximum-zeros-in-sample-percentage"));
        } else {
            logger.info(String.format("Filtering samples with a fraction of zero-coverage intervals greater than or equal to %.2f percent...", maximumZerosInSamplePercentage));
            int numPassingIntervals = SVDDenoisingUtils.countNumberPassingFilter(filterIntervals);
            IntStream.range(0, numOriginalSamples).filter(sampleIndex -> !filterSamples[sampleIndex]).forEach(sampleIndex -> {
                double numZerosInSample = IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex] && readCounts.getEntry(sampleIndex, intervalIndex) == 0.0).count();
                if (numZerosInSample / (double)numPassingIntervals >= maximumZerosInSamplePercentage / 100.0) {
                    filterSamples[sampleIndex] = true;
                }
            });
            logger.info(String.format("After filtering, %d out of %d samples remain...", SVDDenoisingUtils.countNumberPassingFilter(filterSamples), numOriginalSamples));
        }
        if (maximumZerosInIntervalPercentage == 100.0) {
            logger.info(String.format("A value of 100 was provided for argument %s, so the corresponding filtering step will be skipped...", "maximum-zeros-in-interval-percentage"));
        } else {
            logger.info(String.format("Filtering intervals with a fraction of zero-coverage samples greater than or equal to %.2f percent...", maximumZerosInIntervalPercentage));
            int numPassingSamples = SVDDenoisingUtils.countNumberPassingFilter(filterSamples);
            IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex]).forEach(intervalIndex -> {
                double numZerosInInterval = IntStream.range(0, numOriginalSamples).filter(sampleIndex -> !filterSamples[sampleIndex] && readCounts.getEntry(sampleIndex, intervalIndex) == 0.0).count();
                if (numZerosInInterval / (double)numPassingSamples >= maximumZerosInIntervalPercentage / 100.0) {
                    filterIntervals[intervalIndex] = true;
                }
            });
            logger.info(String.format("After filtering, %d out of %d intervals remain...", SVDDenoisingUtils.countNumberPassingFilter(filterIntervals), numOriginalIntervals));
        }
        if (extremeSampleMedianPercentile == 0.0) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding filtering step will be skipped...", "extreme-sample-median-percentile"));
        } else {
            double[] sampleMedians = IntStream.range(0, numOriginalSamples).mapToDouble(sampleIndex -> new Median().evaluate(IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex]).mapToDouble(intervalIndex -> readCounts.getEntry(sampleIndex, intervalIndex)).toArray())).toArray();
            double minimumSampleMedianThreshold = new Percentile(extremeSampleMedianPercentile).evaluate(sampleMedians);
            double maximumSampleMedianThreshold = new Percentile(100.0 - extremeSampleMedianPercentile).evaluate(sampleMedians);
            logger.info(String.format("Filtering samples with a median (across intervals) strictly below the %.2f percentile (%.2f) or strictly above the %.2f percentile (%.2f)...", extremeSampleMedianPercentile, minimumSampleMedianThreshold, 100.0 - extremeSampleMedianPercentile, maximumSampleMedianThreshold));
            IntStream.range(0, numOriginalSamples).filter(sampleIndex -> sampleMedians[sampleIndex] < minimumSampleMedianThreshold || sampleMedians[sampleIndex] > maximumSampleMedianThreshold).forEach(sampleIndex -> {
                filterSamples[sampleIndex] = true;
            });
            logger.info(String.format("After filtering, %d out of %d samples remain...", SVDDenoisingUtils.countNumberPassingFilter(filterSamples), numOriginalSamples));
        }
        int[] panelIntervalIndices = IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex]).toArray();
        int[] panelSampleIndices = IntStream.range(0, numOriginalSamples).filter(sampleIndex -> !filterSamples[sampleIndex]).toArray();
        RealMatrix preprocessedReadCounts = readCounts.getSubMatrix(panelSampleIndices, panelIntervalIndices);
        double[] panelIntervalFractionalMedians = IntStream.range(0, numOriginalIntervals).filter(intervalIndex -> !filterIntervals[intervalIndex]).mapToDouble(intervalIndex -> originalIntervalMedians[intervalIndex]).toArray();
        SVDDenoisingUtils.logHeapUsage();
        logger.info("Performing garbage collection...");
        System.gc();
        SVDDenoisingUtils.logHeapUsage();
        if (!doImputeZeros) {
            logger.info("Skipping imputation of zero-coverage values...");
        } else {
            int numPanelIntervals = panelIntervalIndices.length;
            final double[] intervalNonZeroMedians = IntStream.range(0, numPanelIntervals).mapToObj(intervalIndex -> Arrays.stream(preprocessedReadCounts.getColumn(intervalIndex)).filter(value -> value > 0.0).toArray()).mapToDouble(nonZeroValues -> new Median().evaluate(nonZeroValues)).toArray();
            final int[] numImputed = new int[]{0};
            preprocessedReadCounts.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

                public double visit(int sampleIndex, int intervalIndex, double value) {
                    if (value == 0.0) {
                        numImputed[0] = numImputed[0] + 1;
                        return intervalNonZeroMedians[intervalIndex];
                    }
                    return value;
                }
            });
            logger.info(String.format("%d zero-coverage values were imputed to the median of the non-zero values in the corresponding interval...", numImputed[0]));
        }
        if (extremeOutlierTruncationPercentile == 0.0) {
            logger.info(String.format("A value of 0 was provided for argument %s, so the corresponding truncation step will be skipped...", "extreme-outlier-truncation-percentile"));
        } else if ((long)preprocessedReadCounts.getRowDimension() * (long)preprocessedReadCounts.getColumnDimension() > Integer.MAX_VALUE) {
            logger.warn("The number of matrix elements exceeds Integer.MAX_VALUE, so outlier truncation will be skipped...");
        } else {
            double[] values = Doubles.concat((double[][])preprocessedReadCounts.getData());
            final double minimumOutlierTruncationThreshold = new Percentile(extremeOutlierTruncationPercentile).evaluate(values);
            final double maximumOutlierTruncationThreshold = new Percentile(100.0 - extremeOutlierTruncationPercentile).evaluate(values);
            final int[] numTruncated = new int[]{0};
            preprocessedReadCounts.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

                public double visit(int sampleIndex, int intervalIndex, double value) {
                    if (value < minimumOutlierTruncationThreshold) {
                        numTruncated[0] = numTruncated[0] + 1;
                        return minimumOutlierTruncationThreshold;
                    }
                    if (value > maximumOutlierTruncationThreshold) {
                        numTruncated[0] = numTruncated[0] + 1;
                        return maximumOutlierTruncationThreshold;
                    }
                    return value;
                }
            });
            logger.info(String.format("%d values strictly below the %.2f percentile (%.2f) or strictly above the %.2f percentile (%.2f) were truncated to the corresponding value...", numTruncated[0], extremeOutlierTruncationPercentile, minimumOutlierTruncationThreshold, 100.0 - extremeOutlierTruncationPercentile, maximumOutlierTruncationThreshold));
        }
        return new PreprocessedStandardizedResult(preprocessedReadCounts, panelIntervalFractionalMedians, filterSamples, filterIntervals);
    }

    private static void logHeapUsage() {
        int mb = 0x100000;
        Runtime runtime = Runtime.getRuntime();
        logger.info("Heap utilization statistics [MB]:");
        logger.info("Used memory: " + (runtime.totalMemory() - runtime.freeMemory()) / 0x100000L);
        logger.info("Free memory: " + runtime.freeMemory() / 0x100000L);
        logger.info("Total memory: " + runtime.totalMemory() / 0x100000L);
        logger.info("Maximum memory: " + runtime.maxMemory() / 0x100000L);
    }

    private static RealMatrix preprocessAndStandardizeSample(SVDReadCountPanelOfNormals panelOfNormals, double[] readCounts) {
        Array2DRowRealMatrix result = new Array2DRowRealMatrix((double[][])new double[][]{readCounts});
        logger.info("Preprocessing read counts...");
        SVDDenoisingUtils.transformToFractionalCoverage((RealMatrix)result);
        SVDDenoisingUtils.performOptionalGCBiasCorrection((RealMatrix)result, panelOfNormals.getOriginalIntervalGCContent());
        logger.info("Subsetting sample intervals to post-filter panel intervals...");
        HashSet<SimpleInterval> panelIntervals = new HashSet<SimpleInterval>(panelOfNormals.getPanelIntervals());
        int[] subsetIntervalIndices = IntStream.range(0, panelOfNormals.getOriginalIntervals().size()).filter(i -> panelIntervals.contains(panelOfNormals.getOriginalIntervals().get(i))).toArray();
        result = result.getSubMatrix(new int[]{0}, subsetIntervalIndices);
        logger.info("Dividing by interval medians from the panel of normals...");
        final double[] intervalMedians = panelOfNormals.getPanelIntervalFractionalMedians();
        result.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return value / intervalMedians[intervalIndex];
            }
        });
        logger.info("Sample read counts preprocessed.");
        logger.info("Standardizing read counts...");
        SVDDenoisingUtils.divideBySampleMedianAndTransformToLog2((RealMatrix)result);
        logger.info("Subtracting sample median...");
        final double[] sampleLog2Medians = MatrixSummaryUtils.getRowMedians((RealMatrix)result);
        result.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return value - sampleLog2Medians[sampleIndex];
            }
        });
        logger.info("Sample read counts standardized.");
        return result;
    }

    private static RealMatrix subtractProjection(RealMatrix standardizedValues, double[][] eigensampleVectors, int numEigensamples) {
        if (numEigensamples == 0) {
            return standardizedValues.copy();
        }
        int numIntervals = eigensampleVectors.length;
        int numAllEigensamples = eigensampleVectors[0].length;
        logger.info("Distributing the standardized read counts...");
        logger.info("Composing eigensample matrix for the requested number of eigensamples and transposing them...");
        Array2DRowRealMatrix eigensampleTruncatedMatrix = numEigensamples == numAllEigensamples ? new Array2DRowRealMatrix(eigensampleVectors, false) : new Array2DRowRealMatrix(eigensampleVectors, false).getSubMatrix(0, numIntervals - 1, 0, numEigensamples - 1);
        logger.info("Computing projection...");
        RealMatrix projection = standardizedValues.multiply((RealMatrix)eigensampleTruncatedMatrix).multiply(eigensampleTruncatedMatrix.transpose());
        logger.info("Subtracting projection...");
        return standardizedValues.subtract(projection);
    }

    private static int countNumberPassingFilter(boolean[] filter) {
        int numPassingFilter = (int)IntStream.range(0, filter.length).filter(i -> !filter[i]).count();
        if (numPassingFilter == 0) {
            throw new UserException.BadInput("Filtering removed all samples or intervals.  Select less strict filtering criteria.");
        }
        return numPassingFilter;
    }

    private static void transformToFractionalCoverage(RealMatrix matrix) {
        logger.info("Transforming read counts to fractional coverage...");
        final double[] sampleSums = IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> MathUtils.sum(matrix.getRow(r))).toArray();
        matrix.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return value / sampleSums[sampleIndex];
            }
        });
    }

    private static void performOptionalGCBiasCorrection(RealMatrix matrix, double[] intervalGCContent) {
        if (intervalGCContent != null) {
            logger.info("Performing GC-bias correction...");
            GCBiasCorrector.correctGCBias(matrix, intervalGCContent);
        }
    }

    private static void divideBySampleMedianAndTransformToLog2(RealMatrix matrix) {
        logger.info("Dividing by sample medians and transforming to log2 space...");
        final double[] sampleMedians = MatrixSummaryUtils.getRowMedians(matrix);
        IntStream.range(0, sampleMedians.length).forEach(sampleIndex -> ParamUtils.isPositive(sampleMedians[sampleIndex], sampleMedians.length == 1 ? "Sample does not have a positive sample median." : String.format("Sample at index %s does not have a positive sample median.", sampleIndex)));
        matrix.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int sampleIndex, int intervalIndex, double value) {
                return SVDDenoisingUtils.safeLog2(value / sampleMedians[sampleIndex]);
            }
        });
    }

    private static double safeLog2(double x) {
        return x < 1.0E-9 ? LN2_EPSILON : Math.log(x) * MathUtils.INV_LOG_2;
    }

    static final class PreprocessedStandardizedResult {
        final RealMatrix preprocessedStandardizedValues;
        final double[] panelIntervalFractionalMedians;
        final boolean[] filterSamples;
        final boolean[] filterIntervals;

        private PreprocessedStandardizedResult(RealMatrix preprocessedStandardizedValues, double[] panelIntervalFractionalMedians, boolean[] filterSamples, boolean[] filterIntervals) {
            this.preprocessedStandardizedValues = preprocessedStandardizedValues;
            this.panelIntervalFractionalMedians = panelIntervalFractionalMedians;
            this.filterSamples = filterSamples;
            this.filterIntervals = filterIntervals;
        }
    }
}

