/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ScoreKeeper;
import hex.ToEigenVec;
import hex.VarImp;
import hex.genmodel.GenModel;
import hex.glm.GLMModel;
import hex.tree.CompressedTree;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.TreeJCodeGen;
import hex.tree.TreeStats;
import hex.util.LinearAlgebraUtils;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.IcedUtils;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.SB;
import water.util.SBPrintStream;
import water.util.TwoDimTable;
import water.util.VecUtils;

public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput>
extends Model<M, P, O>
implements Model.LeafNodeAssignment,
Model.GetMostImportantFeatures {
    public String[] getMostImportantFeatures(int n) {
        if (this._output == null) {
            return null;
        }
        TwoDimTable vi = ((SharedTreeOutput)this._output)._variable_importances;
        if (vi == null) {
            return null;
        }
        n = Math.min(n, vi.getRowHeaders().length);
        String[] res = new String[n];
        System.arraycopy(vi.getRowHeaders(), 0, res, 0, n);
        return res;
    }

    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((SharedTreeOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((SharedTreeOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public SharedTreeModel(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    public Frame scoreLeafNodeAssignment(Frame frame, Key<Frame> destination_key) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        int classTrees = 0;
        for (int i = 0; i < ((SharedTreeOutput)this._output)._treeKeys[0].length; ++i) {
            if (((SharedTreeOutput)this._output)._treeKeys[0][i] == null) continue;
            ++classTrees;
        }
        final int outputcols = ((SharedTreeOutput)this._output)._treeKeys.length * classTrees;
        String[] names = new String[outputcols];
        int col = 0;
        for (int tidx = 0; tidx < ((SharedTreeOutput)this._output)._treeKeys.length; ++tidx) {
            Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[tidx];
            for (int c = 0; c < keys.length; ++c) {
                if (keys[c] == null) continue;
                names[col++] = "T" + (tidx + 1) + (keys.length == 1 ? "" : ".C" + (c + 1));
            }
        }
        Frame res = new MRTask(){

            public void map(Chunk[] chks, NewChunk[] idx) {
                double[] input = new double[chks.length];
                String[] output = new String[outputcols];
                for (int row = 0; row < chks[0]._len; ++row) {
                    for (int i = 0; i < chks.length; ++i) {
                        input[i] = chks[i].atd(row);
                    }
                    int col = 0;
                    for (int tidx = 0; tidx < ((SharedTreeOutput)SharedTreeModel.this._output)._treeKeys.length; ++tidx) {
                        Key<CompressedTree>[] keys;
                        for (Key<CompressedTree> key : keys = ((SharedTreeOutput)SharedTreeModel.this._output)._treeKeys[tidx]) {
                            if (key == null) continue;
                            String pred = ((CompressedTree)DKV.get(key).get()).getDecisionPath(input);
                            output[col++] = pred;
                        }
                    }
                    assert (col == outputcols);
                    for (int i = 0; i < outputcols; ++i) {
                        idx[i].addStr((Object)output[i]);
                    }
                }
            }
        }.doAll(outputcols, (byte)2, adaptFrm).outputFrame(destination_key, names, (String[][])null);
        Vec[] nvecs = new Vec[res.vecs().length];
        for (int c = 0; c < res.vecs().length; ++c) {
            Vec vv = res.vec(c);
            try {
                nvecs[c] = vv.toCategoricalVec();
                continue;
            }
            catch (Exception e) {
                VecUtils.deleteVecs((Vec[])nvecs, (int)c);
                throw e;
            }
        }
        res.delete();
        res = new Frame(destination_key, names, nvecs);
        DKV.put((Keyed)res);
        return res;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Frame postProcessPredictions(Frame predictFr) {
        if (((SharedTreeOutput)this._output)._calib_model == null) {
            return predictFr;
        }
        if (((SharedTreeOutput)this._output).getModelCategory() == ModelCategory.Binomial) {
            Key calibInputKey = Key.make();
            Frame calibOutput = null;
            try {
                Frame calibInput = new Frame(calibInputKey, new String[]{"p"}, new Vec[]{predictFr.vec(1)});
                calibOutput = ((SharedTreeOutput)this._output)._calib_model.score(calibInput);
                assert (calibOutput._names.length == 3);
                Vec[] calPredictions = calibOutput.remove(new int[]{1, 2});
                predictFr.write_lock();
                for (int i = 0; i < calPredictions.length; ++i) {
                    predictFr.add("cal_" + predictFr.name(1 + i), calPredictions[i]);
                }
                Frame frame = (Frame)predictFr.update();
                return frame;
            }
            finally {
                DKV.remove((Key)calibInputKey);
                if (calibOutput != null) {
                    calibOutput.remove();
                }
            }
        }
        throw H2O.unimpl((String)"Calibration is only supported for binomial models");
    }

    protected double[] score0(double[] data, double[] preds, double offset) {
        return this.score0(data, preds, offset, ((SharedTreeOutput)this._output)._treeKeys.length);
    }

    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        Arrays.fill(preds, 0.0);
        for (int tidx = 0; tidx < ntrees; ++tidx) {
            this.score0(data, preds, tidx);
        }
        return preds;
    }

    private void score0(double[] data, double[] preds, int treeIdx) {
        Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[treeIdx];
        for (int c = 0; c < keys.length; ++c) {
            if (keys[c] == null) continue;
            double pred = ((CompressedTree)DKV.get(keys[c]).get()).score(data);
            assert (!Double.isInfinite(pred));
            int n = keys.length == 1 ? 0 : c + 1;
            preds[n] = preds[n] + pred;
        }
    }

    protected M deepClone(Key<M> result) {
        SharedTreeModel newModel = (SharedTreeModel)IcedUtils.deepCopy(this.self());
        newModel._key = result;
        ((SharedTreeOutput)newModel._output).clearModelMetrics();
        ((SharedTreeOutput)newModel._output)._training_metrics = null;
        ((SharedTreeOutput)newModel._output)._validation_metrics = null;
        Key<CompressedTree>[][] treeKeys = ((SharedTreeOutput)newModel._output)._treeKeys;
        for (int i = 0; i < treeKeys.length; ++i) {
            for (int j = 0; j < treeKeys[i].length; ++j) {
                if (treeKeys[i][j] == null) continue;
                CompressedTree ct = (CompressedTree)DKV.get(treeKeys[i][j]).get();
                CompressedTree newCt = (CompressedTree)IcedUtils.deepCopy((Iced)ct);
                newCt._key = CompressedTree.makeTreeKey(i, j);
                Key key = newCt._key;
                treeKeys[i][j] = key;
                DKV.put((Key)key, (Iced)newCt);
            }
        }
        Key<CompressedTree>[][] treeKeysAux = ((SharedTreeOutput)newModel._output)._treeKeysAux;
        if (treeKeysAux != null) {
            for (int i = 0; i < treeKeysAux.length; ++i) {
                for (int j = 0; j < treeKeysAux[i].length; ++j) {
                    if (treeKeysAux[i][j] == null) continue;
                    CompressedTree ct = (CompressedTree)DKV.get(treeKeysAux[i][j]).get();
                    CompressedTree newCt = (CompressedTree)IcedUtils.deepCopy((Iced)ct);
                    Key key = newCt._key = Key.make((String)GenModel.createAuxKey((String)treeKeys[i][j].toString()));
                    treeKeysAux[i][j] = key;
                    DKV.put((Key)key, (Iced)newCt);
                }
            }
        }
        return (M)((Object)newModel);
    }

    protected Futures remove_impl(Futures fs) {
        Key<CompressedTree>[] ks;
        int i$;
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                if (k == null) continue;
                k.remove(fs);
            }
        }
        arr$ = ((SharedTreeOutput)this._output)._treeKeysAux;
        len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                if (k == null) continue;
                k.remove(fs);
            }
        }
        if (((SharedTreeOutput)this._output)._calib_model != null) {
            ((SharedTreeOutput)this._output)._calib_model.remove(fs);
        }
        return super.remove_impl(fs);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        Key<CompressedTree>[] ks;
        int i$;
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                ab.putKey(k);
            }
        }
        arr$ = ((SharedTreeOutput)this._output)._treeKeysAux;
        len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                ab.putKey(k);
            }
        }
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        Key<CompressedTree>[] ks;
        int i$;
        Key<CompressedTree>[][] arr$ = ((SharedTreeOutput)this._output)._treeKeys;
        int len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                ab.getKey(k, fs);
            }
        }
        arr$ = ((SharedTreeOutput)this._output)._treeKeysAux;
        len$ = arr$.length;
        for (i$ = 0; i$ < len$; ++i$) {
            for (Key<CompressedTree> k : ks = arr$[i$]) {
                ab.getKey(k, fs);
            }
        }
        return super.readAll_impl(ab, fs);
    }

    private M self() {
        return (M)((Object)this);
    }

    protected boolean toJavaCheckTooBig() {
        return this._output == null || (float)((SharedTreeOutput)this._output)._treeStats._num_trees * ((SharedTreeOutput)this._output)._treeStats._mean_leaves > 1000000.0f;
    }

    protected boolean binomialOpt() {
        return true;
    }

    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nfeatures() { return " + ((SharedTreeOutput)this._output).nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + ((SharedTreeOutput)this._output).nclasses() + "; }").nl();
        return sb;
    }

    protected void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, final boolean verboseCode) {
        final int nclass = ((SharedTreeOutput)this._output).nclasses();
        body.ip("java.util.Arrays.fill(preds,0);").nl();
        final String mname = JCodeGen.toJavaId((String)this._key.toString());
        int t = 0;
        while (t < ((SharedTreeOutput)this._output)._treeKeys.length) {
            this.toJavaForestName(body.i(), mname, t).p(".score0(data,preds);").nl();
            final int treeIdx = t++;
            fileCtx.add((Object)new CodeGenerator(){

                public void generate(JCodeSB out) {
                    try {
                        int c;
                        out.nl();
                        SharedTreeModel.this.toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
                        out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
                        for (c = 0; c < nclass; ++c) {
                            if (((SharedTreeOutput)SharedTreeModel.this._output)._treeKeys[treeIdx][c] == null || SharedTreeModel.this.binomialOpt() && c == 1 && nclass == 2) continue;
                            SharedTreeModel.this.toJavaTreeName(out.ip("preds[").p(nclass == 1 ? 0 : c + 1).p("] += "), mname, treeIdx, c).p(".score0(fdata);").nl();
                        }
                        out.di(1).ip("}").nl();
                        out.di(1).ip("}").nl();
                        for (c = 0; c < nclass; ++c) {
                            if (((SharedTreeOutput)SharedTreeModel.this._output)._treeKeys[treeIdx][c] == null || SharedTreeModel.this.binomialOpt() && c == 1 && nclass == 2) continue;
                            String javaClassName = SharedTreeModel.this.toJavaTreeName(new SB(), mname, treeIdx, c).toString();
                            CompressedTree ct = ((SharedTreeOutput)SharedTreeModel.this._output).ctree(treeIdx, c);
                            SB sb = new SB();
                            new TreeJCodeGen(SharedTreeModel.this, ct, sb, javaClassName, verboseCode).generate();
                            out.p((JCodeSB)sb);
                        }
                    }
                    catch (Throwable t) {
                        t.printStackTrace();
                        throw new IllegalArgumentException("Internal error creating the POJO.", t);
                    }
                }
            });
        }
        this.toJavaUnifyPreds(body);
    }

    protected abstract void toJavaUnifyPreds(SBPrintStream var1);

    protected <T extends JCodeSB> T toJavaTreeName(T sb, String mname, int t, int c) {
        return (T)sb.p(mname).p("_Tree_").p(t).p("_class_").p(c);
    }

    protected <T extends JCodeSB> T toJavaForestName(T sb, String mname, int t) {
        return (T)sb.p(mname).p("_Forest_").p(t);
    }

    public static abstract class SharedTreeOutput
    extends Model.Output {
        public double _init_f;
        public int _ntrees = 0;
        public final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys;
        public Key<CompressedTree>[][] _treeKeysAux;
        public ScoreKeeper[] _scored_train;
        public ScoreKeeper[] _scored_valid;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public TwoDimTable _variable_importances;
        public VarImp _varimp;
        public GLMModel _calib_model;

        public ScoreKeeper[] scoreKeepers() {
            ScoreKeeper[] ska;
            ArrayList<ScoreKeeper> skl = new ArrayList<ScoreKeeper>();
            for (ScoreKeeper sk : ska = this._validation_metrics != null ? this._scored_valid : this._scored_train) {
                if (sk.isEmpty()) continue;
                skl.add(sk);
            }
            return skl.toArray(new ScoreKeeper[skl.size()]);
        }

        public SharedTreeOutput(SharedTree b) {
            super((ModelBuilder)b);
            this._treeKeys = new Key[this._ntrees][];
            this._treeKeysAux = new Key[this._ntrees][];
            this._treeStats = new TreeStats();
            this._scored_train = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
            this._scored_valid = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
            this._modelClassDist = this._priorClassDist;
        }

        public void addKTrees(DTree[] trees) {
            assert (this.nclasses() == trees.length);
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            this._treeKeysAux = (Key[][])Arrays.copyOf(this._treeKeysAux, this._ntrees + 1);
            this._treeKeys[this._ntrees] = new Key[trees.length];
            Key[] keys = this._treeKeys[this._ntrees];
            this._treeKeysAux[this._ntrees] = new Key[trees.length];
            Key[] keysAux = this._treeKeysAux[this._ntrees];
            Futures fs = new Futures();
            for (int i = 0; i < this.nclasses(); ++i) {
                if (trees[i] == null) continue;
                CompressedTree ct = trees[i].compress(this._ntrees, i, this._domains);
                keys[i] = ct._key;
                DKV.put((Key)keys[i], (Iced)ct, (Futures)fs);
                this._treeStats.updateBy(trees[i]);
                CompressedTree ctAux = new CompressedTree(trees[i]._abAux.buf(), -1, -1L, -1, -1, this._domains);
                keysAux[i] = ctAux._key = Key.make((String)GenModel.createAuxKey((String)ct._key.toString()));
                DKV.put((Keyed)ctAux);
            }
            ++this._ntrees;
            this._scored_train = (ScoreKeeper[])ArrayUtils.copyAndFillOf((Object[])this._scored_train, (int)(this._ntrees + 1), (Object)new ScoreKeeper());
            this._scored_valid = this._scored_valid != null ? (ScoreKeeper[])ArrayUtils.copyAndFillOf((Object[])this._scored_valid, (int)(this._ntrees + 1), (Object)new ScoreKeeper()) : null;
            this._training_time_ms = ArrayUtils.copyAndFillOf((long[])this._training_time_ms, (int)(this._ntrees + 1), (long)System.currentTimeMillis());
            fs.blockForPending();
        }

        public CompressedTree ctree(int tnum, int knum) {
            return (CompressedTree)this._treeKeys[tnum][knum].get();
        }

        public String toStringTree(int tnum, int knum) {
            return this.ctree(tnum, knum).toString(this);
        }
    }

    public static abstract class SharedTreeParameters
    extends Model.Parameters {
        public int _ntrees = 50;
        public int _max_depth = 5;
        public double _min_rows = 10.0;
        public int _nbins = 20;
        public int _nbins_cats = 1024;
        public double _min_split_improvement = 1.0E-5;
        public HistogramType _histogram_type = HistogramType.AUTO;
        public double _r2_stopping = Double.MAX_VALUE;
        public int _nbins_top_level = 1024;
        public boolean _build_tree_one_node = false;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public double _sample_rate = 0.632;
        public double[] _sample_rate_per_class;
        public boolean _calibrate_model = false;
        public Key<Frame> _calibration_frame;
        public double _col_sample_rate_change_per_level = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        private static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[]{"_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};

        public Frame calib() {
            return this._calibration_frame == null ? null : (Frame)this._calibration_frame.get();
        }

        public long progressUnits() {
            return this._ntrees + (this._histogram_type == HistogramType.QuantilesGlobal || this._histogram_type == HistogramType.RoundRobin ? 1 : 0);
        }

        protected String[] getCheckpointNonModifiableFields() {
            return CHECKPOINT_NON_MODIFIABLE_FIELDS;
        }

        public void validateWithCheckpoint(SharedTreeParameters checkpointParameters) {
            for (Field fAfter : ((Object)((Object)this)).getClass().getFields()) {
                if (!ArrayUtils.contains((String[])this.getCheckpointNonModifiableFields(), (String)fAfter.getName())) continue;
                for (Field fBefore : ((Object)((Object)checkpointParameters)).getClass().getFields()) {
                    if (!fBefore.equals(fAfter)) continue;
                    try {
                        if (PojoUtils.equals((Object)((Object)this), (Field)fAfter, (Object)((Object)checkpointParameters), (Field)((Object)((Object)checkpointParameters)).getClass().getField(fAfter.getName()))) continue;
                        throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", (Object)("Field " + fAfter.getName() + " cannot be modified if checkpoint is specified!"));
                    }
                    catch (NoSuchFieldException e) {
                        throw new H2OIllegalArgumentException(fAfter.getName(), "TreeBuilder", (Object)("Field " + fAfter.getName() + " is not supported by checkpoint!"));
                    }
                }
            }
        }

        public static enum HistogramType {
            AUTO,
            UniformAdaptive,
            Random,
            QuantilesGlobal,
            RoundRobin;

        }
    }
}

