/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Distribution;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ScoreKeeper;
import hex.VarImp;
import hex.tree.CompressedTree;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.TreeJCodeGen;
import hex.tree.TreeStats;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Key;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.JCodeSB;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.RandomUtils;
import water.util.SB;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput>
extends Model<M, P, O> {
    public double deviance(double w, double y, double f) {
        return new Distribution(((SharedTreeParameters)this._parms)._distribution, ((SharedTreeParameters)this._parms)._tweedie_power).deviance(w, y, f);
    }

    public final VarImp varImp() {
        return ((SharedTreeOutput)this._output)._varimp;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((SharedTreeOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((SharedTreeOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public SharedTreeModel(Key selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 1.0, 0.0);
    }

    protected double[] score0(double[] data, double[] preds, double weight, double offset) {
        Arrays.fill(preds, 0.0);
        for (int tidx = 0; tidx < ((SharedTreeOutput)this._output)._treeKeys.length; ++tidx) {
            this.score0(data, preds, tidx);
        }
        return preds;
    }

    private void score0(double[] data, double[] preds, int treeIdx) {
        Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[treeIdx];
        for (int c = 0; c < keys.length; ++c) {
            if (keys[c] == null) continue;
            double pred = ((CompressedTree)DKV.get(keys[c]).get()).score(data);
            assert (!Double.isInfinite(pred));
            int n = keys.length == 1 ? 0 : c + 1;
            preds[n] = preds[n] + pred;
        }
    }

    protected Futures remove_impl(Futures fs) {
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Key<CompressedTree>[] ks;
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                if (k == null) continue;
                k.remove(fs);
            }
        }
        return super.remove_impl(fs);
    }

    protected boolean toJavaCheckTooBig() {
        return this._output == null || (float)((SharedTreeOutput)this._output)._treeStats._num_trees * ((SharedTreeOutput)this._output)._treeStats._mean_leaves > 1000000.0f;
    }

    protected boolean binomialOpt() {
        return true;
    }

    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nfeatures() { return " + ((SharedTreeOutput)this._output).nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + ((SharedTreeOutput)this._output).nclasses() + "; }").nl();
        return sb;
    }

    protected void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, final boolean verboseCode) {
        final int nclass = ((SharedTreeOutput)this._output).nclasses();
        body.ip("java.util.Arrays.fill(preds,0);").nl();
        body.ip("double[] fdata = hex.genmodel.GenModel.SharedTree_clean(data);").nl();
        final String mname = JCodeGen.toJavaId((String)this._key.toString());
        int t = 0;
        while (t < ((SharedTreeOutput)this._output)._treeKeys.length) {
            this.toJavaForestName(body.i(), mname, t).p(".score0(fdata,preds);").nl();
            final int treeIdx = t++;
            fileCtx.add((Object)new CodeGenerator(){

                public void generate(JCodeSB out) {
                    int c;
                    out.nl();
                    SharedTreeModel.this.toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
                    out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
                    for (c = 0; c < nclass; ++c) {
                        if (SharedTreeModel.this.binomialOpt() && c == 1 && nclass == 2) continue;
                        SharedTreeModel.this.toJavaTreeName(out.ip("preds[").p(nclass == 1 ? 0 : c + 1).p("] += "), mname, treeIdx, c).p(".score0(fdata);").nl();
                    }
                    out.di(1).ip("}").nl();
                    out.di(1).ip("}").nl();
                    for (c = 0; c < nclass; ++c) {
                        if (SharedTreeModel.this.binomialOpt() && c == 1 && nclass == 2) continue;
                        String javaClassName = SharedTreeModel.this.toJavaTreeName(new SB(), mname, treeIdx, c).toString();
                        CompressedTree ct = ((SharedTreeOutput)SharedTreeModel.this._output).ctree(treeIdx, c);
                        SB sb = new SB();
                        new TreeJCodeGen(SharedTreeModel.this, ct, sb, javaClassName, verboseCode).generate();
                        out.p((JCodeSB)sb);
                    }
                }
            });
        }
        this.toJavaUnifyPreds(body);
    }

    protected abstract void toJavaUnifyPreds(SBPrintStream var1);

    protected <T extends JCodeSB> T toJavaTreeName(T sb, String mname, int t, int c) {
        return (T)sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
    }

    protected <T extends JCodeSB> T toJavaForestName(T sb, String mname, int t) {
        return (T)sb.p(mname).p("_Forest_").p(t);
    }

    public List<Key> getPublishedKeys() {
        assert (((SharedTreeOutput)this._output)._ntrees == ((SharedTreeOutput)this._output)._treeKeys.length) : "Tree model is inconsistent: number of trees do not match number of tree keys!";
        List superP = super.getPublishedKeys();
        ArrayList<Key> p = new ArrayList<Key>(((SharedTreeOutput)this._output)._ntrees * ((SharedTreeOutput)this._output).nclasses());
        for (int i = 0; i < ((SharedTreeOutput)this._output)._treeKeys.length; ++i) {
            for (int j = 0; j < ((SharedTreeOutput)this._output)._treeKeys[i].length; ++j) {
                p.add(((SharedTreeOutput)this._output)._treeKeys[i][j]);
            }
        }
        p.addAll(superP);
        return p;
    }

    public static abstract class SharedTreeOutput
    extends Model.Output {
        public double _init_f;
        public int _ntrees = 0;
        public final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys;
        public ScoreKeeper[] _scored_train;
        public ScoreKeeper[] _scored_valid;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public TwoDimTable _variable_importances;
        public VarImp _varimp;

        public SharedTreeOutput(SharedTree b, double mse_train, double mse_valid) {
            super((ModelBuilder)b);
            this._treeKeys = new Key[this._ntrees][];
            this._treeStats = new TreeStats();
            this._scored_train = new ScoreKeeper[]{new ScoreKeeper(mse_train)};
            this._scored_valid = new ScoreKeeper[]{new ScoreKeeper(mse_valid)};
            this._modelClassDist = this._priorClassDist;
        }

        public void addKTrees(DTree[] trees) {
            assert (this.nclasses() == trees.length);
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            this._treeKeys[this._ntrees] = new Key[trees.length];
            Key[] keys = this._treeKeys[this._ntrees];
            Futures fs = new Futures();
            for (int i = 0; i < this.nclasses(); ++i) {
                if (trees[i] == null) continue;
                CompressedTree ct = trees[i].compress(this._ntrees, i);
                keys[i] = ct._key;
                DKV.put((Key)keys[i], (Iced)ct, (Futures)fs);
                this._treeStats.updateBy(trees[i]);
            }
            ++this._ntrees;
            this._scored_train = (ScoreKeeper[])ArrayUtils.copyAndFillOf((Object[])this._scored_train, (int)(this._ntrees + 1), (Object)new ScoreKeeper());
            this._scored_valid = this._scored_valid != null ? (ScoreKeeper[])ArrayUtils.copyAndFillOf((Object[])this._scored_valid, (int)(this._ntrees + 1), (Object)new ScoreKeeper()) : null;
            this._training_time_ms = ArrayUtils.copyAndFillOf((long[])this._training_time_ms, (int)(this._ntrees + 1), (long)System.currentTimeMillis());
            fs.blockForPending();
        }

        public CompressedTree ctree(int tnum, int knum) {
            return (CompressedTree)this._treeKeys[tnum][knum].get();
        }

        public String toStringTree(int tnum, int knum) {
            return this.ctree(tnum, knum).toString(this);
        }
    }

    public static abstract class SharedTreeParameters
    extends Model.Parameters {
        public int _ntrees = 50;
        public int _max_depth = 5;
        public double _min_rows = 10.0;
        public int _nbins = 20;
        public int _nbins_cats = 1024;
        public double _r2_stopping = 0.999999;
        public long _seed = RandomUtils.getRNG((long[])new long[]{System.nanoTime()}).nextLong();
        public int _nbins_top_level = 1024;
        public boolean _build_tree_one_node = false;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        private static String[] MODIFIABLE_BY_CHECKPOINT_FIELDS = new String[]{"_ntrees", "_max_depth", "_min_rows", "_r2_stopping"};

        protected String[] getCheckpointModifiableFields() {
            return MODIFIABLE_BY_CHECKPOINT_FIELDS;
        }

        public void validateWithCheckpoint(SharedTreeParameters checkpointParameters) {
            Field[] allFields;
            String[] fieldNames = this.getCheckpointModifiableFields();
            for (Field f : allFields = ((Object)((Object)this)).getClass().getDeclaredFields()) {
                for (String modifiableFieldName : fieldNames) {
                    if (modifiableFieldName.equals(f.getName())) continue;
                    try {
                        if (PojoUtils.equals((Object)((Object)this), (Field)f, (Object)((Object)checkpointParameters), (Field)((Object)((Object)checkpointParameters)).getClass().getDeclaredField(f.getName()))) continue;
                        throw new H2OIllegalArgumentException(f.getName(), "TreeBuilder", (Object)"Field cannot be modified if checkpoint is specified!");
                    }
                    catch (NoSuchFieldException e) {
                        throw new H2OIllegalArgumentException(f.getName(), "TreeBuilder", (Object)"Field is not supported by checkpoint!");
                    }
                }
            }
        }
    }
}

