/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.explanations.lime;

import com.oracle.labs.mlrg.olcut.command.Command;
import com.oracle.labs.mlrg.olcut.command.CommandGroup;
import com.oracle.labs.mlrg.olcut.command.CommandInterpreter;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Path;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.jline.builtins.Completers;
import org.jline.reader.Completer;
import org.jline.reader.impl.completer.NullCompleter;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.lime.LIMEExplanation;
import org.tribuo.classification.explanations.lime.LIMEText;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.TextPipeline;
import org.tribuo.data.text.impl.BasicPipeline;
import org.tribuo.data.text.impl.TextFeatureExtractorImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.CARTJointRegressionTrainer;
import org.tribuo.util.tokens.Tokenizer;
import org.tribuo.util.tokens.universal.UniversalTokenizer;

public class LIMETextCLI
implements CommandGroup {
    private static final Logger logger = Logger.getLogger(LIMETextCLI.class.getName());
    private final CommandInterpreter shell;
    private Model<Label> model;
    private int numSamples = 100;
    private int numFeatures = 10;
    private SparseTrainer<Regressor> limeTrainer = new CARTJointRegressionTrainer((int)Math.log(this.numFeatures), true);
    private Tokenizer tokenizer = new UniversalTokenizer();
    private TextFeatureExtractor<Label> extractor = new TextFeatureExtractorImpl((TextPipeline)new BasicPipeline(this.tokenizer, 2));
    private LIMEText limeText = null;

    public LIMETextCLI() {
        this.shell = new CommandInterpreter();
        this.shell.setPrompt("lime-text sh% ");
    }

    public String getName() {
        return "LIME Text CLI";
    }

    public String getDescription() {
        return "Commands for experimenting with LIME Text.";
    }

    public Completer[] fileCompleter() {
        return new Completer[]{new Completers.FileNameCompleter(), new NullCompleter()};
    }

    public void startShell() {
        this.shell.add((CommandGroup)this);
        this.shell.start();
    }

    @Command(usage="<filename> <load-protobuf> - Load a model from disk.", completers="fileCompleter")
    public String loadModel(CommandInterpreter ci, File path, boolean protobuf) {
        String output = "Failed to load model";
        if (protobuf) {
            try {
                Model tmpModel = Model.deserializeFromFile((Path)path.toPath());
                this.model = tmpModel.castModel(Label.class);
                output = "Loaded model from path " + path.getAbsolutePath();
            }
            catch (IllegalStateException e) {
                logger.log(Level.SEVERE, "Failed to deserialize protobuf when reading from file " + path.getAbsolutePath(), e);
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
            }
        } else {
            try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)));){
                Model tmpModel = (Model)ois.readObject();
                this.model = tmpModel.castModel(Label.class);
                output = "Loaded model from path " + path.getAbsolutePath();
            }
            catch (ClassNotFoundException e) {
                logger.log(Level.SEVERE, "Failed to load class from stream " + path.getAbsolutePath(), e);
            }
            catch (FileNotFoundException e) {
                logger.log(Level.SEVERE, "Failed to open file " + path.getAbsolutePath(), e);
            }
            catch (IOException e) {
                logger.log(Level.SEVERE, "IOException when reading from file " + path.getAbsolutePath(), e);
            }
        }
        this.limeText = new LIMEText(new SplittableRandom(1L), this.model, this.limeTrainer, this.numSamples, this.extractor, this.tokenizer);
        return output;
    }

    @Command(usage="Does the model generate probabilities")
    public String generatesProbabilities(CommandInterpreter ci) {
        return "" + this.model.generatesProbabilities();
    }

    @Command(usage="Shows the model description")
    public String modelDescription(CommandInterpreter ci) {
        return this.model.toString();
    }

    @Command(usage="Shows the information on a particular feature")
    public String featureInfo(CommandInterpreter ci, String featureName) {
        VariableIDInfo f = this.model.getFeatureIDMap().get(featureName);
        if (f != null) {
            return "" + f.toString();
        }
        return "Feature " + featureName + " not found.";
    }

    @Command(usage="<int> - Shows the top N features in the model")
    public String topFeatures(CommandInterpreter ci, int numFeatures) {
        return "" + this.model.getTopFeatures(numFeatures);
    }

    @Command(usage="Shows the number of features in the model")
    public String numFeatures(CommandInterpreter ci) {
        return "" + this.model.getFeatureIDMap().size();
    }

    @Command(usage="<min count> - Shows the number of features that occurred more than min count times.")
    public String minCount(CommandInterpreter ci, int minCount) {
        int counter = 0;
        for (VariableInfo f : this.model.getFeatureIDMap()) {
            if (f.getCount() <= minCount) continue;
            ++counter;
        }
        return counter + " features occurred more than " + minCount + " times.";
    }

    @Command(usage="Shows the output statistics")
    public String showLabelStats(CommandInterpreter ci) {
        return "Label histogram : \n" + this.model.getOutputIDInfo().toReadableString();
    }

    @Command(usage="Sets the number of samples to use in LIME")
    public String setNumSamples(CommandInterpreter ci, int newNumSamples) {
        this.numSamples = newNumSamples;
        return "Set number of samples to " + this.numSamples;
    }

    @Command(usage="Explain a text classification")
    public String explain(CommandInterpreter ci, String[] tokens) {
        String text = String.join((CharSequence)" ", tokens);
        LIMEExplanation explanation = this.limeText.explain(text);
        SparseModel<Regressor> model = explanation.getModel();
        ci.out.println("Active features of the predicted class = " + model.getActiveFeatures().get(((Label)explanation.getPrediction().getOutput()).getLabel()));
        return "Explanation = " + explanation.toString();
    }

    @Command(usage="Sets the number of features LIME should use in an explanation")
    public String setNumFeatures(CommandInterpreter ci, int newNumFeatures) {
        this.numFeatures = newNumFeatures;
        this.limeTrainer = new CARTJointRegressionTrainer((int)Math.log(this.numFeatures), true);
        this.limeText = new LIMEText(new SplittableRandom(1L), this.model, this.limeTrainer, this.numSamples, this.extractor, this.tokenizer);
        return "Set the number of features in LIME to " + this.numFeatures;
    }

    @Command(usage="Make a prediction")
    public String predict(CommandInterpreter ci, String[] tokens) {
        String text = String.join((CharSequence)" ", tokens);
        Prediction prediction = this.model.predict(this.extractor.extract((Output)LabelFactory.UNKNOWN_LABEL, text));
        return "Prediction = " + prediction.toString();
    }

    public static void main(String[] args) {
        LIMETextCLIOptions options = new LIMETextCLIOptions();
        try {
            ConfigurationManager cm = new ConfigurationManager(args, (Options)options, false);
            LIMETextCLI driver = new LIMETextCLI();
            if (options.modelFilename != null) {
                logger.log(Level.INFO, driver.loadModel(driver.shell, new File(options.modelFilename), options.protobufFormat));
            }
            driver.startShell();
        }
        catch (UsageException e) {
            System.out.println("Usage: " + e.getUsage());
        }
    }

    public static class LIMETextCLIOptions
    implements Options {
        @Option(charName=102, longName="filename", usage="Model file to load. Optional.")
        public String modelFilename;
        @Option(charName=112, longName="protobuf-model", usage="Load the model from a protobuf. Optional")
        public boolean protobufFormat;
    }
}

