/*
 * Decompiled with CFR 0.152.
 */
package hex.naivebayes;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.genmodel.GenModel;
import hex.naivebayes.NaiveBayes;
import hex.schemas.NaiveBayesModelV3;
import water.H2O;
import water.Key;
import water.api.ModelSchema;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

public class NaiveBayesModel
extends Model<NaiveBayesModel, NaiveBayesParameters, NaiveBayesOutput> {
    public NaiveBayesModel(Key selfKey, NaiveBayesParameters parms, NaiveBayesOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    public ModelSchema schema() {
        return new NaiveBayesModelV3();
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((NaiveBayesOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(domain.length, domain);
            }
        }
        throw H2O.unimpl();
    }

    protected double[] score0(double[] data, double[] preds) {
        double[] nums = new double[((NaiveBayesOutput)this._output)._levels.length];
        assert (preds.length == ((NaiveBayesOutput)this._output)._levels.length + 1);
        for (int rlevel = 0; rlevel < ((NaiveBayesOutput)this._output)._levels.length; ++rlevel) {
            int col;
            nums[rlevel] = Math.log(((NaiveBayesOutput)this._output)._apriori_raw[rlevel]);
            for (col = 0; col < ((NaiveBayesOutput)this._output)._ncats; ++col) {
                if (Double.isNaN(data[col])) continue;
                int plevel = (int)data[col];
                double prob = plevel < ((NaiveBayesOutput)this._output)._pcond_raw.length ? ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][plevel] : ((NaiveBayesParameters)this._parms)._laplace / ((double)((NaiveBayesOutput)this._output)._rescnt[rlevel] + ((NaiveBayesParameters)this._parms)._laplace * (double)((NaiveBayesOutput)this._output)._domains[col].length);
                int n = rlevel;
                nums[n] = nums[n] + Math.log(prob <= ((NaiveBayesParameters)this._parms)._eps_prob ? ((NaiveBayesParameters)this._parms)._min_prob : prob);
            }
            for (col = ((NaiveBayesOutput)this._output)._ncats; col < data.length; ++col) {
                double mean;
                if (Double.isNaN(data[col])) continue;
                double x = data[col];
                double d = mean = Double.isNaN(((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][0]) ? 0.0 : ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][0];
                double stddev = Double.isNaN(((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1]) ? 1.0 : (((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1] <= ((NaiveBayesParameters)this._parms)._eps_sdev ? ((NaiveBayesParameters)this._parms)._min_sdev : ((NaiveBayesOutput)this._output)._pcond_raw[col][rlevel][1]);
                double prob = Math.exp(-((x - mean) * (x - mean)) / (2.0 * stddev * stddev)) / (stddev * Math.sqrt(Math.PI * 2));
                int n = rlevel;
                nums[n] = nums[n] + Math.log(prob <= ((NaiveBayesParameters)this._parms)._eps_prob ? ((NaiveBayesParameters)this._parms)._min_prob : prob);
            }
        }
        for (int i = 0; i < nums.length; ++i) {
            double sum = 0.0;
            for (int j = 0; j < nums.length; ++j) {
                sum += Math.exp(nums[j] - nums[i]);
            }
            preds[i + 1] = 1.0 / sum;
        }
        preds[0] = GenModel.getPrediction((double[])preds, (double[])((NaiveBayesOutput)this._output)._priorClassDist, (double[])data, (double)this.defaultThreshold());
        return preds;
    }

    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb = super.toJavaInit(sb, fileCtx);
        sb.ip("public boolean isSupervised() { return " + this.isSupervised() + "; }").nl();
        sb.ip("public int nfeatures() { return " + ((NaiveBayesOutput)this._output).nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + ((NaiveBayesOutput)this._output).nclasses() + "; }").nl();
        JCodeGen.toStaticVar((JCodeSB)sb, (String)"RESCNT", (int[])((NaiveBayesOutput)this._output)._rescnt, (String)"Count of categorical levels in response.");
        JCodeGen.toStaticVar((JCodeSB)sb, (String)"APRIORI", (double[])((NaiveBayesOutput)this._output)._apriori_raw, (String)"Apriori class distribution of the response.");
        JCodeGen.toStaticVar((JCodeSB)sb, (String)"PCOND", (double[][][])((NaiveBayesOutput)this._output)._pcond_raw, (String)"Conditional probability of predictors.");
        double[] dlen = null;
        if (((NaiveBayesOutput)this._output)._ncats > 0) {
            dlen = new double[((NaiveBayesOutput)this._output)._ncats];
            for (int i = 0; i < ((NaiveBayesOutput)this._output)._ncats; ++i) {
                dlen[i] = ((NaiveBayesOutput)this._output)._domains[i].length;
            }
        }
        JCodeGen.toStaticVar((JCodeSB)sb, (String)"DOMLEN", dlen, (String)"Number of unique levels for each categorical predictor.");
        return sb;
    }

    protected void toJavaPredictBody(SBPrintStream bodySb, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        bodySb.i().p("java.util.Arrays.fill(preds,0);").nl();
        bodySb.i().p("double mean, sdev, prob;").nl();
        bodySb.i().p("double[] nums = new double[" + ((NaiveBayesOutput)this._output)._levels.length + "];").nl();
        bodySb.i().p("for(int i = 0; i < " + ((NaiveBayesOutput)this._output)._levels.length + "; i++) {").nl();
        bodySb.i(1).p("nums[i] = Math.log(APRIORI[i]);").nl();
        bodySb.i(1).p("for(int j = 0; j < " + ((NaiveBayesOutput)this._output)._ncats + "; j++) {").nl();
        bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        bodySb.i(2).p("int level = (int)data[j];").nl();
        bodySb.i(2).p("prob = level < " + ((NaiveBayesOutput)this._output)._pcond_raw.length + " ? PCOND[j][i][level] : " + (((NaiveBayesParameters)this._parms)._laplace == 0.0 ? Integer.valueOf(0) : ((NaiveBayesParameters)this._parms)._laplace + "/(RESCNT[i] + " + ((NaiveBayesParameters)this._parms)._laplace + "*DOMLEN[j])")).p(";").nl();
        bodySb.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters)this._parms)._eps_prob + " ? " + ((NaiveBayesParameters)this._parms)._min_prob + " : prob);").nl();
        bodySb.i(1).p("}").nl();
        bodySb.i(1).p("for(int j = " + ((NaiveBayesOutput)this._output)._ncats + "; j < data.length; j++) {").nl();
        bodySb.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        bodySb.i(2).p("mean = Double.isNaN(PCOND[j][i][0]) ? 0 : PCOND[j][i][0];").nl();
        bodySb.i(2).p("sdev = Double.isNaN(PCOND[j][i][1]) ? 1 : (PCOND[j][i][1] <= " + ((NaiveBayesParameters)this._parms)._eps_sdev + " ? " + ((NaiveBayesParameters)this._parms)._min_sdev + " : PCOND[j][i][1]);").nl();
        bodySb.i(2).p("prob = Math.exp(-((data[j]-mean)*(data[j]-mean))/(2.*sdev*sdev)) / (sdev*Math.sqrt(2.*Math.PI));").nl();
        bodySb.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters)this._parms)._eps_prob + " ? " + ((NaiveBayesParameters)this._parms)._min_prob + " : prob);").nl();
        bodySb.i(1).p("}").nl();
        bodySb.i().p("}").nl();
        bodySb.i().p("double sum;").nl();
        bodySb.i().p("for(int i = 0; i < nums.length; i++) {").nl();
        bodySb.i(1).p("sum = 0;").nl();
        bodySb.i(1).p("for(int j = 0; j < nums.length; j++) {").nl();
        bodySb.i(2).p("sum += Math.exp(nums[j]-nums[i]);").nl();
        bodySb.i(1).p("}").nl();
        bodySb.i(1).p("preds[i+1] = 1/sum;").nl();
        bodySb.i().p("}").nl();
        bodySb.i().p("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this.defaultThreshold() + ");").nl();
    }

    public static class NaiveBayesOutput
    extends Model.Output {
        public TwoDimTable _apriori;
        public double[] _apriori_raw;
        public TwoDimTable[] _pcond;
        public double[][][] _pcond_raw;
        public int[] _rescnt;
        public String[] _levels;
        public int _ncats;

        public NaiveBayesOutput(NaiveBayes b) {
            super((ModelBuilder)b);
        }
    }

    public static class NaiveBayesParameters
    extends Model.Parameters {
        public double _laplace = 0.0;
        public double _eps_sdev = 0.0;
        public double _min_sdev = 0.001;
        public double _eps_prob = 0.0;
        public double _min_prob = 0.001;
        public boolean _compute_metrics = true;
    }
}

