/*
 * Decompiled with CFR 0.152.
 */
package hex.naivebayes;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.SupervisedModel;
import hex.SupervisedModelBuilder;
import hex.naivebayes.NaiveBayesModel;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.NaiveBayesV3;
import java.util.ArrayList;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

public class NaiveBayes
extends SupervisedModelBuilder<NaiveBayesModel, NaiveBayesModel.NaiveBayesParameters, NaiveBayesModel.NaiveBayesOutput> {
    public ModelBuilderSchema schema() {
        return new NaiveBayesV3();
    }

    public Job<NaiveBayesModel> trainModel() {
        return this.start(new NaiveBayesDriver(), 0L);
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Unknown};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    protected void checkMemoryFootPrint() {
        long mem_usage = (this._train.numCols() - 1) * this._train.lastVec().cardinality();
        String[][] domains = this._train.domains();
        long count = 0L;
        for (int i = 0; i < this._train.numCols() - 1; ++i) {
            count += domains[i] == null ? 2L : (long)domains[i].length;
        }
        mem_usage *= count;
        long max_mem = H2O.SELF.get_max_mem();
        if ((mem_usage *= 8L) > max_mem) {
            String msg = "Conditional probabilities won't fit in the driver node's memory (" + PrettyPrint.bytes((long)mem_usage) + " > " + PrettyPrint.bytes((long)max_mem) + ") - try reducing the number of columns, the number of response classes or the number of categorical factors of the predictors.";
            this.error("_train", msg);
            this.cancel(msg);
        }
    }

    public NaiveBayes(NaiveBayesModel.NaiveBayesParameters parms) {
        super("NaiveBayes", (SupervisedModel.SupervisedParameters)parms);
        this.init(false);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (this._response != null && !this._response.isEnum()) {
            this.error("_response", "Response must be a categorical column");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._laplace < 0.0) {
            this.error("_laplace", "Laplace smoothing must be an integer >= 0");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._min_sdev < 1.0E-10) {
            this.error("_min_sdev", "Min. standard deviation must be at least 1e-10");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._eps_sdev < 0.0) {
            this.error("_eps_sdev", "Threshold for standard deviation must be positive");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._min_prob < 1.0E-10) {
            this.error("_min_prob", "Min. probability must be at least 1e-10");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._eps_prob < 0.0) {
            this.error("_eps_prob", "Threshold for probability must be positive");
        }
        this.hide("_balance_classes", "Balance classes is not applicable to NaiveBayes.");
        this.hide("_class_sampling_factors", "Class sampling factors is not applicable to NaiveBayes.");
        this.hide("_max_after_balance_size", "Max after balance size is not applicable to NaiveBayes.");
        if (expensive && this.error_count() == 0) {
            this.checkMemoryFootPrint();
        }
    }

    private static boolean couldBeBool(Vec v) {
        return v != null && v.isInt() && v.min() + 1.0 == v.max();
    }

    private TwoDimTable createModelSummaryTable(NaiveBayesModel.NaiveBayesOutput output) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Number of Response Levels");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Min Apriori Probability");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Max Apriori Probability");
        colTypes.add("double");
        colFormat.add("%.5f");
        double apriori_min = output._apriori_raw[0];
        double apriori_max = output._apriori_raw[0];
        for (int i = 1; i < output._apriori_raw.length; ++i) {
            if (output._apriori_raw[i] < apriori_min) {
                apriori_min = output._apriori_raw[i];
                continue;
            }
            if (!(output._apriori_raw[i] > apriori_max)) continue;
            apriori_max = output._apriori_raw[i];
        }
        boolean rows = true;
        TwoDimTable table = new TwoDimTable("Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        int col = 0;
        table.set(row, col++, (Object)output._apriori_raw.length);
        table.set(row, col++, (Object)apriori_min);
        table.set(row, col++, (Object)apriori_max);
        return table;
    }

    private static class NBTask
    extends MRTask<NBTask> {
        final DataInfo _dinfo;
        final String[][] _domains;
        final int _nrescat;
        final int _npreds;
        public int _nobs;
        public int[] _rescnt;
        public int[][][] _jntcnt;
        public double[][][] _jntsum;

        public NBTask(DataInfo dinfo, int nres) {
            this._dinfo = dinfo;
            this._nrescat = nres;
            this._domains = dinfo._adaptedFrame.domains();
            this._npreds = dinfo._adaptedFrame.numCols() - 1;
            assert (this._npreds == dinfo._nums + dinfo._cats);
            assert (this._nrescat == this._domains[this._npreds].length);
        }

        public void map(Chunk[] chks) {
            int i;
            this._nobs = 0;
            this._rescnt = new int[this._nrescat];
            if (this._dinfo._cats > 0) {
                this._jntcnt = new int[this._dinfo._cats][][];
                for (i = 0; i < this._dinfo._cats; ++i) {
                    this._jntcnt[i] = new int[this._nrescat][this._domains[i].length];
                }
            }
            if (this._dinfo._nums > 0) {
                this._jntsum = new double[this._dinfo._nums][][];
                for (i = 0; i < this._dinfo._nums; ++i) {
                    this._jntsum[i] = new double[this._nrescat][2];
                }
            }
            Chunk res = chks[this._npreds];
            block2: for (int row = 0; row < chks[0]._len; ++row) {
                int col;
                for (int col2 = 0; col2 < chks.length; ++col2) {
                    if (Double.isNaN(chks[col2].atd(row))) continue block2;
                }
                int rlevel = (int)res.atd(row);
                for (col = 0; col < this._dinfo._cats; ++col) {
                    int plevel = (int)chks[col].atd(row);
                    int[] nArray = this._jntcnt[col][rlevel];
                    int n = plevel;
                    nArray[n] = nArray[n] + 1;
                }
                for (col = 0; col < this._dinfo._nums; ++col) {
                    int cidx = this._dinfo._cats + col;
                    double x = chks[cidx].atd(row);
                    double[] dArray = this._jntsum[col][rlevel];
                    dArray[0] = dArray[0] + x;
                    double[] dArray2 = this._jntsum[col][rlevel];
                    dArray2[1] = dArray2[1] + x * x;
                }
                int n = rlevel;
                this._rescnt[n] = this._rescnt[n] + 1;
                ++this._nobs;
            }
        }

        public void reduce(NBTask nt) {
            int col;
            this._nobs += nt._nobs;
            ArrayUtils.add((int[])this._rescnt, (int[])nt._rescnt);
            if (null != this._jntcnt) {
                for (col = 0; col < this._jntcnt.length; ++col) {
                    ArrayUtils.add((int[][])this._jntcnt[col], (int[][])nt._jntcnt[col]);
                }
            }
            if (null != this._jntsum) {
                for (col = 0; col < this._jntsum.length; ++col) {
                    ArrayUtils.add((double[][])this._jntsum[col], (double[][])nt._jntsum[col]);
                }
            }
        }
    }

    class NaiveBayesDriver
    extends H2O.H2OCountedCompleter<NaiveBayesDriver> {
        NaiveBayesDriver() {
        }

        public void computeStatsFillModel(NaiveBayesModel model, DataInfo dinfo, NBTask tsk) {
            int col;
            int i;
            int col2;
            int i2;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._levels = NaiveBayes.this._response.domain();
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._rescnt = tsk._rescnt;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._ncats = dinfo._cats;
            String[][] domains = ((NaiveBayesModel.NaiveBayesOutput)model._output)._domains;
            double[] apriori = new double[tsk._nrescat];
            double[][][] pcond = new double[tsk._npreds][][];
            for (i2 = 0; i2 < pcond.length; ++i2) {
                int ncnt = domains[i2] == null ? 2 : domains[i2].length;
                pcond[i2] = new double[tsk._nrescat][ncnt];
            }
            for (i2 = 0; i2 < apriori.length; ++i2) {
                apriori[i2] = ((double)tsk._rescnt[i2] + ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace) / ((double)tsk._nobs + (double)tsk._nrescat * ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace);
            }
            for (col2 = 0; col2 < dinfo._cats; ++col2) {
                assert (pcond[col2].length == tsk._nrescat);
                for (i = 0; i < pcond[col2].length; ++i) {
                    for (int j = 0; j < pcond[col2][i].length; ++j) {
                        pcond[col2][i][j] = ((double)tsk._jntcnt[col2][i][j] + ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace) / ((double)tsk._rescnt[i] + (double)domains[col2].length * ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace);
                    }
                }
            }
            for (col2 = 0; col2 < dinfo._nums; ++col2) {
                for (i = 0; i < pcond[0].length; ++i) {
                    double pmean;
                    int cidx = dinfo._cats + col2;
                    double num = tsk._rescnt[i];
                    pcond[cidx][i][0] = pmean = tsk._jntsum[col2][i][0] / num;
                    double pvar = tsk._jntsum[col2][i][1] / (num - 1.0) - pmean * pmean * num / (num - 1.0);
                    pcond[cidx][i][1] = Math.sqrt(pvar);
                }
            }
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._apriori_raw = apriori;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond_raw = pcond;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond = new TwoDimTable[pcond.length];
            String[] rowNames = NaiveBayes.this._response.domain();
            for (col = 0; col < dinfo._cats; ++col) {
                String[] colNames = NaiveBayes.this._train.vec(col).domain();
                Object[] colTypes = new String[colNames.length];
                Object[] colFormats = new String[colNames.length];
                Arrays.fill(colTypes, "double");
                Arrays.fill(colFormats, "%5f");
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond[col] = new TwoDimTable(NaiveBayes.this._train.name(col), null, rowNames, colNames, (String[])colTypes, (String[])colFormats, "Y / " + NaiveBayes.this._train.name(col), (String[][])new String[rowNames.length][], pcond[col]);
            }
            for (col = 0; col < dinfo._nums; ++col) {
                int cidx = dinfo._cats + col;
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond[cidx] = new TwoDimTable(NaiveBayes.this._train.name(cidx), null, rowNames, new String[]{"Mean", "Std_Dev"}, new String[]{"double", "double"}, new String[]{"%5f", "%5f"}, "Y / " + NaiveBayes.this._train.name(cidx), (String[][])new String[rowNames.length][], pcond[cidx]);
            }
            Object[] colTypes = new String[NaiveBayes.this._response.cardinality()];
            Object[] colFormats = new String[NaiveBayes.this._response.cardinality()];
            Arrays.fill(colTypes, "double");
            Arrays.fill(colFormats, "%5f");
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._apriori = new TwoDimTable("Y", null, new String[1], NaiveBayes.this._response.domain(), (String[])colTypes, (String[])colFormats, "", (String[][])new String[1][], (double[][])new double[][]{apriori});
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._model_summary = NaiveBayes.this.createModelSummaryTable((NaiveBayesModel.NaiveBayesOutput)model._output);
            if (((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._compute_metrics) {
                model.score(((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).train()).delete();
                ModelMetricsSupervised mm = (ModelMetricsSupervised)DKV.getGet((Key)((NaiveBayesModel.NaiveBayesOutput)model._output)._model_metrics[((NaiveBayesModel.NaiveBayesOutput)model._output)._model_metrics.length - 1]);
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._training_metrics = mm;
            }
            if (NaiveBayes.this._valid != null) {
                Frame pred = model.score(((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).valid());
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._validation_metrics = (ModelMetrics)DKV.getGet((Key)((NaiveBayesModel.NaiveBayesOutput)model._output)._model_metrics[((NaiveBayesModel.NaiveBayesOutput)model._output)._model_metrics.length - 1]);
                pred.delete();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void compute2() {
            block9: {
                NaiveBayesModel model = null;
                DataInfo dinfo = null;
                try {
                    ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).read_lock_frames((Job)NaiveBayes.this);
                    NaiveBayes.this.init(true);
                    if (NaiveBayes.this.error_count() > 0) {
                        throw new IllegalArgumentException("Found validation errors: " + NaiveBayes.this.validationErrors());
                    }
                    dinfo = new DataInfo(Key.make(), NaiveBayes.this._train, NaiveBayes.this._valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false);
                    model = new NaiveBayesModel(NaiveBayes.this.dest(), (NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms, new NaiveBayesModel.NaiveBayesOutput(NaiveBayes.this));
                    model.delete_and_lock(NaiveBayes.this._key);
                    NaiveBayes.this._train.read_lock(NaiveBayes.this._key);
                    NBTask tsk = (NBTask)new NBTask(dinfo, NaiveBayes.this._response.cardinality()).doAll(dinfo._adaptedFrame);
                    this.computeStatsFillModel(model, dinfo, tsk);
                    model.update(NaiveBayes.this._key);
                    NaiveBayes.this.done();
                }
                catch (Throwable t) {
                    Job thisJob = (Job)DKV.getGet((Key)NaiveBayes.this._key);
                    if (thisJob._state == Job.JobState.CANCELLED) {
                        Log.info((Object[])new Object[]{"Job cancelled by user."});
                        break block9;
                    }
                    t.printStackTrace();
                    NaiveBayes.this.failed(t);
                    throw t;
                }
                finally {
                    NaiveBayes.this._train.unlock(NaiveBayes.this._key);
                    if (model != null) {
                        model.unlock(NaiveBayes.this._key);
                    }
                    if (dinfo != null) {
                        dinfo.remove();
                    }
                    ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).read_unlock_frames((Job)NaiveBayes.this);
                }
            }
            this.tryComplete();
        }
    }
}

