/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import ai.sklearn4j.naive_bayes.BaseNaiveBayes;
import java.util.List;

public class CategoricalNaiveBayes
extends BaseNaiveBayes {
    private List<NumpyArray<Double>> featureLogProbabilities = null;
    private NumpyArray<Double> classLogPrior = null;
    private NumpyArray<Double> categoryCounts = null;
    private NumpyArray<Long> numberOfCategories = null;

    @Override
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> x) {
        NumpyArray jll = NumpyArrayFactory.arrayOfDoubleWithShape(new int[]{x.getShape()[0], this.classCounts.getShape()[0]});
        for (int i = 0; i < this.getNumberOfFeatures(); ++i) {
            int[] indices = this.getArrayFirstDimension(x, i);
            NumpyArray<Double> logProb = this.featureLogProbabilities.get(i);
            int classCount = this.classes.getShape()[0];
            double[][] temp = new double[classCount][indices.length];
            for (int cls = 0; cls < classCount; ++cls) {
                for (int j = 0; j < indices.length; ++j) {
                    temp[cls][j] = logProb.get(cls, indices[j]);
                }
            }
            NumpyArray<Double> t = NumpyArrayFactory.from(temp).transpose();
            jll = Numpy.add(jll, t);
        }
        return Numpy.add(jll, this.classLogPrior);
    }

    private int[] getArrayFirstDimension(NumpyArray<Double> x, int secondDimensionIndex) {
        int[] indices = new int[x.getShape()[0]];
        for (int j = 0; j < indices.length; ++j) {
            double value = x.get(j, secondDimensionIndex);
            indices[j] = (int)value;
        }
        return indices;
    }

    public NumpyArray<Double> getClassLogPrior() {
        return this.classLogPrior;
    }

    public void setClassLogPrior(NumpyArray<Double> classLogPrior) {
        this.classLogPrior = classLogPrior;
    }

    public List<NumpyArray<Double>> getFeatureLogProbabilities() {
        return this.featureLogProbabilities;
    }

    public void setFeatureLogProbabilities(List<NumpyArray<Double>> featureLogProbabilities) {
        this.featureLogProbabilities = featureLogProbabilities;
    }
}

