/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import com.intel.gkl.IntelGKLUtils;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFHeader;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.Hidden;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.TwoPassVariantWalker;
import org.broadinstitute.hellbender.engine.filters.CountingVariantFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadGroupBlackListReadFilter;
import org.broadinstitute.hellbender.engine.filters.VariantFilterLibrary;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.vqsr.TensorType;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsamplingIterator;
import org.broadinstitute.hellbender.utils.downsampling.ReservoirDownsampler;
import org.broadinstitute.hellbender.utils.io.IOUtils;
import org.broadinstitute.hellbender.utils.python.StreamingPythonScriptExecutor;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.runtime.AsynchronousStreamWriter;
import org.broadinstitute.hellbender.utils.variant.GATKVCFHeaderLines;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@DocumentedFeature
@CommandLineProgramProperties(summary="Annotate a VCF with scores from a Convolutional Neural Network (CNN).The CNN determines a Log Odds Score for each variant.Pre-trained models (1D or 2D) are specified via the architecture argument.1D models will look at the reference sequence and variant annotations.2D models look at aligned reads, reference sequence, and variant annotations.2D models require a BAM file as input as well as the tensor-type argument to be set.", oneLineSummary="Apply a Convolutional Neural Net to filter annotated variants", programGroup=VariantFilteringProgramGroup.class)
public class CNNScoreVariants
extends TwoPassVariantWalker {
    private static final String NL = String.format("%n", new Object[0]);
    static final String USAGE_ONE_LINE_SUMMARY = "Apply a Convolutional Neural Net to filter annotated variants";
    static final String USAGE_SUMMARY = "Annotate a VCF with scores from a Convolutional Neural Network (CNN).The CNN determines a Log Odds Score for each variant.Pre-trained models (1D or 2D) are specified via the architecture argument.1D models will look at the reference sequence and variant annotations.2D models look at aligned reads, reference sequence, and variant annotations.2D models require a BAM file as input as well as the tensor-type argument to be set.";
    static final String DISABLE_AVX_CHECK_NAME = "disable-avx-check";
    static final String AVXREQUIRED_ERROR = "This tool requires AVX instruction set support by default due to its dependency on recent versions of the TensorFlow library.\n If you have an older (pre-1.6) version of TensorFlow installed that does not require AVX you may attempt to re-run the tool with the %s argument to bypass this check.\n Note that such configurations are not officially supported.";
    private static final int CONTIG_INDEX = 0;
    private static final int POS_INDEX = 1;
    private static final int REF_INDEX = 2;
    private static final int ALT_INDEX = 3;
    private static final int KEY_INDEX = 4;
    private static final int FIFO_STRING_INITIAL_CAPACITY = 1024;
    private static final int MAX_BATCH_SIZE_1D = 1024;
    private static final int MAX_BATCH_SIZE_2D = 64;
    private static final String DATA_VALUE_SEPARATOR = ",";
    private static final String DATA_TYPE_SEPARATOR = "\t";
    private static final String ANNOTATION_SEPARATOR = ";";
    private static final String ANNOTATION_SET_STRING = "=";
    private List<String> defaultAnnotationKeys = new ArrayList<String>(Arrays.asList("MQ", "DP", "SOR", "FS", "QD", "MQRankSum", "ReadPosRankSum"));
    @Argument(fullName="output", shortName="O", doc="Output file")
    private GATKPath outputFile;
    @Argument(fullName="architecture", shortName="architecture", doc="Neural Net architecture configuration json file", optional=true)
    private String architecture;
    @Argument(fullName="weights", shortName="weights", doc="Keras model HD5 file with neural net weights.", optional=true)
    private String weights;
    @Argument(fullName="tensor-type", shortName="tensor-type", doc="Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional=true)
    private TensorType tensorType = TensorType.reference;
    @Argument(fullName="window-size", shortName="window-size", doc="Neural Net input window size", minValue=0.0, optional=true)
    private int windowSize = 128;
    @Argument(fullName="read-limit", shortName="read-limit", doc="Maximum number of reads to encode in a tensor, for 2D models only.", minValue=0.0, optional=true)
    private int readLimit = 128;
    @Argument(fullName="filter-symbolic-and-sv", shortName="filter-symbolic-and-sv", doc="If set will filter symbolic and and structural variants from the input VCF", optional=true)
    private boolean filterSymbolicAndSV = false;
    @Advanced
    @Argument(fullName="info-annotation-keys", shortName="info-annotation-keys", doc="The VCF info fields to send to python.  This should only be changed if a new model has been trained which expects the annotations provided here.", optional=true)
    private List<String> annotationKeys = this.defaultAnnotationKeys;
    @Advanced
    @Argument(fullName="inference-batch-size", shortName="inference-batch-size", doc="Size of batches for python to do inference on.", minValue=1.0, maxValue=4096.0, optional=true)
    private int inferenceBatchSize = 256;
    @Advanced
    @Argument(fullName="transfer-batch-size", shortName="transfer-batch-size", doc="Size of data to queue for python streaming.", minValue=1.0, maxValue=8192.0, optional=true)
    private int transferBatchSize = 512;
    @Advanced
    @Argument(fullName="inter-op-threads", shortName="inter-op-threads", doc="Number of inter-op parallelism threads to use for Tensorflow", minValue=0.0, maxValue=4096.0, optional=true)
    private int interOpThreads = 0;
    @Advanced
    @Argument(fullName="intra-op-threads", shortName="intra-op-threads", doc="Number of intra-op parallelism threads to use for Tensorflow", minValue=0.0, maxValue=4096.0, optional=true)
    private int intraOpThreads = 0;
    @Advanced
    @Argument(fullName="output-tensor-dir", shortName="output-tensor-dir", doc="Optional directory where tensors can be saved for debugging or visualization.", optional=true)
    private String outputTensorsDir = "";
    @Advanced
    @Argument(fullName="disable-avx-check", shortName="disable-avx-check", doc="If set, no check will be made for AVX support.  Use only if you have installed a pre-1.6 TensorFlow build. ", optional=true)
    private boolean disableAVXCheck = false;
    @Hidden
    @Argument(fullName="enable-journal", shortName="enable-journal", doc="Enable streaming process journal.", optional=true)
    private boolean enableJournal = false;
    @Hidden
    @Argument(fullName="keep-temp-file", shortName="keep-temp-file", doc="Keep the temporary file that python writes scores to.", optional=true)
    private boolean keepTempFile = false;
    @Hidden
    @Argument(fullName="python-profile", shortName="python-profile", doc="Run the tool with the Python CProfiler on and write results to this file.", optional=true)
    private File pythonProfileResults;
    final StreamingPythonScriptExecutor<String> pythonExecutor = new StreamingPythonScriptExecutor(true);
    private List<String> batchList = new ArrayList<String>(this.inferenceBatchSize);
    private int curBatchSize = 0;
    private int windowEnd = this.windowSize / 2;
    private int windowStart = this.windowSize / 2;
    private boolean waitforBatchCompletion = false;
    private File scoreFile;
    private String scoreKey;
    private Scanner scoreScan;
    private VariantContextWriter vcfWriter;
    private String annotationSetString;
    private static String resourcePathReadTensor = "large/cnn_score_variants/small_2d.json";
    private static String resourcePathReferenceTensor = "large/cnn_score_variants/1d_cnn_mix_train_full_bn.json";

    @Override
    protected String[] customCommandLineValidation() {
        if (this.tensorType.equals((Object)TensorType.read_tensor)) {
            this.transferBatchSize = Math.max(this.transferBatchSize, 64);
            this.inferenceBatchSize = Math.max(this.inferenceBatchSize, 64);
        } else if (this.tensorType.equals((Object)TensorType.reference)) {
            this.transferBatchSize = Math.max(this.transferBatchSize, 1024);
            this.inferenceBatchSize = Math.max(this.inferenceBatchSize, 1024);
        }
        if (this.inferenceBatchSize > this.transferBatchSize) {
            return new String[]{"Inference batch size must be less than or equal to transfer batch size."};
        }
        if (!(this.architecture != null && this.weights != null || this.tensorType.equals((Object)TensorType.read_tensor) || this.tensorType.equals((Object)TensorType.reference))) {
            return new String[]{"No default architecture for tensor type:" + this.tensorType.name()};
        }
        return null;
    }

    @Override
    public boolean requiresReference() {
        return true;
    }

    @Override
    protected CountingVariantFilter makeVariantFilter() {
        return new CountingVariantFilter(this.filterSymbolicAndSV ? VariantFilterLibrary.NOT_SV_OR_SYMBOLIC : VariantFilterLibrary.ALLOW_ALL_VARIANTS);
    }

    @Override
    public List<ReadFilter> getDefaultReadFilters() {
        ArrayList<ReadFilter> readFilters = new ArrayList<ReadFilter>();
        readFilters.addAll(super.getDefaultReadFilters());
        ArrayList<String> filterList = new ArrayList<String>();
        filterList.add("ID:ArtificialHaplotypeRG");
        filterList.add("ID:ArtificialHaplotype");
        readFilters.add(new ReadGroupBlackListReadFilter(filterList, null));
        return readFilters;
    }

    @Override
    public void onTraversalStart() {
        VCFHeader inputHeader;
        if (!this.disableAVXCheck) {
            IntelGKLUtils utils = new IntelGKLUtils();
            utils.load(null);
            if (!utils.isAvxSupported()) {
                throw new UserException.HardwareFeatureException(String.format(AVXREQUIRED_ERROR, DISABLE_AVX_CHECK_NAME));
            }
        }
        if ((inputHeader = this.getHeaderForVariants()).getGenotypeSamples().size() > 1) {
            this.logger.warn("CNNScoreVariants is a single sample tool but the input VCF has more than 1 sample.");
        }
        if (!this.annotationKeys.equals(this.defaultAnnotationKeys)) {
            this.logger.warn("Annotation keys are not the default you must also provide a trained model that expects these annotations.");
        }
        this.pythonExecutor.start(Collections.emptyList(), this.enableJournal, this.pythonProfileResults);
        this.pythonExecutor.initStreamWriter(AsynchronousStreamWriter.stringSerializer);
        this.batchList = new ArrayList<String>(this.transferBatchSize);
        try {
            this.scoreFile = File.createTempFile((String)this.outputFile.getBaseName().get(), ".temp");
            if (!this.keepTempFile) {
                this.scoreFile.deleteOnExit();
            } else {
                this.logger.info("Saving temp file from python:" + this.scoreFile.getAbsolutePath());
            }
            this.pythonExecutor.sendSynchronousCommand(String.format("tempFile = open('%s', 'w+')" + NL, this.scoreFile.getAbsolutePath()));
            this.pythonExecutor.sendSynchronousCommand("import vqsr_cnn" + NL);
            this.scoreKey = this.getScoreKeyAndCheckModelAndReadsHarmony();
            this.annotationSetString = this.annotationKeys.stream().collect(Collectors.joining(DATA_VALUE_SEPARATOR));
            this.initializePythonArgsAndModel();
        }
        catch (IOException e) {
            throw new GATKException("Error when creating temp file and initializing python executor.", e);
        }
    }

    @Override
    public void firstPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
        referenceContext.setWindow(this.windowStart, this.windowEnd);
        if (this.tensorType.isReadsRequired()) {
            this.transferReadsToPythonViaFifo(variant, readsContext, referenceContext);
        } else {
            this.transferToPythonViaFifo(variant, referenceContext);
        }
        this.sendBatchIfReady();
    }

    @Override
    public void afterFirstPass() {
        if (this.waitforBatchCompletion) {
            this.pythonExecutor.waitForPreviousBatchCompletion();
        }
        if (this.curBatchSize > 0) {
            this.executePythonCommand();
            this.pythonExecutor.waitForPreviousBatchCompletion();
        }
        this.pythonExecutor.sendSynchronousCommand("tempFile.close()" + NL);
        this.pythonExecutor.terminate();
        try {
            this.scoreScan = new Scanner(this.scoreFile);
            this.vcfWriter = this.createVCFWriter(this.outputFile);
            this.scoreScan.useDelimiter("\\n");
            this.writeVCFHeader(this.vcfWriter);
        }
        catch (IOException e) {
            throw new GATKException("Error when trying to temporary score file scanner.", e);
        }
    }

    @Override
    protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
        VariantContextBuilder builder;
        String sv = this.scoreScan.nextLine();
        String[] scoredVariant = sv.split("\\t");
        if (variant.getContig().equals(scoredVariant[0]) && Integer.toString(variant.getStart()).equals(scoredVariant[1]) && variant.getReference().getBaseString().equals(scoredVariant[2]) && variant.getAlternateAlleles().toString().equals(scoredVariant[3])) {
            builder = new VariantContextBuilder(variant);
            if (scoredVariant.length > 4) {
                builder.attribute(this.scoreKey, (Object)scoredVariant[4]);
            }
        } else {
            String errorMsg = "Score file out of sync with original VCF. Score file has:" + sv;
            errorMsg = errorMsg + "\n But VCF has:" + variant.toStringWithoutGenotypes();
            throw new GATKException(errorMsg);
        }
        this.vcfWriter.add(builder.make());
    }

    @Override
    public void closeTool() {
        this.logger.info("Done scoring variants with CNN.");
        if (this.vcfWriter != null) {
            this.vcfWriter.close();
        }
        if (this.scoreScan != null) {
            this.scoreScan.close();
        }
    }

    private void transferToPythonViaFifo(VariantContext variant, ReferenceContext referenceContext) {
        try {
            String outDat = String.format("%s%s%s%s%s%s%s\n", this.getVariantDataString(variant), DATA_TYPE_SEPARATOR, new String(Arrays.copyOfRange(referenceContext.getBases(), 0, this.windowSize), "UTF-8"), DATA_TYPE_SEPARATOR, this.getVariantInfoString(variant), DATA_TYPE_SEPARATOR, variant.isSNP() ? "SNP" : (variant.isIndel() ? "INDEL" : "OTHER"));
            this.batchList.add(outDat);
            ++this.curBatchSize;
        }
        catch (UnsupportedEncodingException e) {
            throw new GATKException("Trying to make string from reference, but unsupported encoding UTF-8.", e);
        }
    }

    private void sendBatchIfReady() {
        if (this.curBatchSize == this.transferBatchSize) {
            if (this.waitforBatchCompletion) {
                this.pythonExecutor.waitForPreviousBatchCompletion();
                this.waitforBatchCompletion = false;
            }
            this.executePythonCommand();
            this.waitforBatchCompletion = true;
            this.curBatchSize = 0;
            this.batchList = new ArrayList<String>(this.transferBatchSize);
        }
    }

    private void transferReadsToPythonViaFifo(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext) {
        StringBuilder sb = new StringBuilder(1024);
        try {
            sb.append(String.format("%s%s%s%s%s%s%s%s", this.getVariantDataString(variant), DATA_TYPE_SEPARATOR, new String(Arrays.copyOfRange(referenceContext.getBases(), 0, this.windowSize), "UTF-8"), DATA_TYPE_SEPARATOR, this.getVariantInfoString(variant), DATA_TYPE_SEPARATOR, variant.isSNP() ? "SNP" : (variant.isIndel() ? "INDEL" : "OTHER"), DATA_TYPE_SEPARATOR));
        }
        catch (UnsupportedEncodingException e) {
            throw new GATKException("Trying to make string from reference, but unsupported encoding UTF-8.", e);
        }
        ReadsDownsamplingIterator readIt = new ReadsDownsamplingIterator(readsContext.iterator(), new ReservoirDownsampler(this.readLimit));
        if (!readIt.hasNext()) {
            this.logger.warn("No reads at contig:" + variant.getContig() + " site:" + String.valueOf(variant.getStart()));
        }
        while (readIt.hasNext()) {
            sb.append(this.GATKReadToString((GATKRead)readIt.next()));
        }
        sb.append(NL);
        this.batchList.add(sb.toString());
        ++this.curBatchSize;
    }

    private String GATKReadToString(GATKRead read) {
        StringBuilder sb = new StringBuilder(1024);
        sb.append(read.getBasesString() + DATA_TYPE_SEPARATOR);
        this.appendQualityBytes(sb, read.getBaseQualities());
        sb.append(read.getCigar().toString() + DATA_TYPE_SEPARATOR);
        sb.append(read.isReverseStrand() + DATA_TYPE_SEPARATOR);
        sb.append((read.isPaired() ? Boolean.valueOf(read.mateIsReverseStrand()) : "false") + DATA_TYPE_SEPARATOR);
        sb.append(read.isFirstOfPair() + DATA_TYPE_SEPARATOR);
        sb.append(read.getMappingQuality() + DATA_TYPE_SEPARATOR);
        sb.append(Integer.toString(read.getUnclippedStart()) + DATA_TYPE_SEPARATOR);
        return sb.toString();
    }

    private void appendQualityBytes(StringBuilder sb, byte[] qualities) {
        if (qualities.length == 0) {
            sb.append(DATA_TYPE_SEPARATOR);
            return;
        }
        for (int i = 0; i < qualities.length - 1; ++i) {
            sb.append(Integer.toString(qualities[i]) + DATA_VALUE_SEPARATOR);
        }
        sb.append(Integer.toString(qualities[qualities.length - 1]) + DATA_TYPE_SEPARATOR);
    }

    private String getVariantDataString(VariantContext variant) {
        return String.format("%s%s%d%s%s%s%s", variant.getContig(), DATA_TYPE_SEPARATOR, variant.getStart(), DATA_TYPE_SEPARATOR, variant.getReference().getBaseString(), DATA_TYPE_SEPARATOR, variant.getAlternateAlleles().toString());
    }

    private String getVariantInfoString(VariantContext variant) {
        StringBuilder sb = new StringBuilder(1024);
        for (String attributeKey : this.annotationKeys) {
            if (!variant.hasAttribute(attributeKey)) continue;
            sb.append(attributeKey);
            sb.append(ANNOTATION_SET_STRING);
            sb.append(variant.getAttributeAsString(attributeKey, "0"));
            sb.append(ANNOTATION_SEPARATOR);
        }
        return sb.toString();
    }

    private void executePythonCommand() {
        String pythonCommand = String.format("vqsr_cnn.score_and_write_batch(model, tempFile, %d, %d, '%s', '%s', %d, %d, '%s')", new Object[]{this.curBatchSize, this.inferenceBatchSize, this.tensorType, this.annotationSetString, this.windowSize, this.readLimit, this.outputTensorsDir}) + NL;
        this.pythonExecutor.startBatchWrite(pythonCommand, this.batchList);
    }

    private void writeVCFHeader(VariantContextWriter vcfWriter) {
        VCFHeader inputHeader = this.getHeaderForVariants();
        Set inputHeaders = inputHeader.getMetaDataInSortedOrder();
        HashSet<Object> hInfo = new HashSet<Object>(inputHeaders);
        hInfo.add(GATKVCFHeaderLines.getInfoLine(this.scoreKey));
        TreeSet samples = new TreeSet();
        samples.addAll(inputHeader.getGenotypeSamples());
        hInfo.addAll(this.getDefaultToolVCFHeaderLines());
        VCFHeader vcfHeader = new VCFHeader(hInfo, samples);
        vcfWriter.writeHeader(vcfHeader);
    }

    private String getScoreKeyAndCheckModelAndReadsHarmony() {
        if (this.tensorType.isReadsRequired() && this.hasReads()) {
            return "CNN_2D";
        }
        if (!this.tensorType.isReadsRequired() && this.hasReads()) {
            this.logger.warn(String.format("Reads are available, but tensor type %s does not use them.", this.tensorType.name()));
            return "CNN_1D";
        }
        if (!this.tensorType.isReadsRequired()) {
            return "CNN_1D";
        }
        throw new GATKException("2D Models require a SAM/BAM file specified via -I (-input) argument.");
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void initializePythonArgsAndModel() {
        if (this.architecture == null && this.weights == null) {
            if (this.tensorType.equals((Object)TensorType.read_tensor)) {
                this.architecture = IOUtils.writeTempResourceFromPath(resourcePathReadTensor, null).getAbsolutePath();
                this.weights = IOUtils.writeTempResourceFromPath(resourcePathReadTensor.replace(".json", ".hd5"), null).getAbsolutePath();
            } else {
                if (!this.tensorType.equals((Object)TensorType.reference)) throw new GATKException("No default architecture for tensor type:" + this.tensorType.name());
                this.architecture = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor, null).getAbsolutePath();
                this.weights = IOUtils.writeTempResourceFromPath(resourcePathReferenceTensor.replace(".json", ".hd5"), null).getAbsolutePath();
            }
        } else if (this.weights == null) {
            this.weights = this.architecture.replace(".json", ".hd5");
        } else if (this.architecture == null) {
            this.architecture = this.weights.replace(".hd5", ".json");
        }
        String getArgsAndModel = String.format("args, model = vqsr_cnn.start_session_get_args_and_model(%d, %d, '%s', weights_hd5='%s')", this.intraOpThreads, this.interOpThreads, this.architecture, this.weights) + NL;
        this.logger.info("Using key:" + this.scoreKey + " for CNN architecture:" + this.architecture + " and weights:" + this.weights);
        this.pythonExecutor.sendSynchronousCommand(getArgsAndModel);
    }
}

