/*
 * Decompiled with CFR 0.152.
 */
package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.base.ClassifierMixin;
import ai.sklearn4j.core.libraries.Scipy;
import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;

public abstract class BaseNaiveBayes
extends ClassifierMixin {
    protected abstract NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> var1);

    @Override
    public NumpyArray<Long> predict(NumpyArray<Double> x) {
        NumpyArray<Double> jll = this.jointLogLikelihood(x);
        return Numpy.argmax(jll, 1);
    }

    @Override
    public NumpyArray<Double> predictLogProbabilities(NumpyArray<Double> x) {
        NumpyArray<Double> jll = this.jointLogLikelihood(x);
        NumpyArray<Double> logProbabilityOfX = Scipy.logSumExponent(jll, 1);
        return Numpy.subtract(jll, Numpy.atLeast2D(logProbabilityOfX).transpose());
    }

    @Override
    public NumpyArray<Double> predictProbabilities(NumpyArray<Double> x) {
        return Numpy.exp(this.predictLogProbabilities(x));
    }
}

