/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.argparser.ExperimentalFeature;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.walkers.vqsr.TensorType;
import org.broadinstitute.hellbender.utils.io.Resource;
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@CommandLineProgramProperties(summary="Train a CNN model for filtering variants", oneLineSummary="Train a CNN model for filtering variants", programGroup=VariantFilteringProgramGroup.class)
@DocumentedFeature
@ExperimentalFeature
public class CNNVariantTrain
extends CommandLineProgram {
    @Argument(fullName="input-tensor-dir", shortName="input-tensor-dir", doc="Directory of training tensors to create.")
    private String inputTensorDir;
    @Argument(fullName="output-dir", shortName="output-dir", doc="Directory where models will be saved, defaults to current working directory.", optional=true)
    private String outputDir = "./";
    @Argument(fullName="tensor-type", shortName="tensor-type", doc="Type of tensors to use as input reference for 1D reference tensors and read_tensor for 2D tensors.", optional=true)
    private TensorType tensorType = TensorType.reference;
    @Argument(fullName="model-name", shortName="model-name", doc="Name of the model to be trained.", optional=true)
    private String modelName = "variant_filter_model";
    @Argument(fullName="epochs", shortName="epochs", doc="Maximum number of training epochs.", optional=true, minValue=0.0)
    private int epochs = 10;
    @Argument(fullName="training-steps", shortName="training-steps", doc="Number of training steps per epoch.", optional=true, minValue=0.0)
    private int trainingSteps = 10;
    @Argument(fullName="validation-steps", shortName="validation-steps", doc="Number of validation steps per epoch.", optional=true, minValue=0.0)
    private int validationSteps = 2;
    @Argument(fullName="image-dir", shortName="image-dir", doc="Path where plots and figures are saved.", optional=true)
    private String imageDir;
    @Argument(fullName="conv-width", shortName="conv-width", doc="Width of convolution kernels", optional=true)
    private int convWidth = 5;
    @Argument(fullName="conv-height", shortName="conv-height", doc="Height of convolution kernels", optional=true)
    private int convHeight = 5;
    @Argument(fullName="conv-dropout", shortName="conv-dropout", doc="Dropout rate in convolution layers", optional=true)
    private float convDropout = 0.0f;
    @Argument(fullName="conv-batch-normalize", shortName="conv-batch-normalize", doc="Batch normalize convolution layers", optional=true)
    private boolean convBatchNormalize = false;
    @Argument(fullName="conv-layers", shortName="conv-layers", doc="List of number of filters to use in each convolutional layer", optional=true)
    private List<Integer> convLayers = new ArrayList<Integer>();
    @Argument(fullName="padding", shortName="padding", doc="Padding for convolution layers, valid or same", optional=true)
    private String padding = "valid";
    @Argument(fullName="spatial-dropout", shortName="spatial-dropout", doc="Spatial dropout on convolution layers", optional=true)
    private boolean spatialDropout = false;
    @Argument(fullName="fc-layers", shortName="fc-layers", doc="List of number of filters to use in each fully-connected layer", optional=true)
    private List<Integer> fcLayers = new ArrayList<Integer>();
    @Argument(fullName="fc-dropout", shortName="fc-dropout", doc="Dropout rate in fully-connected layers", optional=true)
    private float fcDropout = 0.0f;
    @Argument(fullName="fc-batch-normalize", shortName="fc-batch-normalize", doc="Batch normalize fully-connected layers", optional=true)
    private boolean fcBatchNormalize = false;
    @Argument(fullName="annotation-units", shortName="annotation-units", doc="Number of units connected to the annotation input layer", optional=true)
    private int annotationUnits = 16;
    @Argument(fullName="annotation-shortcut", shortName="annotation-shortcut", doc="Shortcut connections on the annotation layers.", optional=true)
    private boolean annotationShortcut = false;
    @Advanced
    @Argument(fullName="channels-last", shortName="channels-last", doc="Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional=true)
    private boolean channelsLast = true;
    @Advanced
    @Argument(fullName="annotation-set", shortName="annotation-set", doc="Which set of annotations to use.", optional=true)
    private String annotationSet = "best_practices";
    final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true);

    @Override
    protected void onStartup() {
        PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn");
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected Object doWork() {
        Resource pythonScriptResource = new Resource("training.py", CNNVariantTrain.class);
        ArrayList<String> arguments = new ArrayList<String>(Arrays.asList("--data_dir", this.inputTensorDir, "--output_dir", this.outputDir, "--tensor_name", this.tensorType.name(), "--annotation_set", this.annotationSet, "--conv_width", Integer.toString(this.convWidth), "--conv_height", Integer.toString(this.convHeight), "--conv_dropout", Float.toString(this.convDropout), "--padding", this.padding, "--fc_dropout", Float.toString(this.fcDropout), "--annotation_units", Integer.toString(this.annotationUnits), "--epochs", Integer.toString(this.epochs), "--training_steps", Integer.toString(this.trainingSteps), "--validation_steps", Integer.toString(this.validationSteps), "--gatk_version", this.getVersion(), "--id", this.modelName));
        if (this.channelsLast) {
            arguments.add("--channels_last");
        } else {
            arguments.add("--channels_first");
        }
        if (this.imageDir != null) {
            arguments.addAll(Arrays.asList("--image_dir", this.imageDir));
        }
        if (this.convLayers.size() == 0 && this.fcLayers.size() == 0) {
            if (this.tensorType == TensorType.reference) {
                arguments.addAll(Arrays.asList("--mode", "train_default_1d_model"));
            } else {
                if (this.tensorType != TensorType.read_tensor) throw new GATKException("Unknown tensor mapping mode:" + this.tensorType.name());
                arguments.addAll(Arrays.asList("--mode", "train_default_2d_model"));
            }
        } else {
            if (this.convBatchNormalize) {
                arguments.add("--conv_batch_normalize");
            }
            if (this.fcBatchNormalize) {
                arguments.add("--fc_batch_normalize");
            }
            if (this.spatialDropout) {
                arguments.add("--spatial_dropout");
            }
            if (this.annotationShortcut) {
                arguments.add("--annotation_shortcut");
            }
            arguments.add("--conv_layers");
            for (Integer cl : this.convLayers) {
                arguments.add(Integer.toString(cl));
            }
            arguments.add("--fc_layers");
            for (Integer fl : this.fcLayers) {
                arguments.add(Integer.toString(fl));
            }
            if (this.tensorType == TensorType.reference) {
                arguments.addAll(Arrays.asList("--mode", "train_args_model_on_reference_and_annotations"));
            } else {
                if (this.tensorType != TensorType.read_tensor) throw new GATKException("Unknown tensor mapping mode:" + this.tensorType.name());
                arguments.addAll(Arrays.asList("--mode", "train_args_model_on_read_tensors_and_annotations"));
            }
        }
        this.logger.info("Args are:" + Arrays.toString(arguments.toArray()));
        boolean pythonReturnCode = this.pythonExecutor.executeScript(pythonScriptResource, null, arguments);
        return pythonReturnCode;
    }
}

