/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.experiments;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.Trainer;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.data.DataOptions;

public class RunAll {
    private static final Logger logger = Logger.getLogger(RunAll.class.getName());

    public static void main(String[] args) throws IOException {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        RunAllOptions o = new RunAllOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null || o.directory == null) {
            logger.info(cm.usage());
            System.exit(1);
        }
        Pair data = null;
        try {
            data = o.general.load((OutputFactory)new LabelFactory());
        }
        catch (IOException e) {
            logger.log(Level.SEVERE, "Failed to load data", e);
            System.exit(1);
        }
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        logger.info("Creating directory - " + o.directory.toString());
        if (!o.directory.exists() && !o.directory.mkdirs()) {
            logger.warning("Failed to create directory.");
        }
        HashMap<String, Double> performances = new HashMap<String, Double>();
        List trainers = cm.lookupAll(Trainer.class);
        for (Trainer trainer : trainers) {
            String name = trainer.getClass().getSimpleName();
            logger.info("Training model using " + trainer.toString());
            Model curModel = trainer.train(train);
            LabelEvaluator evaluator = new LabelEvaluator();
            LabelEvaluation evaluation = (LabelEvaluation)evaluator.evaluate(curModel, test);
            Double old = performances.put(name, evaluation.microAveragedF1());
            if (old != null) {
                logger.info("Found two trainers with the name " + name);
            }
            String outputPath = o.directory.toString() + "/" + name;
            if (o.protobuf) {
                curModel.serializeToFile(Paths.get(outputPath + ".model", new String[0]));
            } else {
                try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputPath + ".model"));){
                    oos.writeObject(curModel);
                }
            }
            try (PrintWriter writer = new PrintWriter(new OutputStreamWriter((OutputStream)new FileOutputStream(outputPath + ".output"), StandardCharsets.UTF_8));){
                writer.println("Model = " + name);
                writer.println("Provenance = " + curModel.toString());
                writer.println();
                ConfusionMatrix matrix = evaluation.getConfusionMatrix();
                writer.println("ConfusionMatrix:\n" + matrix.toString());
                writer.println();
                writer.println("Evaluation:\n" + evaluation.toString());
            }
        }
        for (Map.Entry entry : performances.entrySet()) {
            logger.info("Trainer = " + (String)entry.getKey() + ", F1 = " + entry.getValue());
        }
    }

    public static class RunAllOptions
    implements Options {
        public DataOptions general;
        @Option(charName=100, longName="output-directory", usage="Directory to write out the models and test reports.")
        public File directory;
        @Option(longName="write-protobuf-models", usage="Write out models in protobuf format.")
        public boolean protobuf;

        public String getOptionsDescription() {
            return "Performs the same training and test experiment on all Trainers in the supplied configuration file.";
        }
    }
}

