/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.denoising;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMTextHeaderCodec;
import htsjdk.samtools.util.BufferedLineReader;
import htsjdk.samtools.util.Lazy;
import htsjdk.samtools.util.LineReader;
import java.io.File;
import java.io.StringWriter;
import java.io.Writer;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.SingularValueDecomposition;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.broadinstitute.hdf5.HDF5File;
import org.broadinstitute.hdf5.HDF5LibException;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.copynumber.denoising.SVDDenoisingUtils;
import org.broadinstitute.hellbender.tools.copynumber.denoising.SVDReadCountPanelOfNormals;
import org.broadinstitute.hellbender.tools.copynumber.utils.HDF5Utils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.spark.SparkConverter;

public final class HDF5SVDReadCountPanelOfNormals
implements SVDReadCountPanelOfNormals {
    private static final Logger logger = LogManager.getLogger(HDF5SVDReadCountPanelOfNormals.class);
    private static final int CHUNK_DIVISOR = 16;
    private static final int NUM_SLICES_FOR_SPARK_MATRIX_CONVERSION = 100;
    private static final double EPSILON = 1.0E-9;
    private static final double CURRENT_PON_VERSION = 7.0;
    private static final String PON_VERSION_STRING_FORMAT = "%.1f";
    private static final String VERSION_PATH = "/version/value";
    private static final String COMMAND_LINE_PATH = "/command_line/value";
    private static final String SEQUENCE_DICTIONARY_PATH = "/sequence_dictionary/value";
    private static final String ORIGINAL_DATA_GROUP_NAME = "/original_data";
    private static final String ORIGINAL_READ_COUNTS_PATH = "/original_data/read_counts_samples_by_intervals";
    private static final String ORIGINAL_SAMPLE_FILENAMES_PATH = "/original_data/sample_filenames";
    private static final String ORIGINAL_INTERVALS_PATH = "/original_data/intervals";
    private static final String ORIGINAL_INTERVAL_GC_CONTENT_PATH = "/original_data/interval_gc_content";
    private static final String PANEL_GROUP_NAME = "/panel";
    private static final String PANEL_SAMPLE_FILENAMES_PATH = "/panel/sample_filenames";
    private static final String PANEL_INTERVALS_PATH = "/panel/intervals";
    private static final String PANEL_INTERVAL_FRACTIONAL_MEDIANS_PATH = "/panel/interval_fractional_medians";
    private static final String PANEL_SINGULAR_VALUES_PATH = "/panel/singular_values";
    private static final String PANEL_EIGENSAMPLE_VECTORS_PATH = "/panel/transposed_eigensamples_samples_by_intervals";
    private static final String PANEL_NUM_EIGENSAMPLES_PATH = "/panel/transposed_eigensamples_samples_by_intervals/num_rows";
    private final HDF5File file;
    private final Lazy<SAMSequenceDictionary> sequenceDictionary;
    private final Lazy<List<SimpleInterval>> originalIntervals;
    private final Lazy<List<SimpleInterval>> panelIntervals;

    private HDF5SVDReadCountPanelOfNormals(HDF5File file) {
        IOUtils.canReadFile(file.getFile());
        this.file = file;
        this.sequenceDictionary = new Lazy(() -> {
            String sequenceDictionaryString = file.readStringArray(SEQUENCE_DICTIONARY_PATH)[0];
            return new SAMTextHeaderCodec().decode((LineReader)BufferedLineReader.fromString((String)sequenceDictionaryString), file.getFile().getAbsolutePath()).getSequenceDictionary();
        });
        this.originalIntervals = new Lazy(() -> HDF5Utils.readIntervals(file, ORIGINAL_INTERVALS_PATH));
        this.panelIntervals = new Lazy(() -> HDF5Utils.readIntervals(file, PANEL_INTERVALS_PATH));
    }

    @Override
    public double getVersion() {
        if (!this.file.isPresent(VERSION_PATH)) {
            throw new UserException.BadInput(String.format("The panel of normals is out of date and incompatible.  Please use a panel of normals that was created by CreateReadCountPanelOfNormals and is version %.1f.", 7.0));
        }
        return this.file.readDouble(VERSION_PATH);
    }

    @Override
    public int getNumEigensamples() {
        try {
            return (int)this.file.readDouble(PANEL_NUM_EIGENSAMPLES_PATH);
        }
        catch (HDF5LibException e) {
            return 0;
        }
    }

    @Override
    public SAMSequenceDictionary getSequenceDictionary() {
        return (SAMSequenceDictionary)this.sequenceDictionary.get();
    }

    @Override
    public double[][] getOriginalReadCounts() {
        return HDF5Utils.readChunkedDoubleMatrix(this.file, ORIGINAL_READ_COUNTS_PATH);
    }

    @Override
    public List<SimpleInterval> getOriginalIntervals() {
        return (List)this.originalIntervals.get();
    }

    @Override
    public double[] getOriginalIntervalGCContent() {
        if (!this.file.isPresent(ORIGINAL_INTERVAL_GC_CONTENT_PATH)) {
            return null;
        }
        return this.file.readDoubleArray(ORIGINAL_INTERVAL_GC_CONTENT_PATH);
    }

    @Override
    public List<SimpleInterval> getPanelIntervals() {
        return (List)this.panelIntervals.get();
    }

    @Override
    public double[] getPanelIntervalFractionalMedians() {
        return this.file.readDoubleArray(PANEL_INTERVAL_FRACTIONAL_MEDIANS_PATH);
    }

    @Override
    public double[] getSingularValues() {
        if (this.getNumEigensamples() == 0) {
            throw new UnsupportedOperationException("No singular values were available.  This is because the panel only contains a single sample or no eigensamples were requested upon panel creation.");
        }
        return this.file.readDoubleArray(PANEL_SINGULAR_VALUES_PATH);
    }

    @Override
    public double[][] getEigensampleVectors() {
        if (this.getNumEigensamples() == 0) {
            throw new UnsupportedOperationException("No eigensample vectors were available.  This is because the panel only contains a single sample or no eigensamples were requested upon panel creation.");
        }
        return new Array2DRowRealMatrix(HDF5Utils.readChunkedDoubleMatrix(this.file, PANEL_EIGENSAMPLE_VECTORS_PATH), false).transpose().getData();
    }

    public static HDF5SVDReadCountPanelOfNormals read(HDF5File file) {
        IOUtils.canReadFile(file.getFile());
        HDF5SVDReadCountPanelOfNormals pon = new HDF5SVDReadCountPanelOfNormals(file);
        if (pon.getVersion() < 7.0) {
            throw new UserException.BadInput(String.format("The version of the specified panel of normals (%f) is older than the current version (%f).", pon.getVersion(), 7.0));
        }
        return pon;
    }

    public static void create(File outFile, String commandLine, SAMSequenceDictionary sequenceDictionary, RealMatrix originalReadCounts, List<String> originalSampleFilenames, List<SimpleInterval> originalIntervals, double[] intervalGCContent, double minimumIntervalMedianPercentile, double maximumZerosInSamplePercentage, double maximumZerosInIntervalPercentage, double extremeSampleMedianPercentile, boolean doImputeZeros, double extremeOutlierTruncationPercentile, int numEigensamplesRequested, int maximumChunkSize, JavaSparkContext ctx) {
        try (HDF5File file = new HDF5File(outFile, HDF5File.OpenMode.CREATE);){
            logger.info(String.format("Creating read-count panel of normals at %s...", outFile.getAbsolutePath()));
            HDF5SVDReadCountPanelOfNormals pon = new HDF5SVDReadCountPanelOfNormals(file);
            logger.info(String.format("Writing version number (%.1f)...", 7.0));
            pon.writeVersion(7.0);
            logger.info("Writing command line...");
            pon.writeCommandLine(commandLine);
            logger.info("Writing sequence dictionary...");
            pon.writeSequenceDictionary(sequenceDictionary);
            logger.info(String.format("Writing original read counts (%d x %d)...", originalReadCounts.getColumnDimension(), originalReadCounts.getRowDimension()));
            pon.writeOriginalReadCountsPath(originalReadCounts, maximumChunkSize);
            logger.info(String.format("Writing original sample filenames (%d)...", originalSampleFilenames.size()));
            pon.writeOriginalSampleFilenames(originalSampleFilenames);
            logger.info(String.format("Writing original intervals (%d)...", originalIntervals.size()));
            pon.writeOriginalIntervals(originalIntervals);
            if (intervalGCContent != null) {
                logger.info(String.format("Writing GC-content annotations for original intervals (%d)...", intervalGCContent.length));
                pon.writeOriginalIntervalGCContent(intervalGCContent);
            }
            logger.info("Preprocessing and standardizing read counts...");
            SVDDenoisingUtils.PreprocessedStandardizedResult preprocessedStandardizedResult = SVDDenoisingUtils.preprocessAndStandardizePanel(originalReadCounts, intervalGCContent, minimumIntervalMedianPercentile, maximumZerosInSamplePercentage, maximumZerosInIntervalPercentage, extremeSampleMedianPercentile, doImputeZeros, extremeOutlierTruncationPercentile);
            List<String> panelSampleFilenames = IntStream.range(0, originalSampleFilenames.size()).filter(sampleIndex -> !preprocessedStandardizedResult.filterSamples[sampleIndex]).mapToObj(originalSampleFilenames::get).collect(Collectors.toList());
            List<SimpleInterval> panelIntervals = IntStream.range(0, originalIntervals.size()).filter(intervalIndex -> !preprocessedStandardizedResult.filterIntervals[intervalIndex]).mapToObj(originalIntervals::get).collect(Collectors.toList());
            logger.info(String.format("Writing panel sample filenames (%d)...", panelSampleFilenames.size()));
            pon.writePanelSampleFilenames(panelSampleFilenames);
            logger.info(String.format("Writing panel intervals (%d)...", panelIntervals.size()));
            pon.writePanelIntervals(panelIntervals);
            double[] panelIntervalFractionalMedians = preprocessedStandardizedResult.panelIntervalFractionalMedians;
            logger.info(String.format("Writing panel interval fractional medians (%d)...", panelIntervalFractionalMedians.length));
            pon.writePanelIntervalFractionalMedians(panelIntervalFractionalMedians);
            int numPanelSamples = preprocessedStandardizedResult.preprocessedStandardizedValues.getRowDimension();
            int numPanelIntervals = preprocessedStandardizedResult.preprocessedStandardizedValues.getColumnDimension();
            int numEigensamples = Math.min(numEigensamplesRequested, numPanelSamples);
            if (numEigensamples < numEigensamplesRequested) {
                logger.warn(String.format("%d eigensamples were requested but only %d are available in the panel of normals...", numEigensamplesRequested, numEigensamples));
            }
            logger.info(String.format("Performing SVD (truncated at %d eigensamples) of standardized counts (transposed to %d x %d)...", numEigensamples, numPanelIntervals, numPanelSamples));
            if (numPanelSamples > 1 && numEigensamples > 0) {
                SingularValueDecomposition svd = SparkConverter.convertRealMatrixToSparkRowMatrix(ctx, preprocessedStandardizedResult.preprocessedStandardizedValues.transpose(), 100).computeSVD(numEigensamples, true, 1.0E-9);
                double[] singularValues = svd.s().toArray();
                if (singularValues.length == 0 || Arrays.stream(singularValues).noneMatch(s -> s > 1.0E-9)) {
                    throw new UserException(String.format("No non-zero singular values were found.  It may be necessary to use stricter parameters for filtering.  For example, use a larger value of %s.", "minimum-interval-median-percentile"));
                }
                if (singularValues.length < numEigensamples) {
                    logger.warn(String.format("Attempted to truncate at %d eigensamples, but only %d non-zero singular values were found...", numEigensamples, singularValues.length));
                }
                double[][] eigensampleVectors = SparkConverter.convertSparkRowMatrixToRealMatrix((RowMatrix)svd.U(), numPanelIntervals).getData();
                logger.info(String.format("Writing singular values (%d)...", singularValues.length));
                pon.writeSingularValues(singularValues);
                logger.info(String.format("Writing eigensample vectors (transposed to %d x %d)...", eigensampleVectors[0].length, eigensampleVectors.length));
                pon.writeEigensampleVectors(eigensampleVectors, maximumChunkSize);
            } else {
                logger.info("No eigensamples could be computed because only a single sample was provided or no eigensamples were requested.");
            }
        }
        catch (RuntimeException exception) {
            logger.warn(String.format("Exception encountered during creation of panel of normals (%s).  Attempting to delete partial output in %s...", exception, outFile.getAbsolutePath()));
            IOUtils.tryDelete(outFile);
            throw new GATKException(String.format("Could not create panel of normals.  It may be necessary to use stricter parameters for filtering.  For example, use a larger value of %s.", "minimum-interval-median-percentile"), exception);
        }
        logger.info(String.format("Read-count panel of normals written to %s.", outFile.getAbsolutePath()));
    }

    private void writeVersion(double version) {
        this.file.makeDouble(VERSION_PATH, version);
    }

    private void writeCommandLine(String commandLine) {
        this.file.makeStringArray(COMMAND_LINE_PATH, new String[]{commandLine});
    }

    private void writeSequenceDictionary(SAMSequenceDictionary sequenceDictionary) {
        StringWriter stringWriter = new StringWriter();
        new SAMTextHeaderCodec().encode((Writer)stringWriter, new SAMFileHeader(sequenceDictionary));
        this.file.makeStringArray(SEQUENCE_DICTIONARY_PATH, new String[]{stringWriter.toString()});
    }

    private void writeOriginalReadCountsPath(RealMatrix originalReadCounts, int maximumChunkSize) {
        HDF5Utils.writeChunkedDoubleMatrix(this.file, ORIGINAL_READ_COUNTS_PATH, originalReadCounts.getData(), maximumChunkSize);
    }

    private void writeOriginalSampleFilenames(List<String> originalSampleFilenames) {
        this.file.makeStringArray(ORIGINAL_SAMPLE_FILENAMES_PATH, originalSampleFilenames.toArray(new String[originalSampleFilenames.size()]));
    }

    private void writeOriginalIntervals(List<SimpleInterval> originalIntervals) {
        HDF5Utils.writeIntervals(this.file, ORIGINAL_INTERVALS_PATH, originalIntervals);
    }

    private void writeOriginalIntervalGCContent(double[] originalIntervalGCContent) {
        this.file.makeDoubleArray(ORIGINAL_INTERVAL_GC_CONTENT_PATH, originalIntervalGCContent);
    }

    private void writePanelSampleFilenames(List<String> panelSampleFilenames) {
        this.file.makeStringArray(PANEL_SAMPLE_FILENAMES_PATH, panelSampleFilenames.toArray(new String[panelSampleFilenames.size()]));
    }

    private void writePanelIntervals(List<SimpleInterval> panelIntervals) {
        HDF5Utils.writeIntervals(this.file, PANEL_INTERVALS_PATH, panelIntervals);
    }

    private void writePanelIntervalFractionalMedians(double[] panelIntervalFractionalMedians) {
        this.file.makeDoubleArray(PANEL_INTERVAL_FRACTIONAL_MEDIANS_PATH, panelIntervalFractionalMedians);
    }

    private void writeSingularValues(double[] singularValues) {
        this.file.makeDoubleArray(PANEL_SINGULAR_VALUES_PATH, singularValues);
    }

    private void writeEigensampleVectors(double[][] eigensampleVectors, int maximumChunkSize) {
        HDF5Utils.writeChunkedDoubleMatrix(this.file, PANEL_EIGENSAMPLE_VECTORS_PATH, new Array2DRowRealMatrix(eigensampleVectors, false).transpose().getData(), maximumChunkSize);
    }
}

