/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.uplift;

import hex.Model;
import hex.ModelCategory;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.Sample;
import hex.tree.Score;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import hex.tree.SharedTreeModel;
import hex.tree.uplift.UpliftDRFModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.apache.log4j.Logger;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

public class UpliftDRF
extends SharedTree<UpliftDRFModel, UpliftDRFModel.UpliftDRFParameters, UpliftDRFModel.UpliftDRFOutput> {
    private static final Logger LOG = Logger.getLogger(UpliftDRF.class);

    public UpliftDRF(UpliftDRFModel.UpliftDRFParameters parms) {
        super(parms);
        this.init(false);
    }

    public UpliftDRF(UpliftDRFModel.UpliftDRFParameters parms, Key<UpliftDRFModel> key) {
        super(parms, key);
        this.init(false);
    }

    public UpliftDRF(UpliftDRFModel.UpliftDRFParameters parms, Job job) {
        super(parms, job);
        this.init(false);
    }

    public UpliftDRF(boolean startup_once) {
        super(new UpliftDRFModel.UpliftDRFParameters(), startup_once);
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    @Override
    public boolean havePojo() {
        return false;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.BinomialUplift};
    }

    protected SharedTree.Driver trainModelImpl() {
        return new UpliftDRFDriver();
    }

    @Override
    public boolean scoreZeroTrees() {
        return false;
    }

    @Override
    public boolean providesVarImp() {
        return false;
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries < 1 && ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries != -1 && ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries != -2) {
            this.error("_mtries", "mtries must be -1 (converted to sqrt(features)) or -2 (All features) or >= 1 but it is " + ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries);
        }
        if (this._train != null) {
            int ncols = this._train.numCols();
            if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries != -1 && ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries != -2 && (1 > ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries || ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries >= ncols)) {
                this.error("_mtries", "Computed mtries should be -1 or -2 or in interval [1," + ncols + "[ but it is " + ((UpliftDRFModel.UpliftDRFParameters)this._parms)._mtries);
            }
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._sample_rate == 1.0 && this._valid == null) {
            this.warn("_sample_rate", "Sample rate is 100% and no validation dataset. There are no out-of-bag data to compute error estimates on the training data!");
        }
        if (this.hasOffsetCol()) {
            this.error("_offset_column", "Offsets are not yet supported for Uplift DRF.");
        }
        if (this.hasWeightCol()) {
            this.error("_weight_column", "Weights are not yet supported for Uplift DRF.");
        }
        if (this.hasFoldCol()) {
            this.error("_fold_column", "Cross-validation is not yet supported for Uplift DRF.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._nfolds > 0) {
            this.error("_nfolds", "Cross-validation is not yet supported for Uplift DRF.");
        }
        if (this._nclass == 1) {
            this.error("_distribution", "UpliftDRF currently support binomial classification problems only.");
        }
        if (this._nclass > 2 || ((UpliftDRFModel.UpliftDRFParameters)this._parms)._distribution.equals((Object)DistributionFamily.multinomial)) {
            this.error("_distribution", "UpliftDRF currently does not support multinomial distribution.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._treatment_column == null) {
            this.error("_treatment_column", "The treatment column has to be defined.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._custom_distribution_func != null) {
            this.error("_custom_distribution_func", "The custom distribution is not yet supported for Uplift DRF.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._custom_metric_func != null) {
            this.error("_custom_metric_func", "The custom metric is not yet supported for Uplift DRF.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._stopping_metric != ScoreKeeper.StoppingMetric.AUTO) {
            this.error("_stopping_metric", "The early stopping is not yet supported for Uplift DRF.");
        }
        if (((UpliftDRFModel.UpliftDRFParameters)this._parms)._stopping_rounds != 0) {
            this.error("_stopping_rounds", "The early stopping is not yet supported for Uplift DRF.");
        }
    }

    @Override
    protected double score1(Chunk[] chks, double weight, double offset, double[] fs, int row) {
        double sum = 0.0;
        fs[1] = weight * this.chk_tree(chks, 0).atd(row) / this.chk_oobt(chks).atd(row);
        fs[2] = weight * this.chk_tree(chks, 1).atd(row) / this.chk_oobt(chks).atd(row);
        fs[0] = fs[1] - fs[2];
        return sum;
    }

    protected DHistogram[][][] buildLayer(Frame fr, int nbins, DTree tree, int[] leafs, DHistogram[][][] hcs, boolean build_tree_one_node) {
        SharedTree.ScoreBuildOneTree sb1t = null;
        Vec[] vecs = fr.vecs();
        int k = 0;
        if (tree != null) {
            int selectedCol = this._ncols + 2;
            String[] fr2cols = Arrays.copyOf(fr._names, selectedCol);
            Vec[] fr2vecs = Arrays.copyOf(vecs, selectedCol);
            Frame fr2 = new Frame(fr2cols, fr2vecs);
            if (this.isSupervised() && fr2.find(((UpliftDRFModel.UpliftDRFParameters)this._parms)._response_column) == -1) {
                fr2.add(((UpliftDRFModel.UpliftDRFParameters)this._parms)._response_column, fr.vec(((UpliftDRFModel.UpliftDRFParameters)this._parms)._response_column));
            }
            int respIdx = fr2.find(((UpliftDRFModel.UpliftDRFParameters)this._parms)._response_column);
            int weightIdx = fr2.find(((UpliftDRFModel.UpliftDRFParameters)this._parms)._weights_column);
            int treatmentIdx = fr2.find(((UpliftDRFModel.UpliftDRFParameters)this._parms)._treatment_column);
            int predsIdx = fr2.numCols();
            fr2.add(fr._names[this.idx_tree(k)], vecs[this.idx_tree(k)]);
            int workIdx = fr2.numCols();
            fr2.add(fr._names[this.idx_work(k)], vecs[this.idx_work(k)]);
            int nidIdx = fr2.numCols();
            fr2.add(fr._names[this.idx_nids(k)], vecs[this.idx_nids(k)]);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("Building a layer for class " + k + ":\n" + fr2.toTwoDimTable()));
            }
            sb1t = new SharedTree.ScoreBuildOneTree(this, k, nbins, tree, leafs, hcs, fr2, build_tree_one_node, this._improvPerVar, ((UpliftDRFModel.UpliftDRFParameters)((UpliftDRFModel)this._model)._parms)._distribution, respIdx, weightIdx, predsIdx, workIdx, nidIdx, treatmentIdx);
            H2O.submitTask((H2O.H2OCountedCompleter)sb1t);
        }
        boolean did_split = false;
        if (sb1t != null) {
            sb1t.join();
            if (sb1t._did_split) {
                did_split = true;
            }
            if (LOG.isTraceEnabled()) {
                LOG.info((Object)("Done with this layer for class " + k + ":\n" + new Frame(new String[]{"TREE", "WORK", "NIDS"}, new Vec[]{vecs[this.idx_tree(k)], vecs[this.idx_work(k)], vecs[this.idx_nids(k)]}).toTwoDimTable()));
            }
        }
        return did_split ? hcs : (DHistogram[][][])null;
    }

    @Override
    protected TwoDimTable createScoringHistoryTable() {
        UpliftDRFModel.UpliftDRFOutput out = (UpliftDRFModel.UpliftDRFOutput)((UpliftDRFModel)this._model)._output;
        return UpliftDRF.createUpliftScoringHistoryTable(out, out._scored_train, out._scored_valid, this._job, out._training_time_ms, ((UpliftDRFModel.UpliftDRFParameters)this._parms)._custom_metric_func != null);
    }

    static TwoDimTable createUpliftScoringHistoryTable(Model.Output _output, ScoreKeeper[] _scored_train, ScoreKeeper[] _scored_valid, Job job, long[] _training_time_ms, boolean hasCustomMetric) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Timestamp");
        colTypes.add("string");
        colFormat.add("%s");
        colHeaders.add("Duration");
        colTypes.add("string");
        colFormat.add("%s");
        colHeaders.add("Number of Trees");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Training AUUC nbins");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Training AUUC");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Training AUUC normalized");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Training Qini value");
        colTypes.add("double");
        colFormat.add("%.5f");
        if (hasCustomMetric) {
            colHeaders.add("Training Custom");
            colTypes.add("double");
            colFormat.add("%.5f");
        }
        if (_output._validation_metrics != null) {
            colHeaders.add("Validation AUUC nbins");
            colTypes.add("int");
            colFormat.add("%d");
            colHeaders.add("Validation AUUC");
            colTypes.add("double");
            colFormat.add("%.5f");
            colHeaders.add("Validation AUUC normalized");
            colTypes.add("double");
            colFormat.add("%.5f");
            colHeaders.add("Validation Qini value");
            colTypes.add("double");
            colFormat.add("%.5f");
            if (hasCustomMetric) {
                colHeaders.add("Validation Custom");
                colTypes.add("double");
                colFormat.add("%.5f");
            }
        }
        int rows = 0;
        for (int i = 0; i < _scored_train.length; ++i) {
            if (i != 0 && Double.isNaN(_scored_train[i]._AUUC) && (_scored_valid == null || Double.isNaN(_scored_valid[i]._AUUC))) continue;
            ++rows;
        }
        TwoDimTable table = new TwoDimTable("Scoring History", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        for (int i = 0; i < _scored_train.length; ++i) {
            if (i != 0 && Double.isNaN(_scored_train[i]._AUUC) && (_scored_valid == null || Double.isNaN(_scored_valid[i]._AUUC))) continue;
            int col = 0;
            DateTimeFormatter fmt = DateTimeFormat.forPattern((String)"yyyy-MM-dd HH:mm:ss");
            table.set(row, col++, (Object)fmt.print(_training_time_ms[i]));
            table.set(row, col++, (Object)PrettyPrint.msecs((long)(_training_time_ms[i] - job.start_time()), (boolean)true));
            table.set(row, col++, (Object)i);
            ScoreKeeper st = _scored_train[i];
            table.set(row, col++, (Object)st._auuc_nbins);
            table.set(row, col++, (Object)st._AUUC);
            table.set(row, col++, (Object)st._auuc_normalized);
            table.set(row, col++, (Object)st._qini);
            if (hasCustomMetric) {
                table.set(row, col++, (Object)st._custom_metric);
            }
            if (_output._validation_metrics != null) {
                st = _scored_valid[i];
                table.set(row, col++, (Object)st._auuc_nbins);
                table.set(row, col++, (Object)st._AUUC);
                table.set(row, col++, (Object)st._auuc_normalized);
                table.set(row, col++, (Object)st._qini);
                if (hasCustomMetric) {
                    table.set(row, col++, (Object)st._custom_metric);
                }
            }
            ++row;
        }
        return table;
    }

    @Override
    protected UpliftScoreExtension makeScoreExtension() {
        return new UpliftScoreExtension();
    }

    private static class UpliftScoreExtension
    extends Score.ScoreExtension {
        @Override
        protected double getPrediction(double[] cdist) {
            return cdist[1] - cdist[2];
        }

        @Override
        protected int[] getResponseComplements(SharedTreeModel<?, ?, ?> m) {
            return new int[]{((SharedTreeModel.SharedTreeOutput)m._output).treatmentIdx()};
        }
    }

    private class UpliftDRFDriver
    extends SharedTree.Driver {
        private UpliftDRFDriver() {
        }

        @Override
        protected boolean doOOBScoring() {
            return true;
        }

        @Override
        protected void initializeModelSpecifics() {
            UpliftDRF.this._mtry_per_tree = Math.max(1, (int)(((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._col_sample_rate_per_tree * (double)UpliftDRF.this._ncols));
            if (1 > UpliftDRF.this._mtry_per_tree || UpliftDRF.this._mtry_per_tree > UpliftDRF.this._ncols) {
                throw new IllegalArgumentException("Computed mtry_per_tree should be in interval <1," + UpliftDRF.this._ncols + "> but it is " + UpliftDRF.this._mtry_per_tree);
            }
            if (((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._mtries == -2) {
                UpliftDRF.this._mtry = UpliftDRF.this._ncols;
            } else if (((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._mtries == -1) {
                UpliftDRF.this._mtry = UpliftDRF.this.isClassifier() ? Math.max((int)Math.sqrt(UpliftDRF.this._ncols), 1) : Math.max(UpliftDRF.this._ncols / 3, 1);
            } else {
                UpliftDRF.this._mtry = ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._mtries;
            }
            if (1 > UpliftDRF.this._mtry || UpliftDRF.this._mtry > UpliftDRF.this._ncols) {
                throw new IllegalArgumentException("Computed mtry should be in interval <1," + UpliftDRF.this._ncols + "> but it is " + UpliftDRF.this._mtry);
            }
            new MRTask(){

                public void map(Chunk[] chks) {
                    Chunk cy = UpliftDRF.this.chk_resp(chks);
                    for (int i = 0; i < cy._len; ++i) {
                        if (cy.isNA(i)) continue;
                        int cls = (int)cy.at8(i);
                        UpliftDRF.this.chk_work(chks, cls).set(i, 1L);
                    }
                }
            }.doAll(UpliftDRF.this._train);
        }

        @Override
        protected boolean buildNextKTrees() {
            DTree[] ktrees = new DTree[UpliftDRF.this._nclass];
            int[] leafs = new int[UpliftDRF.this._nclass];
            this.growTrees(ktrees, leafs, UpliftDRF.this._rand);
            UpliftCollectPreds cp = (UpliftCollectPreds)new UpliftCollectPreds(ktrees, leafs).doAll(UpliftDRF.this._train, ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._build_tree_one_node);
            ((UpliftDRFModel.UpliftDRFOutput)((UpliftDRFModel)((UpliftDRF)UpliftDRF.this)._model)._output).addKTrees(ktrees);
            return false;
        }

        private void growTrees(DTree[] ktrees, int[] leafs, Random rand) {
            DHistogram[][][] hcs = new DHistogram[UpliftDRF.this._nclass][1][UpliftDRF.this._ncols];
            int adj_nbins = Math.max(((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._nbins_top_level, ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._nbins);
            long rseed = rand.nextLong();
            for (int k = 0; k < UpliftDRF.this._nclass; ++k) {
                if (((UpliftDRFModel.UpliftDRFOutput)((UpliftDRFModel)((UpliftDRF)UpliftDRF.this)._model)._output)._distribution[k] == 0.0) continue;
                ktrees[k] = new DTree(UpliftDRF.this._train, UpliftDRF.this._ncols, UpliftDRF.this._mtry, UpliftDRF.this._mtry_per_tree, rseed, (SharedTreeModel.SharedTreeParameters)UpliftDRF.this._parms);
                new DTree.UndecidedNode(ktrees[k], -1, DHistogram.initialHist(UpliftDRF.this._train, UpliftDRF.this._ncols, adj_nbins, hcs[k][0], rseed, (SharedTreeModel.SharedTreeParameters)UpliftDRF.this._parms, this.getGlobalQuantilesKeys(), null, false, null), null, null);
            }
            Sample s = (Sample)((Sample)new Sample(ktrees[0], ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._sample_rate, ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._sample_rate_per_class).dfork(null, new Frame(new Vec[]{UpliftDRF.this.vec_nids(UpliftDRF.this._train, 0), UpliftDRF.this.vec_resp(UpliftDRF.this._train)}), ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._build_tree_one_node)).getResult();
            for (int depth = 0; depth < ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._max_depth && (hcs = UpliftDRF.this.buildLayer(UpliftDRF.this._train, ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._nbins, ktrees[0], leafs, hcs, ((UpliftDRFModel.UpliftDRFParameters)UpliftDRF.this._parms)._build_tree_one_node)) != null; ++depth) {
            }
            DTree treeTr = ktrees[0];
            ktrees[1] = new DTree(ktrees[0]);
            DTree treeCt = ktrees[1];
            int leaf = leafs[0] = treeTr.len();
            for (int nid = 0; nid < leaf; ++nid) {
                if (!(treeTr.node(nid) instanceof DTree.DecidedNode)) continue;
                DTree.DecidedNode dnTr = treeTr.decided(nid);
                DTree.DecidedNode dnCt = treeCt.decided(nid);
                if (dnTr._split == null) {
                    if (nid != 0) continue;
                    DTree.LeafNode lnTr = new DTree.LeafNode(treeTr, -1, 0);
                    lnTr._pred = (float)((UpliftDRFModel.UpliftDRFOutput)((UpliftDRFModel)((UpliftDRF)UpliftDRF.this)._model)._output)._priorClassDist[1];
                    DTree.LeafNode lnCt = new DTree.LeafNode(treeCt, -1, 0);
                    lnCt._pred = (float)((UpliftDRFModel.UpliftDRFOutput)((UpliftDRFModel)((UpliftDRF)UpliftDRF.this)._model)._output)._priorClassDist[0];
                    continue;
                }
                for (int i = 0; i < dnTr._nids.length; ++i) {
                    int cnid = dnTr._nids[i];
                    if (cnid != -1 && !(treeTr.node(cnid) instanceof DTree.UndecidedNode) && (!(treeTr.node(cnid) instanceof DTree.DecidedNode) || ((DTree.DecidedNode)treeTr.node((int)cnid))._split != null)) continue;
                    DTree.LeafNode lnTr = new DTree.LeafNode(treeTr, nid);
                    lnTr._pred = (float)dnTr.predTreatment(i);
                    dnTr._nids[i] = lnTr.nid();
                    DTree.LeafNode lnCt = new DTree.LeafNode(treeCt, nid);
                    lnCt._pred = (float)dnCt.predControl(i);
                    dnCt._nids[i] = lnCt.nid();
                }
            }
        }

        protected UpliftDRFModel makeModel(Key modelKey, UpliftDRFModel.UpliftDRFParameters parms) {
            return new UpliftDRFModel((Key<UpliftDRFModel>)modelKey, parms, new UpliftDRFModel.UpliftDRFOutput(UpliftDRF.this));
        }

        private class UpliftCollectPreds
        extends MRTask<UpliftCollectPreds> {
            final DTree[] _trees;
            double allRows;

            UpliftCollectPreds(DTree[] trees, int[] leafs) {
                this._trees = trees;
            }

            public void map(Chunk[] chks) {
                Chunk y = UpliftDRF.this.chk_resp(chks);
                Chunk oobt = UpliftDRF.this.chk_oobt(chks);
                C0DChunk weights = UpliftDRF.this.hasWeightCol() ? UpliftDRF.this.chk_weight(chks) : new C0DChunk(1.0, chks[0]._len);
                for (int row = 0; row < oobt._len; ++row) {
                    double weight = weights.atd(row);
                    boolean wasOOBRow = ScoreBuildHistogram.isOOBRow((int)UpliftDRF.this.chk_nids(chks, 0).at8(row));
                    Chunk nids = UpliftDRF.this.chk_nids(chks, 0);
                    Chunk nids1 = UpliftDRF.this.chk_nids(chks, 1);
                    if (weight != 0.0) {
                        DTree treeT = this._trees[0];
                        DTree treeC = this._trees[1];
                        if (treeT == null) continue;
                        int nid = (int)nids.at8(row);
                        if (wasOOBRow) {
                            int leafnid;
                            if (treeT.node(nid = ScoreBuildHistogram.oob2Nid(nid)) instanceof DTree.UndecidedNode) {
                                nid = treeT.node(nid).pid();
                            }
                            if (treeT.root() instanceof DTree.LeafNode) {
                                leafnid = 0;
                            } else {
                                DTree.DecidedNode dn = treeT.decided(nid);
                                if (dn._split == null) {
                                    dn = treeT.decided(treeT.node(nid).pid());
                                }
                                leafnid = dn.getChildNodeID(chks, row);
                            }
                            Chunk ct1 = UpliftDRF.this.chk_tree(chks, 0);
                            ct1.set(row, (float)(ct1.atd(row) + ((DTree.LeafNode)treeT.node(leafnid)).pred()));
                            Chunk ct0 = UpliftDRF.this.chk_tree(chks, 1);
                            ct0.set(row, (float)(ct0.atd(row) + ((DTree.LeafNode)treeC.node(leafnid)).pred()));
                        }
                    }
                    nids.set(row, 0L);
                    nids1.set(row, 0L);
                    if (wasOOBRow) {
                        oobt.set(row, oobt.atd(row) + weight);
                    }
                    if (weight == 0.0 || !wasOOBRow || y.isNA(row)) continue;
                    this.allRows += weight;
                }
            }

            public void reduce(UpliftCollectPreds mrt) {
                this.allRows += mrt.allRows;
            }
        }
    }

    public static enum UpliftMetricType {
        AUTO,
        KL,
        ChiSquared,
        Euclidean;

    }
}

