/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.datasource.LibSVMDataSource;
import org.tribuo.interop.tensorflow.DenseFeatureConverter;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.GradientOptimiser;
import org.tribuo.interop.tensorflow.ImageConverter;
import org.tribuo.interop.tensorflow.LabelConverter;
import org.tribuo.interop.tensorflow.TensorFlowCheckpointModel;
import org.tribuo.interop.tensorflow.TensorFlowNativeModel;
import org.tribuo.interop.tensorflow.TensorFlowTrainer;
import org.tribuo.util.Util;

public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    private static Pair<Dataset<Label>, Dataset<Label>> load(Path trainingPath, Path testingPath, OutputFactory<Label> outputFactory) throws IOException {
        logger.info(String.format("Loading data from %s", trainingPath));
        LibSVMDataSource trainSource = new LibSVMDataSource(trainingPath, outputFactory);
        MutableDataset train = new MutableDataset((DataSource)trainSource);
        boolean zeroIndexed = trainSource.isZeroIndexed();
        int maxFeatureID = trainSource.getMaxFeatureID();
        logger.info(String.format("Loaded %d training examples for %s", train.size(), train.getOutputs().toString()));
        logger.info("Found " + train.getFeatureIDMap().size() + " features");
        ImmutableDataset test = new ImmutableDataset((DataSource)new LibSVMDataSource(testingPath, outputFactory, zeroIndexed, maxFeatureID), (FeatureMap)train.getFeatureIDMap(), (OutputInfo)train.getOutputIDInfo(), false);
        logger.info(String.format("Loaded %d testing examples", test.size()));
        return new Pair((Object)train, (Object)test);
    }

    public static void main(String[] args) throws IOException {
        TensorFlowTrainer<Label> trainer;
        FeatureConverter inputConverter;
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        TensorflowOptions o = new TensorflowOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.trainingPath == null || o.testingPath == null) {
            logger.info(cm.usage());
            return;
        }
        Pair<Dataset<Label>, Dataset<Label>> data = TrainTest.load(o.trainingPath, o.testingPath, (OutputFactory<Label>)new LabelFactory());
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        if (o.inputName == null || o.inputName.isEmpty() || o.outputName == null || o.outputName.isEmpty()) {
            throw new IllegalArgumentException("Must specify both 'input-name' and 'output-name'");
        }
        switch (o.inputType) {
            case IMAGE: {
                String[] splitFormat = o.imageFormat.split(",");
                if (splitFormat.length != 3) {
                    logger.info(cm.usage());
                    logger.info("Invalid image format specified. Found " + o.imageFormat);
                    return;
                }
                int width = Integer.parseInt(splitFormat[0]);
                int height = Integer.parseInt(splitFormat[1]);
                int channels = Integer.parseInt(splitFormat[2]);
                inputConverter = new ImageConverter(o.inputName, width, height, channels);
                break;
            }
            case DENSE: {
                inputConverter = new DenseFeatureConverter(o.inputName);
                break;
            }
            default: {
                logger.info(cm.usage());
                logger.info("Unknown input type. Found " + (Object)((Object)o.inputType));
                return;
            }
        }
        LabelConverter labelConverter = new LabelConverter();
        if (o.checkpointPath == null) {
            logger.info("Using TensorflowTrainer");
            trainer = new TensorFlowTrainer<Label>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval);
        } else {
            logger.info("Using TensorflowCheckpointTrainer, writing to path " + o.checkpointPath);
            trainer = new TensorFlowTrainer<Label>(o.protobufPath, o.outputName, o.optimiser, o.getGradientParams(), inputConverter, labelConverter, o.batchSize, o.epochs, o.testBatchSize, o.loggingInterval, o.checkpointPath);
        }
        logger.info("Training using " + ((Object)trainer).toString());
        long trainStart = System.currentTimeMillis();
        Model model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        long testStart = System.currentTimeMillis();
        LabelEvaluator evaluator = new LabelEvaluator();
        LabelEvaluation evaluation = (LabelEvaluation)evaluator.evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        if (model.generatesProbabilities()) {
            logger.info("Average AUC = " + evaluation.averageAUCROC(false));
            logger.info("Average weighted AUC = " + evaluation.averageAUCROC(true));
        }
        System.out.println(evaluation.toString());
        System.out.println(evaluation.getConfusionMatrix().toString());
        if (o.outputPath != null) {
            if (o.saveToProto) {
                model.serializeToFile(o.outputPath);
            } else {
                try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(o.outputPath, new OpenOption[0]));){
                    oos.writeObject(model);
                }
            }
            logger.info("Serialized model to file: " + o.outputPath);
        }
        if (o.checkpointPath == null) {
            ((TensorFlowNativeModel)model).close();
        } else {
            ((TensorFlowCheckpointModel)model).close();
        }
    }

    public static class TensorflowOptions
    implements Options {
        private static List<String> DEFAULT_PARAM_NAMES = new ArrayList<String>();
        private static List<Float> DEFAULT_PARAM_VALUES = new ArrayList<Float>();
        @Option(charName=102, longName="model-output-path", usage="Path to serialize model to.")
        public Path outputPath;
        @Option(longName="model-save-to-proto", usage="Save the Tribuo model out as a protobuf.")
        public boolean saveToProto;
        @Option(charName=117, longName="training-file", usage="Path to the libsvm format training file.")
        public Path trainingPath;
        @Option(charName=118, longName="testing-file", usage="Path to the libsvm format testing file.")
        public Path testingPath;
        @Option(charName=108, longName="output-name", usage="Name of the output operation.")
        public String outputName;
        @Option(longName="optimizer-param-names", usage="Gradient optimizer param names, see org.tribuo.interop.tensorflow.GradientOptimiser.")
        public List<String> gradientParamNames = DEFAULT_PARAM_NAMES;
        @Option(longName="optimizer-param-values", usage="Gradient optimizer param values, see org.tribuo.interop.tensorflow.GradientOptimiser.")
        public List<Float> gradientParamValues = DEFAULT_PARAM_VALUES;
        @Option(charName=103, longName="gradient-optimizer", usage="The gradient optimizer to use.")
        public GradientOptimiser optimiser = GradientOptimiser.ADAGRAD;
        @Option(longName="test-batch-size", usage="Test time minibatch size.")
        public int testBatchSize = 16;
        @Option(charName=98, longName="batch-size", usage="Minibatch size.")
        public int batchSize = 128;
        @Option(charName=101, longName="num-epochs", usage="Number of gradient descent epochs.")
        public int epochs = 5;
        @Option(longName="logging-interval", usage="Interval between logging the loss.")
        public int loggingInterval = 1000;
        @Option(charName=110, longName="input-name", usage="Name of the input placeholder.")
        public String inputName;
        @Option(longName="image-format", usage="Image format, in [W,H,C]. Defaults to MNIST.")
        public String imageFormat = "28,28,1";
        @Option(charName=116, longName="input-type", usage="Input type.")
        public InputType inputType = InputType.IMAGE;
        @Option(charName=109, longName="model-protobuf", usage="Path to the protobuf containing the network description.")
        public Path protobufPath;
        @Option(charName=112, longName="checkpoint-dir", usage="Path to the checkpoint base directory.")
        public Path checkpointPath;

        public String getOptionsDescription() {
            return "Trains and tests a Tensorflow classification model.";
        }

        public Map<String, Float> getGradientParams() {
            if (this.gradientParamNames.size() != this.gradientParamValues.size()) {
                throw new IllegalArgumentException("Must supply both name and value for the gradient parameters, found " + this.gradientParamNames.size() + " names, and " + this.gradientParamValues.size() + "values.");
            }
            HashMap<String, Float> output = new HashMap<String, Float>();
            for (int i = 0; i < this.gradientParamNames.size(); ++i) {
                output.put(this.gradientParamNames.get(i), this.gradientParamValues.get(i));
            }
            return output;
        }

        static {
            DEFAULT_PARAM_NAMES.add("learningRate");
            DEFAULT_PARAM_NAMES.add("initialAccumulatorValue");
            DEFAULT_PARAM_VALUES.add(Float.valueOf(0.01f));
            DEFAULT_PARAM_VALUES.add(Float.valueOf(0.1f));
        }
    }

    public static enum InputType {
        DENSE,
        IMAGE;

    }
}

