/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.explanations.lime;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.CategoricalInfo;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.RealInfo;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.TabularExplainer;
import org.tribuo.classification.explanations.lime.LIMEExplanation;
import org.tribuo.impl.ArrayExample;
import org.tribuo.interop.ExternalModel;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.evaluation.RegressionEvaluator;
import org.tribuo.util.Util;

public class LIMEBase
implements TabularExplainer<Regressor> {
    private static final Logger logger = Logger.getLogger(LIMEBase.class.getName());
    public static final double WIDTH_CONSTANT = 0.75;
    public static final double DISTANCE_DELTA = 1.0E-12;
    protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory();
    protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true);
    protected final SplittableRandom rng;
    protected final Model<Label> innerModel;
    protected final SparseTrainer<Regressor> explanationTrainer;
    protected final int numSamples;
    protected final long numTrainingExamples;
    protected final double kernelWidth;
    private final ImmutableFeatureMap fMap;

    public LIMEBase(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples) {
        if (!(explanationTrainer instanceof WeightedExamples)) {
            throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + explanationTrainer.toString());
        }
        if (!innerModel.generatesProbabilities()) {
            throw new IllegalArgumentException("LIME requires the model generate probabilities.");
        }
        if (innerModel instanceof ExternalModel) {
            throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + innerModel.getClass() + " which is an external model.");
        }
        this.rng = rng;
        this.innerModel = innerModel;
        this.explanationTrainer = explanationTrainer;
        this.numSamples = numSamples;
        this.numTrainingExamples = innerModel.getOutputIDInfo().getTotalObservations();
        this.kernelWidth = Math.pow((double)innerModel.getFeatureIDMap().size() * 0.75, 2.0);
        this.fMap = innerModel.getFeatureIDMap();
    }

    public LIMEExplanation explain(Example<Label> example) {
        return (LIMEExplanation)this.explainWithSamples(example).getA();
    }

    protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
        Prediction prediction = this.innerModel.predict(example);
        ArrayExample labelledExample = new ArrayExample((Output)LIMEBase.transformOutput((Prediction<Label>)prediction), example, 1.0f);
        List<Example<Regressor>> sample = this.sampleData(example);
        SparseModel<Regressor> model = this.trainExplainer((Example<Regressor>)labelledExample, sample);
        ArrayList<Prediction> predictions = new ArrayList<Prediction>(model.predict(sample));
        predictions.add(model.predict((Example)labelledExample));
        RegressionEvaluation evaluation = (RegressionEvaluation)evaluator.evaluate(model, predictions, (DataProvenance)new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory));
        return new Pair((Object)new LIMEExplanation(model, (Prediction<Label>)prediction, evaluation), sample);
    }

    private List<Example<Regressor>> sampleData(Example<Label> example) {
        ArrayList<Example<Regressor>> output = new ArrayList<Example<Regressor>>();
        SparseVector exampleVector = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.fMap, (boolean)false);
        Random innerRNG = new Random(this.rng.nextLong());
        for (int i = 0; i < this.numSamples; ++i) {
            Example<Label> sample = LIMEBase.samplePoint(innerRNG, this.fMap, this.numTrainingExamples, exampleVector);
            Prediction samplePrediction = this.innerModel.predict(sample);
            double distance = LIMEBase.measureDistance(this.fMap, this.numTrainingExamples, exampleVector, SparseVector.createSparseVector(sample, (ImmutableFeatureMap)this.fMap, (boolean)false));
            distance = LIMEBase.kernelDist(distance, this.kernelWidth);
            ArrayExample labelledSample = new ArrayExample((Output)LIMEBase.transformOutput((Prediction<Label>)samplePrediction), sample, (float)distance);
            output.add((Example<Regressor>)labelledSample);
        }
        return output;
    }

    public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) {
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<Double> values = new ArrayList<Double>();
        for (VariableInfo info : fMap) {
            int id = ((VariableIDInfo)info).getID();
            double inputValue = input.get(id);
            if (info instanceof CategoricalInfo) {
                CategoricalInfo catInfo = (CategoricalInfo)info;
                double sample = catInfo.frequencyBasedSample(rng, numTrainingExamples);
                if (!(Math.abs(sample) > 1.0E-10)) continue;
                names.add(info.getName());
                values.add(sample);
                continue;
            }
            if (info instanceof RealInfo) {
                RealInfo realInfo = (RealInfo)info;
                int count = realInfo.getCount();
                double threshold = (double)count / (double)numTrainingExamples;
                if (!(rng.nextDouble() < threshold)) continue;
                double variance = realInfo.getVariance();
                double sample = rng.nextGaussian() * Math.sqrt(variance) + inputValue;
                names.add(info.getName());
                values.add(sample);
                continue;
            }
            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
        }
        return new ArrayExample((Output)LabelFactory.UNKNOWN_LABEL, names.toArray(new String[0]), Util.toPrimitiveDouble(values));
    }

    protected SparseModel<Regressor> trainExplainer(Example<Regressor> target, List<Example<Regressor>> samples) {
        MutableDataset explanationDataset = new MutableDataset((DataProvenance)new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory);
        explanationDataset.add(target);
        explanationDataset.addAll(samples);
        SparseModel explainer = this.explanationTrainer.train((Dataset)explanationDataset);
        return explainer;
    }

    public static double kernelDist(double input, double width) {
        return Math.sqrt(Math.exp(-(input * input) / width));
    }

    public static double measureDistance(ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input, SparseVector sample) {
        VectorTuple otherTuple;
        VectorTuple tuple;
        double score = 0.0;
        VectorIterator itr = input.iterator();
        VectorIterator otherItr = sample.iterator();
        while (itr.hasNext() && otherItr.hasNext()) {
            tuple = (VectorTuple)itr.next();
            otherTuple = (VectorTuple)otherItr.next();
            while (itr.hasNext() && tuple.index < otherTuple.index) {
                score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, tuple.index, tuple.value);
                tuple = (VectorTuple)itr.next();
            }
            while (otherItr.hasNext() && tuple.index > otherTuple.index) {
                score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, otherTuple.index, otherTuple.value);
                otherTuple = (VectorTuple)otherItr.next();
            }
            if (tuple.index == otherTuple.index) {
                score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, tuple.index, tuple.value, otherTuple.value);
                continue;
            }
            score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, tuple.index, tuple.value);
            score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, otherTuple.index, otherTuple.value);
        }
        while (itr.hasNext()) {
            tuple = (VectorTuple)itr.next();
            score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, tuple.index, tuple.value);
        }
        while (otherItr.hasNext()) {
            otherTuple = (VectorTuple)otherItr.next();
            score += LIMEBase.calculateSingleDistance(fMap, numTrainingExamples, otherTuple.index, otherTuple.value);
        }
        return Math.sqrt(score);
    }

    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double value) {
        VariableIDInfo info = fMap.get(index);
        if (info instanceof CategoricalInfo) {
            return 1.0;
        }
        if (info instanceof RealInfo) {
            RealInfo rInfo = (RealInfo)info;
            double curScore = value * value;
            double range = numTrainingExamples != (long)info.getCount() ? Math.max(rInfo.getMax(), 0.0) - Math.min(rInfo.getMin(), 0.0) : rInfo.getMax() - rInfo.getMin();
            return curScore / (range * range);
        }
        throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
    }

    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double firstValue, double secondValue) {
        VariableIDInfo info = fMap.get(index);
        if (info instanceof CategoricalInfo) {
            if (Math.abs(firstValue - secondValue) > 1.0E-12) {
                return 1.0;
            }
            return 0.0;
        }
        if (info instanceof RealInfo) {
            RealInfo rInfo = (RealInfo)info;
            double tmp = firstValue - secondValue;
            double range = numTrainingExamples != (long)info.getCount() ? Math.max(rInfo.getMax(), 0.0) - Math.min(rInfo.getMin(), 0.0) : rInfo.getMax() - rInfo.getMin();
            return tmp * tmp / (range * range);
        }
        throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
    }

    public static Regressor transformOutput(Prediction<Label> prediction) {
        Map outputs = prediction.getOutputScores();
        String[] names = new String[outputs.size()];
        double[] values = new double[outputs.size()];
        int i = 0;
        for (Map.Entry e : outputs.entrySet()) {
            names[i] = (String)e.getKey();
            values[i] = ((Label)e.getValue()).getScore();
            ++i;
        }
        return new Regressor(names, values);
    }
}

