/*
 * Decompiled with CFR 0.152.
 */
package hex.naivebayes;

import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.SupervisedModel;
import hex.SupervisedModelBuilder;
import hex.genmodel.GenModel;
import hex.naivebayes.NaiveBayes;
import hex.schemas.NaiveBayesModelV3;
import water.H2O;
import water.Key;
import water.api.ModelSchema;
import water.util.TwoDimTable;

public class NaiveBayesModel
extends SupervisedModel<NaiveBayesModel, NaiveBayesParameters, NaiveBayesOutput> {
    public NaiveBayesModel(Key selfKey, NaiveBayesParameters parms, NaiveBayesOutput output) {
        super(selfKey, (SupervisedModel.SupervisedParameters)parms, (SupervisedModel.SupervisedOutput)output);
    }

    public ModelSchema schema() {
        return new NaiveBayesModelV3();
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((NaiveBayesOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length, domain);
            }
        }
        throw H2O.unimpl();
    }

    protected double[] score0(double[] data, double[] preds) {
        double[] nums = new double[((NaiveBayesOutput)this._output)._levels.length];
        assert (preds.length == ((NaiveBayesOutput)this._output)._levels.length + 1);
        for (int rlevel = 0; rlevel < ((NaiveBayesOutput)this._output)._levels.length; ++rlevel) {
            int col;
            nums[rlevel] = Math.log(((NaiveBayesOutput)this._output)._apriori_raw[rlevel]);
            for (col = 0; col < ((NaiveBayesOutput)this._output)._ncats; ++col) {
                if (Double.isNaN(data[col])) continue;
                int plevel = (int)data[col];
                double prob = plevel < ((NaiveBayesOutput)this._output)._pcond_raw.length ? ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][plevel] : ((NaiveBayesParameters)this._parms)._laplace / ((double)((NaiveBayesOutput)this._output)._rescnt[rlevel] + ((NaiveBayesParameters)this._parms)._laplace * (double)((NaiveBayesOutput)this._output)._domains[col].length);
                int n = rlevel;
                nums[n] = nums[n] + Math.log(prob <= ((NaiveBayesParameters)this._parms)._eps_prob ? ((NaiveBayesParameters)this._parms)._min_prob : prob);
            }
            for (col = ((NaiveBayesOutput)this._output)._ncats; col < data.length; ++col) {
                double mean;
                if (Double.isNaN(data[col])) continue;
                double x = data[col];
                double d = mean = Double.isNaN(((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][0]) ? 0.0 : ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][0];
                double stddev = Double.isNaN(((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1]) ? 1.0 : (((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1] <= ((NaiveBayesParameters)this._parms)._eps_sdev ? ((NaiveBayesParameters)this._parms)._min_sdev : ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1]);
                double prob = Math.exp(-((x - mean) * (x - mean)) / (2.0 * stddev * stddev)) / (stddev * Math.sqrt(Math.PI * 2));
                int n = rlevel;
                nums[n] = nums[n] + Math.log(prob <= ((NaiveBayesParameters)this._parms)._eps_prob ? ((NaiveBayesParameters)this._parms)._min_prob : prob);
            }
        }
        for (int i = 0; i < nums.length; ++i) {
            double sum = 0.0;
            for (int j = 0; j < nums.length; ++j) {
                sum += Math.exp(nums[j] - nums[i]);
            }
            preds[i + 1] = 1.0 / sum;
        }
        preds[0] = GenModel.getPrediction((double[])preds, (double[])data, (double)this.defaultThreshold());
        return preds;
    }

    public static class NaiveBayesOutput
    extends SupervisedModel.SupervisedOutput {
        public TwoDimTable _apriori;
        public double[] _apriori_raw;
        public TwoDimTable[] _pcond;
        public double[][][] _pcond_raw;
        public int[] _rescnt;
        public String[] _levels;
        public int _ncats;

        public NaiveBayesOutput(NaiveBayes b) {
            super((SupervisedModelBuilder)b);
        }
    }

    public static class NaiveBayesParameters
    extends SupervisedModel.SupervisedParameters {
        public double _laplace = 0.0;
        public double _eps_sdev = 0.0;
        public double _min_sdev = 0.001;
        public double _eps_prob = 0.0;
        public double _min_prob = 0.001;
        public boolean _compute_metrics = true;
    }
}

