/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.mnb;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.HashMap;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.mnb.MultinomialNaiveBayesModel;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public class MultinomialNaiveBayesTrainer
implements Trainer<Label>,
WeightedExamples {
    @Config(description="Smoothing parameter.")
    private double alpha = 1.0;
    private int trainInvocationCount = 0;

    public MultinomialNaiveBayesTrainer() {
        this(1.0);
    }

    public MultinomialNaiveBayesTrainer(double alpha) {
        if (alpha <= 0.0) {
            throw new IllegalArgumentException("alpha parameter must be > 0");
        }
        this.alpha = alpha;
    }

    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableOutputInfo labelInfos = examples.getOutputIDInfo();
        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
        HashMap labelWeights = new HashMap();
        for (Pair label : labelInfos) {
            labelWeights.put((Integer)label.getA(), new HashMap());
        }
        for (Example ex : examples) {
            int idx = labelInfos.getID((Output)((Label)ex.getOutput()));
            Map featureMap = (Map)labelWeights.get(idx);
            double curWeight = ex.getWeight();
            for (Feature feat : ex) {
                if (feat.getValue() < 0.0) {
                    throw new IllegalStateException("Multinomial Naive Bayes requires non-negative features. Found feature " + feat.toString());
                }
                featureMap.merge(featureInfos.getID(feat.getName()), curWeight * feat.getValue(), Double::sum);
            }
        }
        if (invocationCount != -1) {
            this.setInvocationCount(invocationCount);
        }
        TrainerProvenance trainerProvenance = this.getProvenance();
        ModelProvenance provenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        ++this.trainInvocationCount;
        SparseVector[] labelVectors = new SparseVector[labelInfos.size()];
        for (int i = 0; i < labelInfos.size(); ++i) {
            SparseVector sv = SparseVector.createSparseVector((int)featureInfos.size(), (Map)((Map)labelWeights.get(i)));
            double unsmoothedZ = sv.oneNorm();
            sv.foreachInPlace(d -> Math.log((d + this.alpha) / (unsmoothedZ + (double)featureInfos.size() * this.alpha)));
            labelVectors[i] = sv;
        }
        DenseSparseMatrix labelWordProbs = DenseSparseMatrix.createFromSparseVectors((SparseVector[])labelVectors);
        return new MultinomialNaiveBayesModel("", provenance, featureInfos, (ImmutableOutputInfo<Label>)labelInfos, labelWordProbs, this.alpha);
    }

    public int getInvocationCount() {
        return this.trainInvocationCount;
    }

    public void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCount = invocationCount;
    }

    public String toString() {
        return "MultinomialNaiveBayesTrainer(alpha=" + this.alpha + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

