/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.explanations.lime;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.SplittableRandom;
import org.tribuo.CategoricalInfo;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.RealInfo;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.ColumnarExplainer;
import org.tribuo.classification.explanations.lime.LIMEBase;
import org.tribuo.classification.explanations.lime.LIMEExplanation;
import org.tribuo.data.columnar.FieldProcessor;
import org.tribuo.data.columnar.ResponseProcessor;
import org.tribuo.data.columnar.RowProcessor;
import org.tribuo.impl.ArrayExample;
import org.tribuo.impl.ListExample;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.util.Util;
import org.tribuo.util.tokens.Token;
import org.tribuo.util.tokens.Tokenizer;

public class LIMEColumnar
extends LIMEBase
implements ColumnarExplainer<Regressor> {
    private final RowProcessor<Label> generator;
    private final Map<String, FieldProcessor> binarisedFields = new HashMap<String, FieldProcessor>();
    private final Map<String, FieldProcessor> tabularFields = new HashMap<String, FieldProcessor>();
    private final Map<String, FieldProcessor> textFields = new HashMap<String, FieldProcessor>();
    private final ResponseProcessor<Label> responseProcessor;
    private final Map<String, List<VariableInfo>> binarisedInfos;
    private final Map<String, double[]> binarisedCDFs;
    private final ImmutableFeatureMap binarisedDomain;
    private final ImmutableFeatureMap tabularDomain;
    private final ImmutableFeatureMap textDomain;
    private final Tokenizer tokenizer;
    private final ThreadLocal<Tokenizer> tokenizerThreadLocal;

    public LIMEColumnar(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples, RowProcessor<Label> exampleGenerator, Tokenizer tokenizer) {
        super(rng, innerModel, explanationTrainer, numSamples);
        this.generator = exampleGenerator.copy();
        this.responseProcessor = this.generator.getResponseProcessor();
        this.tokenizer = tokenizer;
        this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {
            try {
                return this.tokenizer.clone();
            }
            catch (CloneNotSupportedException e) {
                throw new IllegalArgumentException("Tokenizer not cloneable", e);
            }
        });
        if (!this.generator.isConfigured()) {
            this.generator.expandRegexMapping(innerModel);
        }
        this.binarisedInfos = new HashMap<String, List<VariableInfo>>();
        ArrayList<VariableInfo> infos = new ArrayList<VariableInfo>();
        for (VariableInfo i : innerModel.getFeatureIDMap()) {
            infos.add(i);
        }
        ArrayList<VariableInfo> allBinarisedInfos = new ArrayList<VariableInfo>();
        ArrayList<VariableInfo> tabularInfos = new ArrayList<VariableInfo>();
        ArrayList<VariableInfo> textInfos = new ArrayList<VariableInfo>();
        block6: for (Map.Entry entry : this.generator.getFieldProcessors().entrySet()) {
            String searchName = (String)entry.getKey() + "@";
            switch (((FieldProcessor)entry.getValue()).getFeatureType()) {
                case BINARISED_CATEGORICAL: {
                    int numNamespaces = ((FieldProcessor)entry.getValue()).getNumNamespaces();
                    if (numNamespaces > 1) {
                        for (int i = 0; i < numNamespaces; ++i) {
                            String namespace = (String)entry.getKey() + "#" + i;
                            String namespaceSearchName = namespace + "@";
                            this.binarisedFields.put(namespace, (FieldProcessor)entry.getValue());
                            List binarisedInfoList = this.binarisedInfos.computeIfAbsent(namespace, k -> new ArrayList());
                            ListIterator li = infos.listIterator();
                            while (li.hasNext()) {
                                VariableInfo info = (VariableInfo)li.next();
                                if (!info.getName().startsWith(namespaceSearchName)) continue;
                                if (((CategoricalInfo)info).getUniqueObservations() != 1) {
                                    throw new IllegalStateException("Processor " + (String)entry.getKey() + ", should have been binary, but had " + ((CategoricalInfo)info).getUniqueObservations() + " unique values");
                                }
                                binarisedInfoList.add(info);
                                allBinarisedInfos.add(info);
                                li.remove();
                            }
                        }
                        continue block6;
                    }
                    this.binarisedFields.put((String)entry.getKey(), (FieldProcessor)entry.getValue());
                    List binarisedInfoList = this.binarisedInfos.computeIfAbsent((String)entry.getKey(), k -> new ArrayList());
                    ListIterator li = infos.listIterator();
                    while (li.hasNext()) {
                        VariableInfo i = (VariableInfo)li.next();
                        if (!i.getName().startsWith(searchName)) continue;
                        if (((CategoricalInfo)i).getUniqueObservations() != 1) {
                            throw new IllegalStateException("Processor " + (String)entry.getKey() + ", should have been binary, but had " + ((CategoricalInfo)i).getUniqueObservations() + " unique values");
                        }
                        binarisedInfoList.add(i);
                        allBinarisedInfos.add(i);
                        li.remove();
                    }
                    continue block6;
                }
                case CATEGORICAL: 
                case REAL: {
                    this.tabularFields.put((String)entry.getKey(), (FieldProcessor)entry.getValue());
                    ListIterator li = infos.listIterator();
                    while (li.hasNext()) {
                        VariableInfo i = (VariableInfo)li.next();
                        if (!i.getName().startsWith(searchName)) continue;
                        tabularInfos.add(i);
                        li.remove();
                    }
                    continue block6;
                }
                case TEXT: {
                    this.textFields.put((String)entry.getKey(), (FieldProcessor)entry.getValue());
                    ListIterator li = infos.listIterator();
                    while (li.hasNext()) {
                        VariableInfo i = (VariableInfo)li.next();
                        if (!i.getName().startsWith(searchName)) continue;
                        textInfos.add(i);
                        li.remove();
                    }
                    continue block6;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported feature type " + ((FieldProcessor)entry.getValue()).getFeatureType());
                }
            }
        }
        if (infos.size() != 0) {
            throw new IllegalArgumentException("Found " + infos.size() + " unsupported features.");
        }
        if (this.generator.getFeatureProcessors().size() != 0) {
            throw new IllegalArgumentException("LIMEColumnar does not support FeatureProcessors.");
        }
        this.tabularDomain = new ImmutableFeatureMap(tabularInfos);
        this.textDomain = new ImmutableFeatureMap(textInfos);
        this.binarisedDomain = new ImmutableFeatureMap(allBinarisedInfos);
        this.binarisedCDFs = new HashMap<String, double[]>();
        for (Map.Entry<Object, Object> entry : this.binarisedInfos.entrySet()) {
            long totalCount = 0L;
            long[] counts = new long[((List)entry.getValue()).size() + 1];
            int i = 0;
            for (VariableInfo info : (List)entry.getValue()) {
                long curCount;
                counts[i] = curCount = (long)info.getCount();
                totalCount += curCount;
                ++i;
            }
            long zeroCount = this.numTrainingExamples - totalCount;
            if (zeroCount < 0L) {
                throw new IllegalStateException("Processor " + (String)entry.getKey() + " purports to be a BINARISED_CATEGORICAL, but had overlap in it's elements");
            }
            counts[i] = zeroCount;
            double[] cdf = Util.generateCDF((long[])counts, (long)this.numTrainingExamples);
            this.binarisedCDFs.put((String)entry.getKey(), cdf);
        }
    }

    public LIMEExplanation explain(Map<String, String> input) {
        return (LIMEExplanation)this.explainWithSamples(input).getA();
    }

    protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> input) {
        Optional optExample = this.generator.generateExample(input, false);
        if (optExample.isPresent()) {
            Example example = (Example)optExample.get();
            if (this.textDomain.size() == 0 && this.binarisedCDFs.size() == 0) {
                return this.explainWithSamples((Example<Label>)example);
            }
            Prediction prediction = this.innerModel.predict(example);
            ArrayExample labelledExample = new ArrayExample((Output)LIMEColumnar.transformOutput((Prediction<Label>)prediction));
            for (Feature f : example) {
                if (this.tabularDomain.getID(f.getName()) == -1) continue;
                labelledExample.add(f);
            }
            SparseVector tabularVector = SparseVector.createSparseVector((Example)labelledExample, (ImmutableFeatureMap)this.tabularDomain, (boolean)false);
            HashMap<String, String> exampleTextValues = new HashMap<String, String>();
            HashMap<String, List<Token>> exampleTextTokens = new HashMap<String, List<Token>>();
            for (Map.Entry<String, FieldProcessor> e : this.textFields.entrySet()) {
                String value = input.get(e.getKey());
                if (value == null) continue;
                List tokens = this.tokenizerThreadLocal.get().tokenize((CharSequence)value);
                for (int i = 0; i < tokens.size(); ++i) {
                    labelledExample.add(this.nameFeature(e.getKey(), ((Token)tokens.get((int)i)).text, i), 1.0);
                }
                exampleTextValues.put(e.getKey(), value);
                exampleTextTokens.put(e.getKey(), tokens);
            }
            List<Example<Regressor>> sample = this.sampleData(tabularVector, exampleTextValues, exampleTextTokens);
            SparseModel<Regressor> model = this.trainExplainer((Example<Regressor>)labelledExample, sample);
            ArrayList<Prediction> predictions = new ArrayList<Prediction>(model.predict(sample));
            predictions.add(model.predict((Example)labelledExample));
            RegressionEvaluation evaluation = (RegressionEvaluation)evaluator.evaluate(model, predictions, (DataProvenance)new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
            return new Pair((Object)new LIMEExplanation(model, (Prediction<Label>)prediction, evaluation), sample);
        }
        throw new IllegalArgumentException("Label not found in input " + input.toString());
    }

    protected String nameFeature(String fieldName, String name, int idx) {
        return fieldName + "@" + name + "@idx" + idx;
    }

    private List<Example<Regressor>> sampleData(SparseVector tabularVector, Map<String, String> text, Map<String, List<Token>> textTokens) {
        ArrayList<Example<Regressor>> output = new ArrayList<Example<Regressor>>();
        Random innerRNG = new Random(this.rng.nextLong());
        for (int i = 0; i < this.numSamples; ++i) {
            ListExample sampledExample = new ListExample((Output)LabelFactory.UNKNOWN_LABEL);
            ArrayList<Feature> tabularFeatures = new ArrayList<Feature>();
            for (VariableInfo info : this.tabularDomain) {
                int id = ((VariableIDInfo)info).getID();
                double inputValue = tabularVector.get(id);
                if (info instanceof CategoricalInfo) {
                    CategoricalInfo catInfo = (CategoricalInfo)info;
                    double sample = catInfo.frequencyBasedSample(innerRNG, this.numTrainingExamples);
                    if (!(Math.abs(sample) > 1.0E-10)) continue;
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                    continue;
                }
                if (info instanceof RealInfo) {
                    RealInfo realInfo = (RealInfo)info;
                    int count = realInfo.getCount();
                    double threshold = (double)count / (double)this.numTrainingExamples;
                    if (!(innerRNG.nextDouble() < threshold)) continue;
                    double d = realInfo.getVariance();
                    double sample = innerRNG.nextGaussian() * Math.sqrt(d) + inputValue;
                    Feature newFeature = new Feature(info.getName(), sample);
                    tabularFeatures.add(newFeature);
                    continue;
                }
                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
            }
            for (Map.Entry<String, double[]> e : this.binarisedCDFs.entrySet()) {
                int sample = Util.sampleFromCDF((double[])e.getValue(), (Random)innerRNG);
                if (sample == e.getValue().length - 1) continue;
                VariableInfo info = this.binarisedInfos.get(e.getKey()).get(sample);
                Feature newFeature = new Feature(info.getName(), 1.0);
                tabularFeatures.add(newFeature);
            }
            sampledExample.addAll(tabularFeatures);
            double tabularDistance = LIMEColumnar.measureDistance(this.tabularDomain, this.numTrainingExamples, tabularVector, SparseVector.createSparseVector((Example)sampledExample, (ImmutableFeatureMap)this.tabularDomain, (boolean)false));
            ArrayList textFeatures = new ArrayList();
            ArrayList<Feature> perturbedFeatures = new ArrayList<Feature>();
            double textDistance = 0.0;
            long numTokens = 0L;
            for (Map.Entry entry : text.entrySet()) {
                String curText = (String)entry.getValue();
                List<Token> tokens = textTokens.get(entry.getKey());
                numTokens += (long)tokens.size();
                int[] activeFeatures = new int[tokens.size()];
                char[] sampledText = curText.toCharArray();
                for (int j = 0; j < activeFeatures.length; ++j) {
                    activeFeatures[j] = innerRNG.nextInt(2);
                    if (activeFeatures[j] != 0) continue;
                    textDistance += 1.0;
                    Token curToken = tokens.get(j);
                    Arrays.fill(sampledText, curToken.start, curToken.end, '\u0000');
                }
                String sampledString = new String(sampledText);
                sampledString = sampledString.replace("\u0000", "");
                textFeatures.addAll(this.textFields.get(entry.getKey()).process(sampledString));
                for (int j = 0; j < activeFeatures.length; ++j) {
                    perturbedFeatures.add(new Feature(this.nameFeature((String)entry.getKey(), tokens.get((int)j).text, j), (double)activeFeatures[j]));
                }
            }
            sampledExample.addAll(textFeatures);
            double totalTextDistance = textDistance / (double)numTokens;
            Prediction samplePrediction = this.innerModel.predict((Example)sampledExample);
            double totalLength = tabularFeatures.size() + perturbedFeatures.size();
            double weight = 1.0 - (double)tabularFeatures.size() * (LIMEColumnar.kernelDist(tabularDistance, this.kernelWidth) + (double)perturbedFeatures.size() * totalTextDistance) / totalLength;
            ArrayExample labelledSample = new ArrayExample((Output)LIMEColumnar.transformOutput((Prediction<Label>)samplePrediction), (float)weight);
            labelledSample.addAll(tabularFeatures);
            labelledSample.addAll(perturbedFeatures);
            output.add((Example<Regressor>)labelledSample);
        }
        return output;
    }
}

