/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.experiments;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.FeatureMap;
import org.tribuo.ImmutableDataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.data.DataOptions;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.TextPipeline;
import org.tribuo.data.text.impl.SimpleTextDataSource;
import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
import org.tribuo.data.text.impl.TokenPipeline;
import org.tribuo.datasource.LibSVMDataSource;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.tokens.Tokenizer;
import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;

public class Test {
    private static final Logger logger = Logger.getLogger(Test.class.getName());

    public static Pair<Model<Label>, Dataset<Label>> load(ConfigurableTestOptions o) throws IOException {
        Dataset test;
        Model tmpModel;
        Path modelPath = o.modelPath;
        Path datasetPath = o.testingPath;
        logger.info(String.format("Loading model from %s", modelPath));
        if (o.protobufModel) {
            tmpModel = Model.deserializeFromFile((Path)modelPath);
        } else {
            try (ObjectInputStream mois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(modelPath.toFile())));){
                tmpModel = (Model)mois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new IllegalArgumentException("Unknown class in serialised model", e);
            }
        }
        Model model = tmpModel.castModel(Label.class);
        logger.info(String.format("Loading data from %s", datasetPath));
        switch (o.inputFormat) {
            case SERIALIZED: {
                logger.info("Deserialising dataset from " + datasetPath);
                try (ObjectInputStream oits = new ObjectInputStream(new BufferedInputStream(new FileInputStream(datasetPath.toFile())));){
                    Dataset deserTest = (Dataset)oits.readObject();
                    test = ImmutableDataset.copyDataset((Dataset)deserTest, (ImmutableFeatureMap)model.getFeatureIDMap(), (ImmutableOutputInfo)model.getOutputIDInfo());
                    logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
                    break;
                }
                catch (ClassNotFoundException e) {
                    throw new IllegalArgumentException("Unknown class in serialised dataset", e);
                }
            }
            case SERIALIZED_PROTOBUF: {
                Dataset tmp = Dataset.deserializeFromFile((Path)datasetPath);
                if (tmp.validate(Label.class)) {
                    test = Dataset.castDataset((Dataset)tmp, Label.class);
                    test = ImmutableDataset.copyDataset((Dataset)test, (ImmutableFeatureMap)model.getFeatureIDMap(), (ImmutableOutputInfo)model.getOutputIDInfo());
                    logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
                    break;
                }
                throw new IllegalArgumentException("Invalid test dataset type, expected Label.class");
            }
            case LIBSVM: {
                boolean zeroIndexed = o.zeroIndexed;
                int maxFeatureID = model.getFeatureIDMap().size() - 1;
                LibSVMDataSource testSVMSource = new LibSVMDataSource(datasetPath, (OutputFactory)new LabelFactory(), zeroIndexed, maxFeatureID);
                test = new ImmutableDataset((DataSource)testSVMSource, model, true);
                logger.info(String.format("Loaded %d training examples for %s", test.size(), test.getOutputs().toString()));
                break;
            }
            case TEXT: {
                TextFeatureExtractorImpl extractor = o.hashDim > 0 ? new TextFeatureExtractorImpl((TextPipeline)new TokenPipeline((Tokenizer)new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting, o.hashDim)) : new TextFeatureExtractorImpl((TextPipeline)new TokenPipeline((Tokenizer)new BreakIteratorTokenizer(Locale.US), o.ngram, o.termCounting));
                SimpleTextDataSource testSource = new SimpleTextDataSource(datasetPath, (OutputFactory)new LabelFactory(), (TextFeatureExtractor)extractor);
                test = new ImmutableDataset((DataSource)testSource, (FeatureMap)model.getFeatureIDMap(), (OutputInfo)model.getOutputIDInfo(), true);
                logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
                break;
            }
            case CSV: {
                if (o.csvResponseName == null) {
                    throw new IllegalArgumentException("Please supply a response column name");
                }
                CSVLoader loader = new CSVLoader((OutputFactory)new LabelFactory());
                test = new ImmutableDataset(loader.loadDataSource(datasetPath, o.csvResponseName), (FeatureMap)model.getFeatureIDMap(), (OutputInfo)model.getOutputIDInfo(), true);
                logger.info(String.format("Loaded %d testing examples for %s", test.size(), test.getOutputs().toString()));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported input format " + o.inputFormat);
            }
        }
        return new Pair((Object)model, (Object)test);
    }

    public static void main(String[] args) {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        ConfigurableTestOptions o = new ConfigurableTestOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.modelPath == null || o.testingPath == null) {
            logger.info(cm.usage());
            System.exit(1);
        }
        Pair<Model<Label>, Dataset<Label>> loaded = null;
        try {
            loaded = Test.load(o);
        }
        catch (IOException e) {
            logger.log(Level.SEVERE, "Failed to load model/data", e);
            System.exit(1);
        }
        Model model = (Model)loaded.getA();
        Dataset test = (Dataset)loaded.getB();
        logger.info("Model is " + model.toString());
        logger.info("Labels are " + model.getOutputIDInfo().toReadableString());
        LabelEvaluator labelEvaluator = new LabelEvaluator();
        long testStart = System.currentTimeMillis();
        List predictions = model.predict(test);
        LabelEvaluation evaluation = (LabelEvaluation)labelEvaluator.evaluate(model, predictions, (DataProvenance)test.getProvenance());
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        System.out.println(evaluation.getConfusionMatrix().toString());
        if (model.generatesProbabilities()) {
            System.out.println("Average AUC = " + evaluation.averageAUCROC(false));
            System.out.println("Average weighted AUC = " + evaluation.averageAUCROC(true));
        }
        if (o.predictionPath != null) {
            try (BufferedWriter wrt = Files.newBufferedWriter(o.predictionPath, new OpenOption[0]);){
                List labels = model.getOutputIDInfo().getDomain().stream().map(Label::getLabel).sorted().collect(Collectors.toList());
                wrt.write("Label,");
                wrt.write(String.join((CharSequence)",", labels));
                wrt.newLine();
                for (Prediction pred : predictions) {
                    Example ex = pred.getExample();
                    wrt.write(((Label)ex.getOutput()).getLabel() + ",");
                    wrt.write(labels.stream().map(l -> Double.toString(((Label)pred.getOutputScores().get(l)).getScore())).collect(Collectors.joining(",")));
                    wrt.newLine();
                }
                wrt.flush();
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "Error writing predictions", e);
            }
        }
    }

    public static class ConfigurableTestOptions
    implements Options {
        @Option(longName="hashing-dimension", usage="Hashing dimension used for standard text format.")
        public int hashDim = 0;
        @Option(longName="ngram", usage="Ngram size to generate when using standard text format. Defaults to 2.")
        public int ngram = 2;
        @Option(longName="term-counting", usage="Use term counts instead of boolean when using the standard text format.")
        public boolean termCounting;
        @Option(longName="csv-response-name", usage="Response name in the csv file.")
        public String csvResponseName;
        @Option(longName="libsvm-zero-indexed", usage="Is the libsvm file zero indexed.")
        public boolean zeroIndexed = false;
        @Option(charName=102, longName="model-path", usage="Load a trainer from the config file.")
        public Path modelPath;
        @Option(charName=111, longName="predictions", usage="Path to write model predictions")
        public Path predictionPath;
        @Option(charName=115, longName="input-format", usage="Loads the data using the specified format. Defaults to LIBSVM.")
        public DataOptions.InputFormat inputFormat = DataOptions.InputFormat.LIBSVM;
        @Option(charName=118, longName="testing-file", usage="Path to the testing file.")
        public Path testingPath;
        @Option(longName="read-protobuf-model", usage="Load the model in protobuf format.")
        public boolean protobufModel;

        public String getOptionsDescription() {
            return "Tests an already trained classifier on a dataset.";
        }
    }
}

