/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.SupervisedModel;
import hex.SupervisedModelBuilder;
import hex.VarImp;
import hex.tree.CompressedTree;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.TreeStats;
import java.util.Arrays;
import water.DKV;
import water.Futures;
import water.Iced;
import water.Key;

public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput>
extends SupervisedModel<M, P, O> {
    static final String PRED_TYPE = "float";

    public SharedTreeModel(Key selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    protected float[] score0(double[] data, float[] preds) {
        Arrays.fill(preds, 0.0f);
        for (int tidx = 0; tidx < ((SharedTreeOutput)this._output)._treeKeys.length; ++tidx) {
            this.score0(data, preds, tidx);
        }
        return preds;
    }

    public void score0(double[] data, float[] preds, int treeIdx) {
        Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[treeIdx];
        for (int c = 0; c < keys.length; ++c) {
            if (keys[c] == null) continue;
            int n = keys.length == 1 ? 0 : c + 1;
            preds[n] = preds[n] + ((CompressedTree)DKV.get(keys[c]).get()).score(data);
        }
    }

    boolean isFromSpeeDRF() {
        return false;
    }

    protected Futures remove_impl(Futures fs) {
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            Key<CompressedTree>[] ks;
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                if (k == null) continue;
                k.remove(fs);
            }
        }
        return super.remove_impl(fs);
    }

    public static abstract class SharedTreeOutput
    extends SupervisedModel.SupervisedOutput {
        public double _initialPrediction;
        public int _ntrees = 0;
        final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys = new Key[this._ntrees][];
        public double[] _mse_train;
        public double[] _mse_valid;
        public VarImp _varimp;

        public SharedTreeOutput(SharedTree b, double mse_train, double mse_valid) {
            super((SupervisedModelBuilder)b);
            double[] dArray;
            this._treeStats = new TreeStats();
            this._mse_train = new double[]{mse_train};
            if (Double.isNaN(mse_valid)) {
                dArray = null;
            } else {
                double[] dArray2 = new double[1];
                dArray = dArray2;
                dArray2[0] = mse_valid;
            }
            this._mse_valid = dArray;
        }

        public void addKTrees(DTree[] trees) {
            assert (this.nclasses() == trees.length);
            this._treeStats.updateBy(trees);
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            this._treeKeys[this._ntrees] = new Key[trees.length];
            Key[] keys = this._treeKeys[this._ntrees];
            Futures fs = new Futures();
            for (int i = 0; i < this.nclasses(); ++i) {
                if (trees[i] == null) continue;
                CompressedTree ct = trees[i].compress(this._ntrees, i);
                keys[i] = ct._key;
                DKV.put((Key)keys[i], (Iced)ct, (Futures)fs);
            }
            ++this._ntrees;
            this._mse_train = Arrays.copyOf(this._mse_train, this._ntrees + 1);
            if (this._mse_valid != null) {
                this._mse_valid = Arrays.copyOf(this._mse_valid, this._ntrees + 1);
            }
            fs.blockForPending();
        }

        public String toStringTree(int tnum, int knum) {
            return ((CompressedTree)this._treeKeys[tnum][knum].get()).toString(this);
        }
    }

    public static abstract class SharedTreeParameters
    extends SupervisedModel.SupervisedParameters {
        static final int MAX_SUPPORTED_LEVELS = 1000;
        public int _ntrees = 50;
        public int _max_depth = 5;
        public int _min_rows = 10;
        public int _nbins = 20;
        public boolean _variable_importance = false;
        public long _seed;
        public boolean _checkpoint;
    }
}

