/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.naive_bayes.BaseNaiveBayes;
import ai.sklearn4j.utils.ExtMath;
import ai.sklearn4j.utils.Preprocessing;

public class BernoulliNaiveBayes
extends BaseNaiveBayes {
    private NumpyArray<Double> featureLogProbabilities = null;
    private NumpyArray<Double> classLogPrior = null;
    private NumpyArray<Double> featureCounts = null;
    private double binarizationThreshold = 0.0;

    @Override
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> x) {
        int n_features_X;
        x = Preprocessing.binarizeInput(x, this.binarizationThreshold);
        int n_features = this.featureLogProbabilities.getShape()[1];
        if (n_features != (n_features_X = x.getShape()[1])) {
            throw new ScikitLearnCoreException(String.format("Expected input with %d features, got %d instead.", n_features, n_features_X));
        }
        NumpyArray<Double> featureProbabilities = Numpy.exp(this.featureLogProbabilities);
        NumpyArray<Double> negProb = Numpy.log(Numpy.add(Numpy.multiply(featureProbabilities, -1.0), 1.0));
        NumpyArray jll = ExtMath.dot(x, Numpy.subtract(this.featureLogProbabilities, negProb).transpose());
        jll = Numpy.add(jll, Numpy.add(this.classLogPrior, Numpy.sum(negProb, 1)));
        return jll;
    }

    public NumpyArray<Double> getFeatureLogProbabilities() {
        return this.featureLogProbabilities;
    }

    public void setFeatureLogProbabilities(NumpyArray<Double> featureLogProbabilities) {
        this.featureLogProbabilities = featureLogProbabilities;
    }

    public NumpyArray<Double> getClassLogPrior() {
        return this.classLogPrior;
    }

    public void setClassLogPrior(NumpyArray<Double> classLogPrior) {
        this.classLogPrior = classLogPrior;
    }

    public NumpyArray<Double> getFeatureCounts() {
        return this.featureCounts;
    }

    public void setFeatureCount(NumpyArray<Double> featureCounts) {
        this.featureCounts = featureCounts;
    }

    public double getBinarizationThreshold() {
        return this.binarizationThreshold;
    }

    public void setBinarizationThreshold(double binarizationThreshold) {
        this.binarizationThreshold = binarizationThreshold;
    }
}

