/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.mnb;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.mnb.protos.MultinomialNaiveBayesProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.math.util.ExpNormalizer;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

public class MultinomialNaiveBayesModel
extends Model<Label> {
    private static final long serialVersionUID = 1L;
    public static final int CURRENT_VERSION = 0;
    private final DenseSparseMatrix labelWordProbs;
    private final double alpha;
    private static final VectorNormalizer normalizer = new ExpNormalizer();

    MultinomialNaiveBayesModel(String name, ModelProvenance description, ImmutableFeatureMap featureInfos, ImmutableOutputInfo<Label> labelInfos, DenseSparseMatrix labelWordProbs, double alpha) {
        super(name, description, featureInfos, labelInfos, true);
        this.labelWordProbs = labelWordProbs;
        this.alpha = alpha;
    }

    public static MultinomialNaiveBayesModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        MultinomialNaiveBayesProto proto = (MultinomialNaiveBayesProto)message.unpack(MultinomialNaiveBayesProto.class);
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
            throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
        }
        ImmutableOutputInfo outputDomain = carrier.outputDomain();
        Tensor weights = Tensor.deserialize((TensorProto)proto.getLabelWordProbs());
        if (!(weights instanceof DenseSparseMatrix)) {
            throw new IllegalStateException("Invalid protobuf, label word probs must be a sparse matrix, found " + weights.getClass());
        }
        DenseSparseMatrix labelWordProbs = (DenseSparseMatrix)weights;
        if (labelWordProbs.getDimension1Size() != carrier.outputDomain().size()) {
            throw new IllegalStateException("Invalid protobuf, labelWordProbs not the right size, expected " + carrier.outputDomain().size() + ", found " + labelWordProbs.getDimension1Size());
        }
        if (labelWordProbs.getDimension2Size() != carrier.featureDomain().size()) {
            throw new IllegalStateException("Invalid protobuf, labelWordProbs not the right size, expected " + carrier.featureDomain().size() + ", found " + labelWordProbs.getDimension2Size());
        }
        double alpha = proto.getAlpha();
        if (alpha < 0.0) {
            throw new IllegalStateException("Invalid protobuf, alpha must be non-negative, found " + alpha);
        }
        return new MultinomialNaiveBayesModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), (ImmutableOutputInfo<Label>)outputDomain, labelWordProbs, alpha);
    }

    public Prediction<Label> predict(Example<Label> example) {
        SparseVector exVector = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        if (exVector.minValue() < 0.0) {
            throw new IllegalArgumentException("Example has negative feature values, example = " + example.toString());
        }
        if (exVector.numActiveElements() == 0) {
            throw new IllegalArgumentException("No features found in Example " + example.toString());
        }
        double[] alphaOffsets = new double[this.outputIDInfo.size()];
        int vocabSize = this.labelWordProbs.getDimension2Size();
        if (this.alpha > 0.0) {
            for (int i = 0; i < this.outputIDInfo.size(); ++i) {
                double unobservedProb = Math.log(this.alpha / (this.labelWordProbs.getRow(i).oneNorm() + (double)vocabSize * this.alpha));
                int[] mismatchedIndices = exVector.difference(this.labelWordProbs.getRow(i));
                double inExampleFactor = 0.0;
                for (int idx = 0; idx < mismatchedIndices.length; ++idx) {
                    inExampleFactor += exVector.get(mismatchedIndices[idx]) * unobservedProb;
                }
                alphaOffsets[i] = inExampleFactor;
            }
        }
        DenseVector prediction = this.labelWordProbs.leftMultiply((SGDVector)exVector);
        prediction.intersectAndAddInPlace((Tensor)DenseVector.createDenseVector((double[])alphaOffsets));
        prediction.normalize(normalizer);
        LinkedHashMap<String, Label> distribution = new LinkedHashMap<String, Label>();
        Label maxLabel = null;
        double maxScore = Double.NEGATIVE_INFINITY;
        for (VectorTuple vt : prediction) {
            String name = ((Label)this.outputIDInfo.getOutput(vt.index)).getLabel();
            Label label = new Label(name, vt.value);
            if (vt.value > maxScore) {
                maxScore = vt.value;
                maxLabel = label;
            }
            distribution.put(name, label);
        }
        Prediction p = new Prediction(maxLabel, distribution, exVector.numActiveElements(), example, true);
        return p;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        int maxFeatures = n < 0 ? this.featureIDMap.size() : n;
        HashMap<String, List<Pair<String, Double>>> topFeatures = new HashMap<String, List<Pair<String, Double>>>();
        for (Pair label : this.outputIDInfo) {
            List<Object> features = new ArrayList<Pair>(this.labelWordProbs.numActiveElements(((Integer)label.getA()).intValue()));
            for (VectorTuple vt : this.labelWordProbs.getRow(((Integer)label.getA()).intValue())) {
                features.add(new Pair((Object)this.featureIDMap.get(vt.index).getName(), (Object)vt.value));
            }
            features.sort(Comparator.comparing(x -> -((Double)x.getB()).doubleValue()));
            if (maxFeatures < this.featureIDMap.size()) {
                features = features.subList(0, maxFeatures);
            }
            topFeatures.put(((Label)label.getB()).getLabel(), features);
        }
        return topFeatures;
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        HashMap explanation = new HashMap();
        for (Pair label : this.outputIDInfo) {
            ArrayList<Pair> scores = new ArrayList<Pair>();
            for (Feature f : example) {
                int id = this.featureIDMap.getID(f.getName());
                if (id <= -1) continue;
                scores.add(new Pair((Object)f.getName(), (Object)this.labelWordProbs.getRow(((Integer)label.getA()).intValue()).get(id)));
            }
            explanation.put(((Label)label.getB()).getLabel(), scores);
        }
        return Optional.of(new Excuse(example, this.predict(example), explanation));
    }

    protected MultinomialNaiveBayesModel copy(String newName, ModelProvenance newProvenance) {
        return new MultinomialNaiveBayesModel(newName, newProvenance, this.featureIDMap, (ImmutableOutputInfo<Label>)this.outputIDInfo, new DenseSparseMatrix(this.labelWordProbs), this.alpha);
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        MultinomialNaiveBayesProto.Builder modelBuilder = MultinomialNaiveBayesProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setLabelWordProbs(this.labelWordProbs.serialize());
        modelBuilder.setAlpha(this.alpha);
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(MultinomialNaiveBayesModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }
}

