/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.experiments;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.WeightedLabels;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.data.DataOptions;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.util.Util;

public class ConfigurableTrainTest {
    private static final Logger logger = Logger.getLogger(ConfigurableTrainTest.class.getName());

    public static Map<Label, Float> processWeights(List<String> input) {
        HashMap<Label, Float> map = new HashMap<Label, Float>();
        for (String tuple : input) {
            String[] splitTuple = tuple.split(":");
            map.put(new Label(splitTuple[0]), Float.valueOf(Float.parseFloat(splitTuple[1])));
        }
        return map;
    }

    public static void main(String[] args) {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        ConfigurableTrainTestOptions o = new ConfigurableTrainTestOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null) {
            logger.info(cm.usage());
            System.exit(1);
        }
        Pair data = null;
        try {
            data = o.general.load((OutputFactory)new LabelFactory());
        }
        catch (IOException e) {
            logger.log(Level.SEVERE, "Failed to load data", e);
            System.exit(1);
        }
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        if (o.trainer == null) {
            logger.warning("No trainer supplied");
            logger.info(cm.usage());
            System.exit(1);
        }
        logger.info("Trainer is " + o.trainer.toString());
        if (o.weights != null) {
            Map<Label, Float> weightsMap = ConfigurableTrainTest.processWeights(o.weights);
            if (o.trainer instanceof WeightedLabels) {
                ((WeightedLabels)o.trainer).setLabelWeights(weightsMap);
                logger.info("Setting label weights using " + weightsMap.toString());
            } else if (o.trainer instanceof WeightedExamples) {
                ((MutableDataset)train).setWeights(weightsMap);
                logger.info("Setting example weights using " + weightsMap.toString());
            } else {
                logger.warning("The selected trainer does not support weighted training. The chosen trainer is " + o.trainer.toString());
                logger.info(cm.usage());
                System.exit(1);
            }
        }
        logger.info("Labels are " + train.getOutputInfo().toReadableString());
        long trainStart = System.currentTimeMillis();
        Model model = o.trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training classifier " + Util.formatDuration((long)trainStart, (long)trainStop));
        LabelEvaluator labelEvaluator = new LabelEvaluator();
        long testStart = System.currentTimeMillis();
        List predictions = model.predict(test);
        LabelEvaluation labelEvaluation = (LabelEvaluation)labelEvaluator.evaluate(model, predictions, (DataProvenance)test.getProvenance());
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(labelEvaluation.toString());
        ConfusionMatrix matrix = labelEvaluation.getConfusionMatrix();
        System.out.println(matrix.toString());
        if (model.generatesProbabilities()) {
            System.out.println("Average AUC = " + labelEvaluation.averageAUCROC(false));
            System.out.println("Average weighted AUC = " + labelEvaluation.averageAUCROC(true));
        }
        if (o.predictionPath != null) {
            try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath, new OpenOption[0]);){
                List labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
                wrt.write("Label,");
                wrt.write(String.join((CharSequence)",", labels));
                wrt.newLine();
                for (Prediction pred : predictions) {
                    Example ex = pred.getExample();
                    wrt.write(((Label)ex.getOutput()).getLabel() + ",");
                    wrt.write(labels.stream().map(l -> Double.toString(((Label)pred.getOutputScores().get(l)).getScore())).collect(Collectors.joining(",")));
                    wrt.newLine();
                }
                wrt.flush();
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "Error writing predictions", e);
            }
        }
        if (o.general.outputPath != null) {
            try {
                o.general.saveModel(model);
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "Error writing model", e);
            }
        }
    }

    public static class ConfigurableTrainTestOptions
    implements Options {
        public DataOptions general;
        @Option(charName=116, longName="trainer", usage="Load a trainer from the config file.")
        public Trainer<Label> trainer;
        @Option(charName=119, longName="weights", usage="A list of weights to use in classification. Format = LABEL_NAME:weight,LABEL_NAME:weight...")
        public List<String> weights;
        @Option(charName=111, longName="predictions", usage="Path to write model predictions")
        public Path predictionPath;

        public String getOptionsDescription() {
            return "Loads a Trainer (and optionally a Datasource) from a config file, trains a Model, tests it and optionally saves it to disk.";
        }
    }
}

