/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.explanations.lime;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.TextExplainer;
import org.tribuo.classification.explanations.lime.LIMEBase;
import org.tribuo.classification.explanations.lime.LIMEExplanation;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.util.tokens.Token;
import org.tribuo.util.tokens.Tokenizer;

public class LIMEText
extends LIMEBase
implements TextExplainer<Regressor> {
    private static final Logger logger = Logger.getLogger(LIMEText.class.getName());
    private final TextFeatureExtractor<Label> extractor;
    private final Tokenizer tokenizer;
    private final ThreadLocal<Tokenizer> tokenizerThreadLocal;

    public LIMEText(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples, TextFeatureExtractor<Label> extractor, Tokenizer tokenizer) {
        super(rng, innerModel, explanationTrainer, numSamples);
        this.extractor = extractor;
        this.tokenizer = tokenizer;
        this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {
            try {
                return this.tokenizer.clone();
            }
            catch (CloneNotSupportedException e) {
                throw new IllegalArgumentException("Tokenizer not cloneable", e);
            }
        });
    }

    public LIMEExplanation explain(String inputText) {
        Example trueExample = this.extractor.extract((Output)LabelFactory.UNKNOWN_LABEL, inputText);
        Prediction prediction = this.innerModel.predict(trueExample);
        ArrayExample bowExample = new ArrayExample((Output)LIMEText.transformOutput((Prediction<Label>)prediction));
        List tokens = this.tokenizerThreadLocal.get().tokenize((CharSequence)inputText);
        for (int i = 0; i < tokens.size(); ++i) {
            bowExample.add(this.nameFeature(((Token)tokens.get((int)i)).text, i), 1.0);
        }
        List<Example<Regressor>> sample = this.sampleData(inputText, tokens);
        SparseModel<Regressor> model = this.trainExplainer((Example<Regressor>)bowExample, sample);
        ArrayList<Prediction> predictions = new ArrayList<Prediction>(model.predict(sample));
        predictions.add(model.predict((Example)bowExample));
        RegressionEvaluation evaluation = (RegressionEvaluation)evaluator.evaluate(model, predictions, (DataProvenance)new SimpleDataSourceProvenance("LIMEText sampled data", regressionFactory));
        return new LIMEExplanation(model, (Prediction<Label>)prediction, evaluation);
    }

    protected String nameFeature(String name, int idx) {
        return name + "@idx" + idx;
    }

    protected List<Example<Regressor>> sampleData(String inputText, List<Token> tokens) {
        ArrayList<Example<Regressor>> output = new ArrayList<Example<Regressor>>();
        Random innerRNG = new Random(this.rng.nextLong());
        for (int i = 0; i < this.numSamples; ++i) {
            double distance = 0.0;
            int[] activeFeatures = new int[tokens.size()];
            char[] sampledText = inputText.toCharArray();
            for (int j = 0; j < activeFeatures.length; ++j) {
                activeFeatures[j] = innerRNG.nextInt(2);
                if (activeFeatures[j] != 0) continue;
                distance += 1.0;
                Token curToken = tokens.get(j);
                Arrays.fill(sampledText, curToken.start, curToken.end, '\u0000');
            }
            String sampledString = new String(sampledText);
            Example sample = this.extractor.extract((Output)LabelFactory.UNKNOWN_LABEL, sampledString = sampledString.replace("\u0000", ""));
            if (sample.size() <= 0) continue;
            Prediction samplePrediction = this.innerModel.predict(sample);
            double weight = 1.0 - distance / (double)tokens.size();
            ArrayExample labelledSample = new ArrayExample((Output)LIMEText.transformOutput((Prediction<Label>)samplePrediction), (float)weight);
            for (int j = 0; j < activeFeatures.length; ++j) {
                labelledSample.add(this.nameFeature(tokens.get((int)j).text, j), (double)activeFeatures[j]);
            }
            output.add((Example<Regressor>)labelledSample);
        }
        return output;
    }
}

