/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.pipelines.metrics;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.metrics.MetricsFile;
import htsjdk.samtools.util.Histogram;
import htsjdk.samtools.util.SequenceUtil;
import java.io.File;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.filters.MetricsReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.metrics.MetricsUtils;
import org.broadinstitute.hellbender.utils.R.RScriptExecutor;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.read.ReadUtils;
import picard.cmdline.programgroups.DiagnosticsAndQCProgramGroup;

@DocumentedFeature
@CommandLineProgramProperties(summary="Program to chart quality score distributions in a SAM/BAM file.", oneLineSummary="QualityScoreDistribution on Spark", programGroup=DiagnosticsAndQCProgramGroup.class)
@BetaFeature
public final class QualityScoreDistributionSpark
extends GATKSparkTool {
    private static final long serialVersionUID = 1L;
    @Argument(doc="uri for the output file: a local file path", shortName="O", fullName="output", optional=true)
    public String out;
    @Argument(shortName="C", fullName="chart", doc="A file (with .pdf extension) to write the chart to.", optional=true)
    public File chartOutput;
    @Argument(shortName="A", fullName="alignedReadsOnly", doc="If set to true calculate mean quality over aligned reads only.")
    public boolean alignedReadsOnly = false;
    @Argument(shortName="F", fullName="pfReadsOnly", doc="If set to true calculate mean quality over PF reads only.")
    public boolean pfReadsOnly = false;
    @Argument(shortName="NC", fullName="includeNoCalls", doc="If set to true, include quality for no-call bases in the distribution.")
    public boolean includeNoCalls = false;

    @Override
    public List<ReadFilter> getDefaultReadFilters() {
        return Collections.singletonList(ReadFilterLibrary.ALLOW_ALL_READS);
    }

    @Override
    protected void runTool(JavaSparkContext ctx) {
        JavaRDD<GATKRead> reads = this.getReads();
        MetricsReadFilter metricsFilter = new MetricsReadFilter(this.pfReadsOnly, this.alignedReadsOnly);
        JavaRDD filteredReads = reads.filter((Function & Serializable)read -> metricsFilter.test((GATKRead)read));
        Counts result = (Counts)filteredReads.aggregate((Object)new Counts(this.includeNoCalls), (Function2 & Serializable)(counts, read) -> counts.addRead((GATKRead)read), (Function2 & Serializable)(counts1, counts2) -> counts1.merge((Counts)counts2));
        MetricsFile<?, Byte> metrics = this.makeMetrics(result);
        this.saveResults(metrics, this.getHeaderForReads(), this.getReadSourceName());
    }

    private MetricsFile<?, Byte> makeMetrics(Counts result) {
        Histogram qHisto = new Histogram("QUALITY", "COUNT_OF_Q");
        Histogram oqHisto = new Histogram("QUALITY", "COUNT_OF_OQ");
        for (int i = 0; i < result.qCounts.length; ++i) {
            if (result.qCounts[i] > 0L) {
                qHisto.increment((Comparable)Byte.valueOf((byte)i), (double)result.qCounts[i]);
            }
            if (result.oqCounts[i] <= 0L) continue;
            oqHisto.increment((Comparable)Byte.valueOf((byte)i), (double)result.oqCounts[i]);
        }
        MetricsFile metrics = this.getMetricsFile();
        metrics.addHistogram(qHisto);
        if (!oqHisto.isEmpty()) {
            metrics.addHistogram(oqHisto);
        }
        return metrics;
    }

    private void saveResults(MetricsFile<?, Byte> metrics, SAMFileHeader readsHeader, String inputFileName) {
        MetricsUtils.saveMetrics(metrics, this.out);
        if (metrics.getAllHistograms().isEmpty()) {
            this.logger.warn("No valid bases found in input file.");
        } else if (this.chartOutput != null) {
            String plotSubtitle = "";
            List readGroups = readsHeader.getReadGroups();
            if (readGroups.size() == 1 && null == (plotSubtitle = ((SAMReadGroupRecord)readGroups.get(0)).getLibrary())) {
                plotSubtitle = "";
            }
            RScriptExecutor executor = new RScriptExecutor();
            executor.addScript(QualityScoreDistributionSpark.getQualityScoreDistributionRScriptResource());
            executor.addArgs(this.out, this.chartOutput.getAbsolutePath(), inputFileName, plotSubtitle);
            executor.exec();
        }
    }

    @VisibleForTesting
    static Resource getQualityScoreDistributionRScriptResource() {
        String R_SCRIPT = "qualityScoreDistribution.R";
        return new Resource("qualityScoreDistribution.R", QualityScoreDistributionSpark.class);
    }

    @VisibleForTesting
    static final class Counts
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public static final int MAX_BASE_QUALITY = 127;
        private final long[] qCounts = new long[128];
        private final long[] oqCounts = new long[128];
        private final boolean includeNoCalls;

        Counts(boolean includeNoCalls) {
            this.includeNoCalls = includeNoCalls;
        }

        Counts addRead(GATKRead read) {
            byte[] bases = read.getBases();
            byte[] quals = read.getBaseQualities();
            byte[] oq = ReadUtils.getOriginalBaseQualities(read);
            int length = quals.length;
            for (int i = 0; i < length; ++i) {
                if (!this.includeNoCalls && SequenceUtil.isNoCall((byte)bases[i])) continue;
                byte by = quals[i];
                this.qCounts[by] = this.qCounts[by] + 1L;
                if (oq == null) continue;
                byte by2 = oq[i];
                this.oqCounts[by2] = this.oqCounts[by2] + 1L;
            }
            return this;
        }

        Counts merge(Counts counts2) {
            int i;
            for (i = 0; i <= 127; ++i) {
                int n = i;
                this.qCounts[n] = this.qCounts[n] + counts2.qCounts[i];
            }
            for (i = 0; i <= 127; ++i) {
                int n = i;
                this.oqCounts[n] = this.oqCounts[n] + counts2.oqCounts[i];
            }
            return this;
        }

        long[] getQualCounts() {
            return this.qCounts;
        }

        long[] getOrigQualCounts() {
            return this.oqCounts;
        }
    }
}

