/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import ai.sklearn4j.core.libraries.numpy.wrappers.Dim2DoubleNumpyWrapper;
import ai.sklearn4j.naive_bayes.BaseNaiveBayes;

public class GaussianNaiveBayes
extends BaseNaiveBayes {
    private NumpyArray<Double> classPriors = null;
    private NumpyArray<Double> priors = null;
    private NumpyArray<Double> sigma = null;
    private NumpyArray<Double> theta = null;

    @Override
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> x) {
        int count = x.getShape()[0];
        int classCount = this.classCounts.getShape()[0];
        int featureCount = this.sigma.getShape()[1];
        double[][] jointLogLikelihood = new double[count][classCount];
        double[][] variance = ((Dim2DoubleNumpyWrapper)this.sigma.getWrapper()).getArray();
        double[][] mean = ((Dim2DoubleNumpyWrapper)this.theta.getWrapper()).getArray();
        for (int cls = 0; cls < classCount; ++cls) {
            double sumOfLogVariance = 0.0;
            for (int feature = 0; feature < featureCount; ++feature) {
                sumOfLogVariance += Math.log(Math.PI * 2 * variance[cls][feature]);
            }
            double jointi = Math.log(this.classPriors.get(cls));
            for (int i = 0; i < count; ++i) {
                double value = 0.0;
                for (int feature = 0; feature < featureCount; ++feature) {
                    double diff = x.get(i, feature) - mean[cls][feature];
                    value += Math.pow(x.get(i, feature) - mean[cls][feature], 2.0) / variance[cls][feature];
                }
                value = -0.5 * (sumOfLogVariance + value);
                jointLogLikelihood[i][cls] = value + jointi;
            }
        }
        return NumpyArrayFactory.from(jointLogLikelihood);
    }

    public NumpyArray<Double> getClassPriors() {
        return this.classPriors;
    }

    public void setClassPriors(NumpyArray<Double> classPriors) {
        this.classPriors = classPriors;
    }

    public NumpyArray<Double> getPriors() {
        return this.priors;
    }

    public void setPriors(NumpyArray<Double> priors) {
        this.priors = priors;
    }

    public NumpyArray<Double> getSigma() {
        return this.sigma;
    }

    public void setSigma(NumpyArray<Double> sigma) {
        this.sigma = sigma;
    }

    public NumpyArray<Double> getTheta() {
        return this.theta;
    }

    public void setTheta(NumpyArray<Double> theta) {
        this.theta = theta;
    }
}

