/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.KeyValue;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.PlattScalingHelper;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostMojoWriter;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostPojoWriter;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.predict.AssignLeafNodeTask;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictTreeSHAPTask;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import hex.tree.xgboost.predict.XGBoostJavaBigScorePredict;
import hex.tree.xgboost.predict.XGBoostModelMetrics;
import hex.tree.xgboost.predict.XGBoostNativeBigScorePredict;
import hex.tree.xgboost.util.BoosterHelper;
import hex.tree.xgboost.util.PredictConfiguration;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.stream.Stream;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.PredictorFactory;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.H2ONode;
import water.Iced;
import water.IcedUtils;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.codegen.CodeGeneratorPipeline;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.Log;
import water.util.SBPrintStream;

public class XGBoostModel
extends Model<XGBoostModel, XGBoostParameters, XGBoostOutput>
implements SharedTreeGraphConverter,
Model.LeafNodeAssignment,
Model.Contributions {
    private static final String PROP_VERBOSITY = "sys.ai.h2o..xgboost.verbosity";
    private static final String PROP_NTHREAD = "sys.ai.h2o.xgboost.nthreadMax";
    private XGBoostModelInfo model_info;

    public XGBoostModelInfo model_info() {
        return this.model_info;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((XGBoostOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((XGBoostOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public XGBoostModel(Key<XGBoostModel> selfKey, XGBoostParameters parms, XGBoostOutput output, Frame train, Frame valid) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
        DataInfo dinfo = XGBoost.makeDataInfo(train, valid, (XGBoostParameters)this._parms, output.nclasses());
        DKV.put((Keyed)dinfo);
        this.setDataInfoToOutput(dinfo);
        this.model_info = new XGBoostModelInfo(parms, dinfo);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void dump(String format) {
        File fmFile = null;
        try {
            String[] d;
            Booster b = BoosterHelper.loadModel(new ByteArrayInputStream(this.model_info._boosterBytes));
            fmFile = File.createTempFile("xgboost-feature-map", ".bin");
            FileOutputStream os = new FileOutputStream(fmFile);
            os.write(this.model_info._featureMap.getBytes());
            os.close();
            String fmFilePath = fmFile.getAbsolutePath();
            for (String l : d = b.getModelDump(fmFilePath, true, format)) {
                System.out.println(l);
            }
        }
        catch (Exception e) {
            Log.err((Throwable)e);
        }
        finally {
            if (fmFile != null) {
                fmFile.delete();
            }
        }
    }

    public static XGBoostParameters.Backend getActualBackend(XGBoostParameters p) {
        if (p._backend == XGBoostParameters.Backend.auto || p._backend == XGBoostParameters.Backend.gpu) {
            if (H2O.getCloudSize() > 1) {
                Log.info((Object[])new Object[]{"GPU backend not supported in distributed mode. Using CPU backend."});
                return XGBoostParameters.Backend.cpu;
            }
            if (!p.gpuIncompatibleParams().isEmpty()) {
                Log.info((Object[])new Object[]{"GPU backend not supported for the choice of parameters (" + p.gpuIncompatibleParams() + "). Using CPU backend."});
                return XGBoostParameters.Backend.cpu;
            }
            if (XGBoost.hasGPU(H2O.CLOUD.members()[0], p._gpu_id)) {
                Log.info((Object[])new Object[]{"Using GPU backend (gpu_id: " + p._gpu_id + ")."});
                return XGBoostParameters.Backend.gpu;
            }
            Log.info((Object[])new Object[]{"No GPU (gpu_id: " + p._gpu_id + ") found. Using CPU backend."});
            return XGBoostParameters.Backend.cpu;
        }
        Log.info((Object[])new Object[]{"Using CPU backend."});
        return XGBoostParameters.Backend.cpu;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static BoosterParms createParams(XGBoostParameters p, int nClasses, String[] coefNames) {
        int nthread;
        XGBoostParameters.Backend actualBackend;
        HashMap<String, Object> params = new HashMap<String, Object>();
        if (p._n_estimators != 0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter n_estimators instead of ntrees."});
            params.put("nround", p._n_estimators);
            p._ntrees = p._n_estimators;
        } else {
            params.put("nround", p._ntrees);
            p._n_estimators = p._ntrees;
        }
        if (p._eta != 0.3) {
            Log.info((Object[])new Object[]{"Using user-provided parameter eta instead of learn_rate."});
            params.put("eta", p._eta);
            p._learn_rate = p._eta;
        } else {
            params.put("eta", p._learn_rate);
            p._eta = p._learn_rate;
        }
        params.put("max_depth", p._max_depth);
        if (System.getProperty(PROP_VERBOSITY) != null) {
            params.put("verbosity", System.getProperty(PROP_VERBOSITY));
        } else {
            params.put("silent", p._quiet_mode);
        }
        if (p._subsample != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter subsample instead of sample_rate."});
            params.put("subsample", p._subsample);
            p._sample_rate = p._subsample;
        } else {
            params.put("subsample", p._sample_rate);
            p._subsample = p._sample_rate;
        }
        if (p._colsample_bytree != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bytree instead of col_sample_rate_per_tree."});
            params.put("colsample_bytree", p._colsample_bytree);
            p._col_sample_rate_per_tree = p._colsample_bytree;
        } else {
            params.put("colsample_bytree", p._col_sample_rate_per_tree);
            p._colsample_bytree = p._col_sample_rate_per_tree;
        }
        if (p._colsample_bylevel != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bylevel instead of col_sample_rate."});
            params.put("colsample_bylevel", p._colsample_bylevel);
            p._col_sample_rate = p._colsample_bylevel;
        } else {
            params.put("colsample_bylevel", p._col_sample_rate);
            p._colsample_bylevel = p._col_sample_rate;
        }
        if (p._max_delta_step != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter max_delta_step instead of max_abs_leafnode_pred."});
            params.put("max_delta_step", Float.valueOf(p._max_delta_step));
            p._max_abs_leafnode_pred = p._max_delta_step;
        } else {
            params.put("max_delta_step", Float.valueOf(p._max_abs_leafnode_pred));
            p._max_delta_step = p._max_abs_leafnode_pred;
        }
        params.put("seed", (int)(p._seed % Integer.MAX_VALUE));
        params.put("grow_policy", p._grow_policy.toString());
        if (p._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
            params.put("max_bin", p._max_bins);
            params.put("max_leaves", p._max_leaves);
            params.put("min_sum_hessian_in_leaf", Float.valueOf(p._min_sum_hessian_in_leaf));
            params.put("min_data_in_leaf", Float.valueOf(p._min_data_in_leaf));
        }
        params.put("booster", p._booster.toString());
        if (p._booster == XGBoostParameters.Booster.dart) {
            params.put("sample_type", p._sample_type.toString());
            params.put("normalize_type", p._normalize_type.toString());
            params.put("rate_drop", Float.valueOf(p._rate_drop));
            params.put("one_drop", p._one_drop ? "1" : "0");
            params.put("skip_drop", Float.valueOf(p._skip_drop));
        }
        if ((actualBackend = XGBoostModel.getActualBackend(p)) == XGBoostParameters.Backend.gpu) {
            params.put("gpu_id", p._gpu_id);
            if (p._booster == XGBoostParameters.Booster.gblinear) {
                Log.info((Object[])new Object[]{"Using gpu_coord_descent updater."});
                params.put("updater", "gpu_coord_descent");
            } else {
                Log.info((Object[])new Object[]{"Using gpu_hist tree method."});
                params.put("max_bin", p._max_bins);
                params.put("updater", "grow_gpu_hist");
            }
        } else if (p._booster == XGBoostParameters.Booster.gblinear) {
            Log.info((Object[])new Object[]{"Using coord_descent updater."});
            params.put("updater", "coord_descent");
        } else if (H2O.CLOUD.size() > 1 && p._tree_method == XGBoostParameters.TreeMethod.auto && p._monotone_constraints != null) {
            Log.info((Object[])new Object[]{"Using hist tree method for distributed computation with monotone_constraints."});
            params.put("tree_method", XGBoostParameters.TreeMethod.hist.toString());
            params.put("max_bin", p._max_bins);
        } else {
            Log.info((Object[])new Object[]{"Using " + p._tree_method.toString() + " tree method."});
            params.put("tree_method", p._tree_method.toString());
            if (p._tree_method == XGBoostParameters.TreeMethod.hist) {
                params.put("max_bin", p._max_bins);
            }
        }
        if (p._min_child_weight != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter min_child_weight instead of min_rows."});
            params.put("min_child_weight", p._min_child_weight);
            p._min_rows = p._min_child_weight;
        } else {
            params.put("min_child_weight", p._min_rows);
            p._min_child_weight = p._min_rows;
        }
        if (p._gamma != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter gamma instead of min_split_improvement."});
            params.put("gamma", Float.valueOf(p._gamma));
            p._min_split_improvement = p._gamma;
        } else {
            params.put("gamma", Float.valueOf(p._min_split_improvement));
            p._gamma = p._min_split_improvement;
        }
        params.put("lambda", Float.valueOf(p._reg_lambda));
        params.put("alpha", Float.valueOf(p._reg_alpha));
        if (nClasses == 2) {
            params.put("objective", XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId());
        } else if (nClasses == 1) {
            if (p._distribution == DistributionFamily.gamma) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId());
            } else if (p._distribution == DistributionFamily.tweedie) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId());
                params.put("tweedie_variance_power", p._tweedie_power);
            } else if (p._distribution == DistributionFamily.poisson) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId());
            } else {
                if (p._distribution != DistributionFamily.gaussian && p._distribution != DistributionFamily.AUTO) throw new UnsupportedOperationException("No support for distribution=" + p._distribution.toString());
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_SQUAREDERROR.getId());
            }
        } else {
            params.put("objective", XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId());
            params.put("num_class", nClasses);
        }
        assert (XGBoostMojoModel.ObjectiveType.fromXGBoost((String)((String)params.get("objective"))) != null);
        int nthreadMax = XGBoostModel.getMaxNThread();
        int n = nthread = p._nthread != -1 ? Math.min(p._nthread, nthreadMax) : nthreadMax;
        if (nthread < p._nthread) {
            Log.warn((Object[])new Object[]{"Requested nthread=" + p._nthread + " but the cluster has only " + nthreadMax + " available.Training will use nthread=" + nthread + " instead of the user specified value."});
        }
        params.put("nthread", nthread);
        Map<String, Integer> monotoneConstraints = p.monotoneConstraints();
        if (!monotoneConstraints.isEmpty()) {
            int constraintsUsed = 0;
            StringBuilder sb = new StringBuilder();
            sb.append("(");
            for (String coef : coefNames) {
                String direction;
                if (monotoneConstraints.containsKey(coef)) {
                    direction = monotoneConstraints.get(coef).toString();
                    ++constraintsUsed;
                } else {
                    direction = "0";
                }
                sb.append(direction);
                sb.append(",");
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            params.put("monotone_constraints", sb.toString());
            assert (constraintsUsed == monotoneConstraints.size());
        }
        Log.info((Object[])new Object[]{"XGBoost Parameters:"});
        for (Map.Entry s : params.entrySet()) {
            Log.info((Object[])new Object[]{" " + (String)s.getKey() + " = " + s.getValue()});
        }
        Log.info((Object[])new Object[]{""});
        return BoosterParms.fromMap(Collections.unmodifiableMap(params));
    }

    protected XGBoostModel deepClone(Key<XGBoostModel> result) {
        XGBoostModel newModel = (XGBoostModel)IcedUtils.deepCopy((Iced)this);
        newModel._key = result;
        ((XGBoostOutput)newModel._output).clearModelMetrics(false);
        ((XGBoostOutput)newModel._output)._training_metrics = null;
        ((XGBoostOutput)newModel._output)._validation_metrics = null;
        return newModel;
    }

    static int getMaxNThread() {
        if (System.getProperty(PROP_NTHREAD) != null) {
            return Integer.getInteger(PROP_NTHREAD);
        }
        int maxNodesPerHost = 1;
        HashSet<String> checkedNodes = new HashSet<String>();
        for (H2ONode node : H2O.CLOUD.members()) {
            String nodeHost = node.getIp();
            if (checkedNodes.contains(nodeHost)) continue;
            checkedNodes.add(nodeHost);
            long cnt = Stream.of(H2O.CLOUD.members()).filter(h -> h.getIp().equals(nodeHost)).count();
            if (cnt <= (long)maxNodesPerHost) continue;
            maxNodesPerHost = (int)cnt;
        }
        return Math.max(1, H2O.ARGS.nthreads / maxNodesPerHost);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(this.model_info.getDataInfoKey());
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(this.model_info.getDataInfoKey(), fs);
        return super.readAll_impl(ab, fs);
    }

    public XGBoostMojoWriter getMojo() {
        return new XGBoostMojoWriter(this);
    }

    private ModelMetrics makeMetrics(Frame data, Frame originalData, boolean isTrain, String description) {
        Log.debug((Object[])new Object[]{"Making metrics: " + description});
        return new XGBoostModelMetrics((XGBoostOutput)this._output, data, originalData, isTrain, this).compute();
    }

    final void doScoring(Frame _train, Frame _trainOrig, Frame _valid, Frame _validOrig) {
        ModelMetrics mm;
        ((XGBoostOutput)this._output)._training_metrics = mm = this.makeMetrics(_train, _trainOrig, true, "Metrics reported on training frame");
        ((XGBoostOutput)this._output)._scored_train[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
        this.addModelMetrics(mm);
        if (_valid != null) {
            ((XGBoostOutput)this._output)._validation_metrics = mm = this.makeMetrics(_valid, _validOrig, false, "Metrics reported on validation frame");
            ((XGBoostOutput)this._output)._scored_valid[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
            this.addModelMetrics(mm);
        }
    }

    protected Frame postProcessPredictions(Frame adaptedFrame, Frame predictFr, Job j) {
        return PlattScalingHelper.postProcessPredictions((Frame)predictFr, (Job)j, (PlattScalingHelper.OutputWithCalibration)((PlattScalingHelper.OutputWithCalibration)this._output));
    }

    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    public double[] score0(double[] data, double[] preds, double offset) {
        float[] out;
        DataInfo di = this.model_info.dataInfo();
        assert (di != null);
        MutableOneHotEncoderFVec row = new MutableOneHotEncoderFVec(di, ((XGBoostOutput)this._output)._sparse);
        row.setInput(data);
        Predictor predictor = PredictorFactory.makePredictor(this.model_info._boosterBytes);
        if (((XGBoostOutput)this._output).hasOffset()) {
            out = predictor.predict((FVec)row, (float)offset);
        } else {
            if (offset != 0.0) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            out = predictor.predict((FVec)row);
        }
        return XGBoostMojoModel.toPreds((double[])data, (float[])out, (double[])preds, (int)((XGBoostOutput)this._output).nclasses(), (double[])((XGBoostOutput)this._output)._priorClassDist, (double)this.defaultThreshold());
    }

    protected XGBoostBigScorePredict setupBigScorePredict(Model.BigScore bs) {
        return this.setupBigScorePredict(false);
    }

    public XGBoostBigScorePredict setupBigScorePredict(boolean isTrain) {
        DataInfo di = this.model_info().scoringInfo(isTrain);
        return PredictConfiguration.useJavaScoring() ? this.setupBigScorePredictJava(di) : this.setupBigScorePredictNative(di);
    }

    private XGBoostBigScorePredict setupBigScorePredictNative(DataInfo di) {
        BoosterParms boosterParms = XGBoostModel.createParams((XGBoostParameters)this._parms, ((XGBoostOutput)this._output).nclasses(), di.coefNames());
        return new XGBoostNativeBigScorePredict(this.model_info, (XGBoostParameters)this._parms, (XGBoostOutput)this._output, di, boosterParms, this.defaultThreshold());
    }

    private XGBoostBigScorePredict setupBigScorePredictJava(DataInfo di) {
        return new XGBoostJavaBigScorePredict(this.model_info, (XGBoostOutput)this._output, di, (XGBoostParameters)this._parms, this.defaultThreshold());
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        DataInfo di = this.model_info().dataInfo();
        assert (di != null);
        String[] outputNames = (String[])ArrayUtils.append((Object[])di.coefNames(), (Object[])new String[]{"BiasTerm"});
        return ((PredictTreeSHAPTask)new PredictTreeSHAPTask(di, this.model_info(), (XGBoostOutput)this._output).withPostMapAction((MRTask.PostMapAction)JobUpdatePostMap.forJob(j)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    public Frame scoreLeafNodeAssignment(Frame frame, Model.LeafNodeAssignment.LeafNodeAssignmentType type, Key<Frame> destination_key) {
        AssignLeafNodeTask task = AssignLeafNodeTask.make(this.model_info.scoringInfo(false), (XGBoostOutput)this._output, this.model_info._boosterBytes, type);
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        return task.execute(adaptFrm, destination_key);
    }

    private void setDataInfoToOutput(DataInfo dinfo) {
        ((XGBoostOutput)this._output).setNames(dinfo._adaptedFrame.names(), dinfo._adaptedFrame.typesStr());
        ((XGBoostOutput)this._output)._domains = dinfo._adaptedFrame.domains();
        ((XGBoostOutput)this._output)._nums = dinfo._nums;
        ((XGBoostOutput)this._output)._cats = dinfo._cats;
        ((XGBoostOutput)this._output)._catOffsets = dinfo._catOffsets;
        ((XGBoostOutput)this._output)._useAllFactorLevels = dinfo._useAllFactorLevels;
    }

    protected Futures remove_impl(Futures fs, boolean cascade) {
        DataInfo di = this.model_info().dataInfo();
        if (di != null) {
            di.remove(fs);
        }
        if (((XGBoostOutput)this._output)._calib_model != null) {
            ((XGBoostOutput)this._output)._calib_model.remove(fs);
        }
        return super.remove_impl(fs, cascade);
    }

    public SharedTreeGraph convert(int treeNumber, String treeClassName) {
        GradBooster booster = XGBoostJavaMojoModel.makePredictor((byte[])this.model_info._boosterBytes).getBooster();
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException("XGBoost model is not backed by a tree-based booster. Booster class is " + booster.getClass().getCanonicalName());
        }
        RegTree[][] groupedTrees = ((GBTree)booster).getGroupedTrees();
        int treeClass = this.getXGBoostClassIndex(treeClassName);
        if (treeClass >= groupedTrees.length) {
            throw new IllegalArgumentException(String.format("Given XGBoost model does not have given class '%s'.", treeClassName));
        }
        RegTree[] treesInGroup = groupedTrees[treeClass];
        if (treeNumber >= treesInGroup.length || treeNumber < 0) {
            throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", treesInGroup.length));
        }
        RegTreeNode[] treeNodes = treesInGroup[treeNumber].getNodes();
        assert (treeNodes.length >= 1);
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        SharedTreeSubgraph sharedTreeSubgraph = sharedTreeGraph.makeSubgraph(((XGBoostOutput)this._output)._training_metrics._description);
        XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames(this.model_info.dataInfo());
        XGBoostModel.constructSubgraph(treeNodes, sharedTreeSubgraph.makeRootNode(), 0, sharedTreeSubgraph, featureProperties, true);
        return sharedTreeGraph;
    }

    private static void constructSubgraph(RegTreeNode[] xgBoostNodes, SharedTreeNode sharedTreeNode, int nodeIndex, SharedTreeSubgraph sharedTreeSubgraph, XGBoostUtils.FeatureProperties featureProperties, boolean inclusiveNA) {
        RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
        if (featureProperties._oneHotEncoded[xgBoostNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
        sharedTreeNode.setInclusiveNa(inclusiveNA);
        sharedTreeNode.setNodeNumber(nodeIndex);
        if (!xgBoostNode.isLeaf()) {
            sharedTreeNode.setCol(xgBoostNode.getSplitIndex(), featureProperties._names[xgBoostNode.getSplitIndex()]);
            XGBoostModel.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, featureProperties, xgBoostNode.default_left());
            XGBoostModel.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, featureProperties, !xgBoostNode.default_left());
        }
    }

    public SharedTreeGraph convert(int treeNumber, String treeClass, ConvertTreeOptions options) {
        return this.convert(treeNumber, treeClass);
    }

    private int getXGBoostClassIndex(String treeClass) {
        ModelCategory modelCategory = ((XGBoostOutput)this._output).getModelCategory();
        if (ModelCategory.Regression.equals((Object)modelCategory) && treeClass != null && !treeClass.isEmpty()) {
            throw new IllegalArgumentException("There should be no tree class specified for regression.");
        }
        if (treeClass == null || treeClass.isEmpty()) {
            switch (modelCategory) {
                case Binomial: 
                case Regression: {
                    return 0;
                }
            }
            throw new IllegalArgumentException(String.format("Model category '%s' requires tree class to be specified.", modelCategory));
        }
        Object[] domain = ((XGBoostOutput)this._output)._domains[((XGBoostOutput)this._output)._domains.length - 1];
        int treeClassIndex = ArrayUtils.find((Object[])domain, (Object)treeClass);
        if (ModelCategory.Binomial.equals((Object)modelCategory) && treeClassIndex != 0) {
            throw new IllegalArgumentException(String.format("For binomial XGBoost model, only one tree for class %s has been built.", domain[0]));
        }
        if (treeClassIndex < 0) {
            throw new IllegalArgumentException(String.format("No such class '%s' in tree.", treeClass));
        }
        return treeClassIndex;
    }

    public boolean isFeatureUsedInPredict(String featureName) {
        int featureIdx = ArrayUtils.find((Object[])((XGBoostOutput)this._output)._varimp._names, (Object)featureName);
        if (featureIdx == -1 && ((XGBoostOutput)this._output)._catOffsets.length > 1) {
            featureIdx = ArrayUtils.find((Object[])((XGBoostOutput)this._output)._names, (Object)featureName);
            if (featureIdx == -1 || !((XGBoostOutput)this._output)._column_types[featureIdx].equals("Enum")) {
                return false;
            }
            for (int i = 0; i < ((XGBoostOutput)this._output)._varimp._names.length; ++i) {
                if (!((XGBoostOutput)this._output)._varimp._names[i].startsWith(featureName.concat(".")) || ((XGBoostOutput)this._output)._varimp._varimp[i] == 0.0f) continue;
                return true;
            }
            return false;
        }
        return featureIdx != -1 && (double)((XGBoostOutput)this._output)._varimp._varimp[featureIdx] != 0.0;
    }

    protected boolean toJavaCheckTooBig() {
        return this._output == null || ((XGBoostOutput)this._output)._ntrees * ((XGBoostParameters)this._parms)._max_depth > 1000;
    }

    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nclasses() { return ").p(((XGBoostOutput)this._output).nclasses()).p("; }").nl();
        return sb;
    }

    protected void toJavaPredictBody(SBPrintStream sb, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        String namePrefix = JCodeGen.toJavaId((String)this._key.toString());
        Predictor p = PredictorFactory.makePredictor(this.model_info._boosterBytes, false);
        XGBoostPojoWriter.make(p, namePrefix, (XGBoostOutput)this._output, this.defaultThreshold()).renderJavaPredictBody(sb, fileCtx);
    }

    public static class XGBoostParameters
    extends Model.Parameters
    implements Model.GetNTrees,
    PlattScalingHelper.ParamsWithCalibration {
        public boolean _quiet_mode = true;
        public int _ntrees = 50;
        public int _n_estimators;
        public int _max_depth = 6;
        public double _min_rows = 1.0;
        public double _min_child_weight = 1.0;
        public double _learn_rate = 0.3;
        public double _eta = 0.3;
        public double _learn_rate_annealing = 1.0;
        public double _sample_rate = 1.0;
        public double _subsample = 1.0;
        public double _col_sample_rate = 1.0;
        public double _colsample_bylevel = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        public double _colsample_bytree = 1.0;
        public KeyValue[] _monotone_constraints;
        public float _max_abs_leafnode_pred = 0.0f;
        public float _max_delta_step = 0.0f;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public float _min_split_improvement = 0.0f;
        public float _gamma;
        public int _nthread = -1;
        public String _save_matrix_directory;
        public boolean _build_tree_one_node = false;
        public int _max_bins = 256;
        public int _max_leaves = 0;
        public float _min_sum_hessian_in_leaf = 100.0f;
        public float _min_data_in_leaf = 0.0f;
        public TreeMethod _tree_method = TreeMethod.auto;
        public GrowPolicy _grow_policy = GrowPolicy.depthwise;
        public Booster _booster = Booster.gbtree;
        public DMatrixType _dmatrix_type = DMatrixType.auto;
        public float _reg_lambda = 1.0f;
        public float _reg_alpha = 0.0f;
        public boolean _calibrate_model;
        public Key<Frame> _calibration_frame;
        public DartSampleType _sample_type = DartSampleType.uniform;
        public DartNormalizeType _normalize_type = DartNormalizeType.tree;
        public float _rate_drop = 0.0f;
        public boolean _one_drop = false;
        public float _skip_drop = 0.0f;
        public int _gpu_id = 0;
        public Backend _backend = Backend.auto;
        static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[]{"_tree_method", "_grow_policy", "_booster", "_sample_rate", "_max_depth", "_min_rows"};

        public String algoName() {
            return "XGBoost";
        }

        public String fullName() {
            return "XGBoost";
        }

        public String javaName() {
            return XGBoostModel.class.getName();
        }

        public long progressUnits() {
            return this._ntrees;
        }

        Map<String, Object> gpuIncompatibleParams() {
            HashMap<String, Object> incompat = new HashMap<String, Object>();
            if (TreeMethod.auto != this._tree_method && TreeMethod.hist != this._tree_method && Booster.gblinear != this._booster) {
                incompat.put("tree_method", "Only auto and hist are supported tree_method on GPU backend.");
            }
            if (this._max_depth > 15 || this._max_depth < 1) {
                incompat.put("max_depth", this._max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
            }
            if (this._grow_policy == GrowPolicy.lossguide) {
                incompat.put("grow_policy", (Object)GrowPolicy.lossguide);
            }
            return incompat;
        }

        Map<String, Integer> monotoneConstraints() {
            if (this._monotone_constraints == null || this._monotone_constraints.length == 0) {
                return Collections.emptyMap();
            }
            HashMap<String, Integer> constraints = new HashMap<String, Integer>(this._monotone_constraints.length);
            for (KeyValue constraint : this._monotone_constraints) {
                double val = constraint.getValue();
                if (val == 0.0) continue;
                if (constraints.containsKey(constraint.getKey())) {
                    throw new IllegalStateException("Duplicate definition of constraint for feature '" + constraint.getKey() + "'.");
                }
                int direction = val < 0.0 ? -1 : 1;
                constraints.put(constraint.getKey(), direction);
            }
            return constraints;
        }

        public int getNTrees() {
            return this._ntrees;
        }

        public Frame getCalibrationFrame() {
            return this._calibration_frame != null ? (Frame)this._calibration_frame.get() : null;
        }

        public boolean calibrateModel() {
            return this._calibrate_model;
        }

        public Model.Parameters getParams() {
            return this;
        }

        public static enum Backend {
            auto,
            gpu,
            cpu;

        }

        public static enum DMatrixType {
            auto,
            dense,
            sparse;

        }

        public static enum DartNormalizeType {
            tree,
            forest;

        }

        public static enum DartSampleType {
            uniform,
            weighted;

        }

        public static enum Booster {
            gbtree,
            gblinear,
            dart;

        }

        public static enum GrowPolicy {
            depthwise,
            lossguide;

        }

        public static enum TreeMethod {
            auto,
            exact,
            approx,
            hist;

        }
    }
}

