/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.SupervisedModel;
import hex.SupervisedModelBuilder;
import hex.tree.CompressedTree;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.TreeJCodeGen;
import hex.tree.TreeStats;
import java.util.Arrays;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Key;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.SB;
import water.util.TwoDimTable;

public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput>
extends SupervisedModel<M, P, O> {
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((SharedTreeOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((SharedTreeOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public SharedTreeModel(Key selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    protected double[] score0(double[] data, double[] preds) {
        Arrays.fill(preds, 0.0);
        for (int tidx = 0; tidx < ((SharedTreeOutput)this._output)._treeKeys.length; ++tidx) {
            this.score0(data, preds, tidx);
        }
        return preds;
    }

    public void score0(double[] data, double[] preds, int treeIdx) {
        Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[treeIdx];
        for (int c = 0; c < keys.length; ++c) {
            if (keys[c] == null) continue;
            double pred = ((CompressedTree)DKV.get(keys[c]).get()).score(data);
            assert (!Double.isInfinite(pred));
            int n = keys.length == 1 ? 0 : c + 1;
            preds[n] = preds[n] + pred;
        }
    }

    protected Futures remove_impl(Futures fs) {
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Key<CompressedTree>[] ks;
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                if (k == null) continue;
                k.remove(fs);
            }
        }
        return super.remove_impl(fs);
    }

    protected boolean toJavaCheckTooBig() {
        return this._output == null || (float)((SharedTreeOutput)this._output)._treeStats._num_trees * ((SharedTreeOutput)this._output)._treeStats._mean_leaves > 5000.0f;
    }

    protected boolean binomialOpt() {
        return false;
    }

    protected SB toJavaInit(SB sb, SB fileContext) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nfeatures() { return " + ((SharedTreeOutput)this._output).nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + ((SharedTreeOutput)this._output).nclasses() + "; }").nl();
        sb.ip("public ModelCategory getModelCategory() { return ModelCategory." + ((SharedTreeOutput)this._output).getModelCategory() + "; }").nl();
        return sb;
    }

    protected void toJavaPredictBody(SB body, SB classCtx, SB file) {
        int nclass = ((SharedTreeOutput)this._output).nclasses();
        body.ip("java.util.Arrays.fill(preds,0);").nl();
        body.ip("double[] fdata = hex.genmodel.GenModel.SharedTree_clean(data);").nl();
        String mname = JCodeGen.toJavaId((String)this._key.toString());
        for (int t = 0; t < ((SharedTreeOutput)this._output)._treeKeys.length; ++t) {
            int c;
            this.toJavaForestName(body.i(), mname, t).p(".score0(fdata,preds);").nl();
            file.nl();
            this.toJavaForestName(file.ip("class "), mname, t).p(" {").nl().ii(1);
            file.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
            for (c = 0; c < nclass; ++c) {
                if (this.binomialOpt() && c == 1 && nclass == 2) continue;
                this.toJavaTreeName(file.ip("preds[").p(nclass == 1 ? 0 : c + 1).p("] += "), mname, t, c).p(".score0(fdata);").nl();
            }
            file.di(1).ip("}").nl();
            file.di(1).ip("}").nl();
            for (c = 0; c < nclass; ++c) {
                if (this.binomialOpt() && c == 1 && nclass == 2) continue;
                this.toJavaTreeName(file.ip("class "), mname, t, c).p(" {").nl().ii(1);
                CompressedTree ct = ((SharedTreeOutput)this._output).ctree(t, c);
                new TreeJCodeGen(this, ct, file).generate();
                file.di(1).ip("}").nl();
            }
        }
        this.toJavaUnifyPreds(body, file);
    }

    protected abstract void toJavaUnifyPreds(SB var1, SB var2);

    protected SB toJavaTreeName(SB sb, String mname, int t, int c) {
        return sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
    }

    protected SB toJavaForestName(SB sb, String mname, int t) {
        return sb.p(mname).p("_Forest_").p(t);
    }

    public static abstract class SharedTreeOutput
    extends SupervisedModel.SupervisedOutput {
        public double _init_f;
        public int _ntrees = 0;
        public final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys;
        public double[] _mse_train;
        public double[] _mse_valid;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public TwoDimTable _variable_importances;

        public SharedTreeOutput(SharedTree b, double mse_train, double mse_valid) {
            super((SupervisedModelBuilder)b);
            this._treeKeys = new Key[this._ntrees][];
            this._treeStats = new TreeStats();
            this._mse_train = new double[]{mse_train};
            this._mse_valid = new double[]{mse_valid};
        }

        public void addKTrees(DTree[] trees) {
            assert (this.nclasses() == trees.length);
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            this._treeKeys[this._ntrees] = new Key[trees.length];
            Key[] keys = this._treeKeys[this._ntrees];
            Futures fs = new Futures();
            for (int i = 0; i < this.nclasses(); ++i) {
                if (trees[i] == null) continue;
                CompressedTree ct = trees[i].compress(this._ntrees, i);
                keys[i] = ct._key;
                DKV.put((Key)keys[i], (Iced)ct, (Futures)fs);
                this._treeStats.updateBy(trees[i]);
            }
            ++this._ntrees;
            this._mse_train = ArrayUtils.copyAndFillOf((double[])this._mse_train, (int)(this._ntrees + 1), (double)Double.NaN);
            this._mse_valid = this._validation_metrics != null ? ArrayUtils.copyAndFillOf((double[])this._mse_valid, (int)(this._ntrees + 1), (double)Double.NaN) : null;
            this._training_time_ms = ArrayUtils.copyAndFillOf((long[])this._training_time_ms, (int)(this._ntrees + 1), (long)System.currentTimeMillis());
            fs.blockForPending();
        }

        public CompressedTree ctree(int tnum, int knum) {
            return (CompressedTree)this._treeKeys[tnum][knum].get();
        }

        public String toStringTree(int tnum, int knum) {
            return this.ctree(tnum, knum).toString(this);
        }
    }

    public static abstract class SharedTreeParameters
    extends SupervisedModel.SupervisedParameters {
        static final int MAX_SUPPORTED_LEVELS = 1000;
        public int _ntrees = 50;
        public int _max_depth = 5;
        public int _min_rows = 10;
        public int _nbins = 20;
        public long _seed;
        public boolean _checkpoint;
    }
}

