/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.readorientation;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.metrics.Header;
import htsjdk.samtools.metrics.MetricsFile;
import htsjdk.samtools.util.CloserUtil;
import htsjdk.samtools.util.Histogram;
import htsjdk.samtools.util.IOUtil;
import htsjdk.samtools.util.SequenceUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.readorientation.AltSiteRecord;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactPrior;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactPriorCollection;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2CountsCollector;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2FilterConstants;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2FilterUtils;
import org.broadinstitute.hellbender.tools.walkers.readorientation.LearnReadOrientationModelEngine;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ReadOrientation;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.io.IOUtils;

@CommandLineProgramProperties(summary="Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", oneLineSummary="Get the maximum likelihood estimates of artifact prior probabilities in the orientation bias mixture model filter", programGroup=ShortVariantDiscoveryProgramGroup.class)
@DocumentedFeature
public class LearnReadOrientationModel
extends CommandLineProgram {
    public static final double DEFAULT_CONVERGENCE_THRESHOLD = 1.0E-4;
    public static final int DEFAULT_MAX_ITERATIONS = 20;
    private static final int DEFAULT_INITIAL_LIST_SIZE = 1000000;
    public static final String EM_CONVERGENCE_THRESHOLD_LONG_NAME = "convergence-threshold";
    public static final String MAX_EM_ITERATIONS_LONG_NAME = "num-em-iterations";
    public static final String MAX_DEPTH_LONG_NAME = "max-depth";
    public static final String ARTIFACT_PRIOR_EXTENSION = ".orientation_priors";
    @Argument(fullName="input", shortName="I", doc="One or more .tar.gz containing outputs of CollectF1R2Counts")
    private List<File> inputTarGzs;
    @Argument(fullName="output", shortName="O", doc="tar.gz of artifact prior tables")
    private File outputTarGz;
    @Argument(fullName="convergence-threshold", doc="Stop the EM when the distance between parameters between iterations falls below this value", optional=true)
    private double convergenceThreshold = 1.0E-4;
    @Argument(fullName="num-em-iterations", doc="give up on EM after this many iterations", optional=true)
    private int maxEMIterations = 20;
    @Argument(fullName="max-depth", doc="sites with depth higher than this value will be grouped", optional=true)
    private int maxDepth = 200;
    private Map<String, List<Histogram<Integer>>> refHistogramsBySample;
    private Map<String, List<Histogram<Integer>>> altHistogramsBySample;

    @Override
    public Object doWork() {
        if (!this.outputTarGz.getAbsolutePath().endsWith(".tar.gz")) {
            throw new UserException.CouldNotCreateOutputFile(this.outputTarGz, "Output file must end in .tar.gz");
        }
        List tmpDirs = IntStream.range(0, this.inputTarGzs.size()).mapToObj(n -> IOUtils.createTempDir(Integer.toString(n))).collect(Collectors.toList());
        IntStream.range(0, this.inputTarGzs.size()).forEach(n -> IOUtils.extractTarGz(this.inputTarGzs.get(n).toPath(), ((File)tmpDirs.get(n)).toPath()));
        List refHistogramFiles = tmpDirs.stream().flatMap(dir -> F1R2CountsCollector.getRefHistogramsFromExtractedTar(dir).stream()).collect(Collectors.toList());
        List altHistogramFiles = tmpDirs.stream().flatMap(dir -> F1R2CountsCollector.getAltHistogramsFromExtractedTar(dir).stream()).collect(Collectors.toList());
        List<File> altTableFiles = tmpDirs.stream().flatMap(dir -> F1R2CountsCollector.getAltTablesFromExtractedTar(dir).stream()).collect(Collectors.toList());
        Map<String, List<MetricsFile>> refHistogramMetricsFilesBySample = refHistogramFiles.stream().map(file -> LearnReadOrientationModel.readMetricsFile(file)).collect(Collectors.groupingBy(metricsFile -> ((Header)metricsFile.getHeaders().get(0)).toString()));
        Map<String, List<MetricsFile>> altHistogramMetricsFilesBySample = altHistogramFiles.stream().map(file -> LearnReadOrientationModel.readMetricsFile(file)).collect(Collectors.groupingBy(metricsFile -> ((Header)metricsFile.getHeaders().get(0)).toString()));
        Set<String> refHistogramSamples = refHistogramMetricsFilesBySample.keySet();
        Set<String> altHistogramSamples = altHistogramMetricsFilesBySample.keySet();
        Utils.validate(altHistogramSamples.isEmpty() || refHistogramSamples.containsAll(altHistogramSamples) && altHistogramSamples.containsAll(refHistogramSamples), "ref and alt histograms must have same samples");
        Utils.validate(altHistogramSamples.isEmpty() || refHistogramSamples.stream().allMatch(sample -> ((List)refHistogramMetricsFilesBySample.get(sample)).size() == ((List)altHistogramMetricsFilesBySample.get(sample)).size()), "Each sample must have the same number of alt and ref histograms");
        this.refHistogramsBySample = refHistogramSamples.stream().collect(Collectors.toMap(sample -> sample, sample -> LearnReadOrientationModel.sumHistogramsFromFiles((List)refHistogramMetricsFilesBySample.get(sample), true)));
        this.altHistogramsBySample = altHistogramSamples.stream().collect(Collectors.toMap(sample -> sample, sample -> LearnReadOrientationModel.sumHistogramsFromFiles((List)altHistogramMetricsFilesBySample.get(sample), false)));
        Map<String, List<AltSiteRecord>> recordsBySample = LearnReadOrientationModel.gatherAltSiteRecords(altTableFiles);
        HashMap<String, ArtifactPriorCollection> artifactPriorCollectionBySample = new HashMap<String, ArtifactPriorCollection>();
        for (Map.Entry<String, List<AltSiteRecord>> entry : recordsBySample.entrySet()) {
            String sample2 = entry.getKey();
            List<AltSiteRecord> records = entry.getValue();
            Map<String, List<AltSiteRecord>> altDesignMatrixByContext = records.stream().collect(Collectors.groupingBy(AltSiteRecord::getReferenceContext));
            ArtifactPriorCollection artifactPriorCollection = new ArtifactPriorCollection(sample2);
            for (String refContext : F1R2FilterConstants.CANONICAL_KMERS) {
                String reverseComplement = SequenceUtil.reverseComplement((String)refContext);
                Histogram refHistogram = this.refHistogramsBySample.get(sample2).stream().filter(h -> h.getValueLabel().equals(refContext)).findFirst().orElseGet(() -> F1R2FilterUtils.createRefHistogram(refContext, this.maxDepth));
                Histogram refHistogramRevComp = this.refHistogramsBySample.get(sample2).stream().filter(h -> h.getValueLabel().equals(reverseComplement)).findFirst().orElseGet(() -> F1R2FilterUtils.createRefHistogram(reverseComplement, this.maxDepth));
                Histogram<Integer> combinedRefHistograms = LearnReadOrientationModel.combineRefHistogramWithRC(refContext, (Histogram<Integer>)refHistogram, (Histogram<Integer>)refHistogramRevComp, this.maxDepth);
                List<Histogram<Integer>> altDepthOneHistogramsForContext = !this.altHistogramsBySample.containsKey(sample2) ? Collections.emptyList() : this.altHistogramsBySample.get(sample2).stream().filter(h -> h.getValueLabel().startsWith(refContext)).collect(Collectors.toList());
                List<Histogram<Integer>> altDepthOneHistogramsRevComp = !this.altHistogramsBySample.containsKey(sample2) ? Collections.emptyList() : this.altHistogramsBySample.get(sample2).stream().filter(h -> h.getValueLabel().startsWith(reverseComplement)).collect(Collectors.toList());
                List<Histogram<Integer>> combinedAltHistograms = LearnReadOrientationModel.combineAltDepthOneHistogramWithRC(altDepthOneHistogramsForContext, altDepthOneHistogramsRevComp, this.maxDepth);
                List altDesignMatrix = altDesignMatrixByContext.getOrDefault(refContext, new ArrayList());
                List<AltSiteRecord> altDesignMatrixRevComp = altDesignMatrixByContext.getOrDefault(reverseComplement, Collections.emptyList());
                LearnReadOrientationModel.mergeDesignMatrices(altDesignMatrix, altDesignMatrixRevComp);
                if (combinedRefHistograms.getSumOfValues() == 0.0 || altDesignMatrix.isEmpty()) {
                    this.logger.info(String.format("Skipping the reference context %s as we didn't find either the ref or alt table for the context", refContext));
                    continue;
                }
                LearnReadOrientationModelEngine engine = new LearnReadOrientationModelEngine(combinedRefHistograms, combinedAltHistograms, altDesignMatrix, this.convergenceThreshold, this.maxEMIterations, this.maxDepth, this.logger);
                ArtifactPrior artifactPrior = engine.learnPriorForArtifactStates();
                artifactPriorCollection.set(artifactPrior);
            }
            artifactPriorCollectionBySample.put(sample2, artifactPriorCollection);
        }
        File tmpPriorDir = IOUtils.createTempDir("priors");
        for (String sample2 : artifactPriorCollectionBySample.keySet()) {
            ArtifactPriorCollection artifactPriorCollection = (ArtifactPriorCollection)artifactPriorCollectionBySample.get(sample2);
            File destination = new File(tmpPriorDir, IOUtils.urlEncode(sample2) + ARTIFACT_PRIOR_EXTENSION);
            artifactPriorCollection.writeArtifactPriors(destination);
        }
        try {
            IOUtils.writeTarGz(this.outputTarGz.getAbsolutePath(), tmpPriorDir.listFiles());
        }
        catch (IOException iOException) {
            throw new UserException.CouldNotCreateOutputFile("Could not create output .tar.gz file.", (Exception)iOException);
        }
        return "SUCCESS";
    }

    @VisibleForTesting
    public static Histogram<Integer> combineRefHistogramWithRC(String refContext, Histogram<Integer> refHistogram, Histogram<Integer> refHistogramRevComp, int maxDepth) {
        Utils.validateArg(refHistogram.getValueLabel().equals(SequenceUtil.reverseComplement((String)refHistogramRevComp.getValueLabel())), "ref context = " + refHistogram.getValueLabel() + ", rev comp = " + refHistogramRevComp.getValueLabel());
        Utils.validateArg(refHistogram.getValueLabel().equals(refContext), "this better match");
        Histogram<Integer> combinedRefHistogram = F1R2FilterUtils.createRefHistogram(refContext, maxDepth);
        for (Integer depth : refHistogram.keySet()) {
            double newCount = refHistogram.get((Comparable)depth).getValue() + refHistogramRevComp.get((Comparable)depth).getValue();
            combinedRefHistogram.increment((Comparable)depth, newCount);
        }
        return combinedRefHistogram;
    }

    @VisibleForTesting
    public static List<Histogram<Integer>> combineAltDepthOneHistogramWithRC(List<Histogram<Integer>> altHistograms, List<Histogram<Integer>> altHistogramsRevComp, int maxDepth) {
        if (altHistograms.isEmpty() && altHistogramsRevComp.isEmpty()) {
            return Collections.emptyList();
        }
        String refContext = !altHistograms.isEmpty() ? (String)F1R2FilterUtils.labelToTriplet(altHistograms.get(0).getValueLabel()).getLeft() : SequenceUtil.reverseComplement((String)((String)F1R2FilterUtils.labelToTriplet(altHistogramsRevComp.get(0).getValueLabel()).getLeft()));
        Utils.validateArg(F1R2FilterConstants.CANONICAL_KMERS.contains(refContext), "refContext must be the canonical representation but got " + refContext);
        ArrayList<Histogram<Integer>> combinedHistograms = new ArrayList<Histogram<Integer>>(F1R2FilterConstants.numAltHistogramsPerContext);
        for (Nucleotide altAllele : Nucleotide.STANDARD_BASES) {
            if (altAllele == F1R2FilterUtils.getMiddleBase(refContext)) continue;
            String reverseComplement = SequenceUtil.reverseComplement((String)refContext);
            Nucleotide altAlleleRevComp = Nucleotide.valueOf(SequenceUtil.reverseComplement((String)altAllele.toString()));
            for (ReadOrientation orientation : ReadOrientation.values()) {
                ReadOrientation otherOrientation = ReadOrientation.getOtherOrientation(orientation);
                Histogram altHistogram = altHistograms.stream().filter(h -> h.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(refContext, altAllele, orientation))).findFirst().orElseGet(() -> F1R2FilterUtils.createAltHistogram(refContext, altAllele, orientation, maxDepth));
                Histogram altHistogramRevComp = altHistogramsRevComp.stream().filter(h -> h.getValueLabel().equals(F1R2FilterUtils.tripletToLabel(reverseComplement, altAlleleRevComp, otherOrientation))).findFirst().orElseGet(() -> F1R2FilterUtils.createAltHistogram(reverseComplement, altAlleleRevComp, otherOrientation, maxDepth));
                Histogram<Integer> combinedHistogram = F1R2FilterUtils.createAltHistogram(refContext, altAllele, orientation, maxDepth);
                for (Integer depth : altHistogram.keySet()) {
                    double newCount = altHistogram.get((Comparable)depth).getValue() + altHistogramRevComp.get((Comparable)depth).getValue();
                    combinedHistogram.increment((Comparable)depth, newCount);
                }
                combinedHistograms.add(combinedHistogram);
            }
        }
        return combinedHistograms;
    }

    @VisibleForTesting
    public static void mergeDesignMatrices(List<AltSiteRecord> altDesignMatrix, List<AltSiteRecord> altDesignMatrixRevComp) {
        Optional revCompContext;
        if (altDesignMatrix.isEmpty() && altDesignMatrixRevComp.isEmpty()) {
            return;
        }
        Utils.validateArg(altDesignMatrix.isEmpty() || F1R2FilterConstants.CANONICAL_KMERS.contains(altDesignMatrix.get(0).getReferenceContext()), "altDesignMatrix must have the canonical representation");
        Optional refContext = altDesignMatrix.isEmpty() ? Optional.empty() : Optional.of(altDesignMatrix.get(0).getReferenceContext());
        Optional<Object> optional = revCompContext = altDesignMatrixRevComp.isEmpty() ? Optional.empty() : Optional.of(altDesignMatrixRevComp.get(0).getReferenceContext());
        if (refContext.isPresent() && revCompContext.isPresent()) {
            Utils.validateArg(((String)refContext.get()).equals(SequenceUtil.reverseComplement((String)((String)revCompContext.get()))), "ref context and its rev comp don't match");
        }
        altDesignMatrix.addAll(altDesignMatrixRevComp.stream().map(AltSiteRecord::getReverseComplementOfRecord).collect(Collectors.toList()));
    }

    public static MetricsFile<?, Integer> readMetricsFile(File file) {
        MetricsFile metricsFile = new MetricsFile();
        BufferedReader reader = IOUtil.openFileForBufferedReading((File)file);
        metricsFile.read((Reader)reader);
        CloserUtil.close((Object)reader);
        return metricsFile;
    }

    public static List<Histogram<Integer>> sumHistogramsFromFiles(List<MetricsFile<?, Integer>> metricsFiles, boolean ref) {
        Utils.nonNull(metricsFiles, "files may not be null");
        if (metricsFiles.isEmpty()) {
            return Collections.emptyList();
        }
        List histogramList = metricsFiles.get(0).getAllHistograms();
        if (ref) {
            Utils.validate(histogramList.size() == F1R2FilterConstants.NUM_KMERS, "The list of ref histograms need to include all kmers as enforced by CollectF1R2Counts");
            Utils.validate(histogramList.stream().allMatch(h -> F1R2FilterConstants.ALL_KMERS.contains(h.getValueLabel())), "a histogram contains an unsupported, non-kmer header");
        } else {
            Utils.validate(histogramList.size() == F1R2FilterConstants.NUM_KMERS * F1R2FilterConstants.numAltHistogramsPerContext, "The list of alt histograms missing some (kmer, alt allele, f1r2) triple");
        }
        for (int i = 1; i < metricsFiles.size(); ++i) {
            List ithHistograms = metricsFiles.get(i).getAllHistograms();
            for (Histogram jthHistogram : ithHistograms) {
                String refContext = jthHistogram.getValueLabel();
                Optional<Histogram> hist = histogramList.stream().filter(h -> h.getValueLabel().equals(refContext)).findAny();
                Utils.validate(hist.isPresent(), "Missing histogram header for: " + refContext);
                hist.get().addHistogram(jthHistogram);
            }
        }
        return histogramList;
    }

    @VisibleForTesting
    static Map<String, List<AltSiteRecord>> gatherAltSiteRecords(List<File> tables) {
        HashMap<String, List<AltSiteRecord>> result = new HashMap<String, List<AltSiteRecord>>();
        for (File table : tables) {
            Pair<String, List<AltSiteRecord>> sampleAndRecords = AltSiteRecord.readAltSiteRecords(table.toPath(), 1000000);
            String sample = (String)sampleAndRecords.getLeft();
            List records = (List)sampleAndRecords.getRight();
            if (result.containsKey(sample)) {
                ((List)result.get(sample)).addAll(records);
                continue;
            }
            result.put(sample, records);
        }
        return result;
    }
}

