/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.KeyValue;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.algos.xgboost.XGBoostNativeMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostMojoWriter;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.BoosterHelper;
import ml.dmlc.xgboost4j.java.PredictorFactory;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import ml.dmlc.xgboost4j.java.XGBoostScoreTask;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Key;
import water.Keyed;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.Log;

public class XGBoostModel
extends Model<XGBoostModel, XGBoostParameters, XGBoostOutput>
implements SharedTreeGraphConverter {
    private XGBoostModelInfo model_info;

    public XGBoostModelInfo model_info() {
        return this.model_info;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((XGBoostOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((XGBoostOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public XGBoostModel(Key<XGBoostModel> selfKey, XGBoostParameters parms, XGBoostOutput output, Frame train, Frame valid) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
        DataInfo dinfo = XGBoost.makeDataInfo(train, valid, (XGBoostParameters)this._parms, output.nclasses());
        DKV.put((Keyed)dinfo);
        this.setDataInfoToOutput(dinfo);
        this.model_info = new XGBoostModelInfo(parms, dinfo);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static BoosterParms createParams(XGBoostParameters p, int nClasses, String[] coefNames) {
        int nthread;
        HashMap<String, Object> params = new HashMap<String, Object>();
        if (p._n_estimators != 0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter n_estimators instead of ntrees."});
            params.put("nround", p._n_estimators);
            p._ntrees = p._n_estimators;
        } else {
            params.put("nround", p._ntrees);
        }
        if (p._eta != 0.3) {
            Log.info((Object[])new Object[]{"Using user-provided parameter eta instead of learn_rate."});
            params.put("eta", p._eta);
            p._learn_rate = p._eta;
        } else {
            params.put("eta", p._learn_rate);
        }
        params.put("max_depth", p._max_depth);
        params.put("silent", p._quiet_mode);
        if (p._subsample != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter subsample instead of sample_rate."});
            params.put("subsample", p._subsample);
            p._sample_rate = p._subsample;
        } else {
            params.put("subsample", p._sample_rate);
        }
        if (p._colsample_bytree != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bytree instead of col_sample_rate_per_tree."});
            params.put("colsample_bytree", p._colsample_bytree);
            p._col_sample_rate_per_tree = p._colsample_bytree;
        } else {
            params.put("colsample_bytree", p._col_sample_rate_per_tree);
        }
        if (p._colsample_bylevel != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bylevel instead of col_sample_rate."});
            params.put("colsample_bylevel", p._colsample_bylevel);
            p._col_sample_rate = p._colsample_bylevel;
        } else {
            params.put("colsample_bylevel", p._col_sample_rate);
        }
        if (p._max_delta_step != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter max_delta_step instead of max_abs_leafnode_pred."});
            params.put("max_delta_step", Float.valueOf(p._max_delta_step));
            p._max_abs_leafnode_pred = p._max_delta_step;
        } else {
            params.put("max_delta_step", Float.valueOf(p._max_abs_leafnode_pred));
        }
        params.put("seed", (int)(p._seed % Integer.MAX_VALUE));
        params.put("tree_method", p._tree_method.toString());
        params.put("grow_policy", p._grow_policy.toString());
        if (p._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
            params.put("max_bins", p._max_bins);
            params.put("max_leaves", p._max_leaves);
            params.put("min_sum_hessian_in_leaf", Float.valueOf(p._min_sum_hessian_in_leaf));
            params.put("min_data_in_leaf", Float.valueOf(p._min_data_in_leaf));
        }
        params.put("booster", p._booster.toString());
        if (p._booster == XGBoostParameters.Booster.dart) {
            params.put("sample_type", p._sample_type.toString());
            params.put("normalize_type", p._normalize_type.toString());
            params.put("rate_drop", Float.valueOf(p._rate_drop));
            params.put("one_drop", p._one_drop ? "1" : "0");
            params.put("skip_drop", Float.valueOf(p._skip_drop));
        }
        if (p._backend == XGBoostParameters.Backend.auto || p._backend == XGBoostParameters.Backend.gpu) {
            if (H2O.getCloudSize() > 1) {
                Log.info((Object[])new Object[]{"GPU backend not supported in distributed mode. Using CPU backend."});
            } else if (!p.gpuIncompatibleParams().isEmpty()) {
                Log.info((Object[])new Object[]{"GPU backend not supported for the choice of parameters (" + p.gpuIncompatibleParams() + "). Using CPU backend."});
            } else if (XGBoost.hasGPU(H2O.CLOUD.members()[0], p._gpu_id)) {
                Log.info((Object[])new Object[]{"Using GPU backend (gpu_id: " + p._gpu_id + ")."});
                params.put("gpu_id", p._gpu_id);
                if (p._tree_method == XGBoostParameters.TreeMethod.exact) {
                    Log.info((Object[])new Object[]{"Using grow_gpu (exact) updater."});
                    params.put("tree_method", "exact");
                    params.put("updater", "grow_gpu");
                } else {
                    Log.info((Object[])new Object[]{"Using grow_gpu_hist (approximate) updater."});
                    params.put("max_bins", p._max_bins);
                    params.put("tree_method", "exact");
                    params.put("updater", "grow_gpu_hist");
                }
            } else {
                Log.info((Object[])new Object[]{"No GPU (gpu_id: " + p._gpu_id + ") found. Using CPU backend."});
            }
        } else {
            assert (p._backend == XGBoostParameters.Backend.cpu);
            Log.info((Object[])new Object[]{"Using CPU backend."});
        }
        if (p._min_child_weight != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter min_child_weight instead of min_rows."});
            params.put("min_child_weight", p._min_child_weight);
            p._min_rows = p._min_child_weight;
        } else {
            params.put("min_child_weight", p._min_rows);
        }
        if (p._gamma != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter gamma instead of min_split_improvement."});
            params.put("gamma", Float.valueOf(p._gamma));
            p._min_split_improvement = p._gamma;
        } else {
            params.put("gamma", Float.valueOf(p._min_split_improvement));
        }
        params.put("lambda", Float.valueOf(p._reg_lambda));
        params.put("alpha", Float.valueOf(p._reg_alpha));
        if (nClasses == 2) {
            params.put("objective", XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId());
        } else if (nClasses == 1) {
            if (p._distribution == DistributionFamily.gamma) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId());
            } else if (p._distribution == DistributionFamily.tweedie) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId());
                params.put("tweedie_variance_power", p._tweedie_power);
            } else if (p._distribution == DistributionFamily.poisson) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId());
            } else {
                if (p._distribution != DistributionFamily.gaussian && p._distribution != DistributionFamily.AUTO) throw new UnsupportedOperationException("No support for distribution=" + p._distribution.toString());
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_LINEAR.getId());
            }
        } else {
            params.put("objective", XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId());
            params.put("num_class", nClasses);
        }
        assert (XGBoostMojoModel.ObjectiveType.fromXGBoost((String)((String)params.get("objective"))) != null);
        int nthreadMax = XGBoostModel.getMaxNThread();
        int n = nthread = p._nthread != -1 ? Math.min(p._nthread, nthreadMax) : nthreadMax;
        if (nthread < p._nthread) {
            Log.warn((Object[])new Object[]{"Requested nthread=" + p._nthread + " but the cluster has only " + nthreadMax + " available.Training will use nthread=" + nthreadMax + " instead of the user specified value."});
        }
        params.put("nthread", nthread);
        Map<String, Integer> monotoneConstraints = p.monotoneConstraints();
        if (!monotoneConstraints.isEmpty()) {
            int constraintsUsed = 0;
            StringBuilder sb = new StringBuilder();
            sb.append("(");
            for (String coef : coefNames) {
                String direction;
                if (monotoneConstraints.containsKey(coef)) {
                    direction = monotoneConstraints.get(coef).toString();
                    ++constraintsUsed;
                } else {
                    direction = "0";
                }
                sb.append(direction);
                sb.append(",");
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            params.put("monotone_constraints", sb.toString());
            assert (constraintsUsed == monotoneConstraints.size());
        }
        Log.info((Object[])new Object[]{"XGBoost Parameters:"});
        for (Map.Entry s : params.entrySet()) {
            Log.info((Object[])new Object[]{" " + (String)s.getKey() + " = " + s.getValue()});
        }
        Log.info((Object[])new Object[]{""});
        return BoosterParms.fromMap(Collections.unmodifiableMap(params));
    }

    private static int getMaxNThread() {
        return Integer.getInteger("sys.ai.h2o.xgboost.nthread", H2O.ARGS.nthreads);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(this.model_info._dataInfoKey);
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(this.model_info._dataInfoKey, fs);
        return super.readAll_impl(ab, fs);
    }

    public XGBoostMojoWriter getMojo() {
        return new XGBoostMojoWriter(this);
    }

    private ModelMetrics makeMetrics(Frame data, Frame originalData, String description) {
        Log.debug((Object[])new Object[]{"Making metrics: " + description});
        XGBoostScoreTask.XGBoostScoreTaskResult score = XGBoostScoreTask.runScoreTask(this.model_info(), (XGBoostOutput)this._output, (XGBoostParameters)this._parms, null, data, originalData, true, this);
        score.preds.remove();
        return score.mm;
    }

    final void doScoring(Frame _train, Frame _trainOrig, Frame _valid, Frame _validOrig) {
        ModelMetrics mm;
        ((XGBoostOutput)this._output)._training_metrics = mm = this.makeMetrics(_train, _trainOrig, "Metrics reported on training frame");
        ((XGBoostOutput)this._output)._scored_train[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
        this.addModelMetrics(mm);
        if (_valid != null) {
            ((XGBoostOutput)this._output)._validation_metrics = mm = this.makeMetrics(_valid, _validOrig, "Metrics reported on validation frame");
            ((XGBoostOutput)this._output)._scored_valid[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
            this.addModelMetrics(mm);
        }
    }

    protected boolean needsPostProcess() {
        return false;
    }

    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public double[] score0(double[] data, double[] preds, double offset) {
        double[] dArray;
        block4: {
            DataInfo di = (DataInfo)this.model_info._dataInfoKey.get();
            assert (di != null);
            double threshold = this.defaultThreshold();
            Booster booster = null;
            try {
                booster = this.model_info.deserializeBooster();
                dArray = XGBoostNativeMojoModel.score0((double[])data, (double)offset, (double[])preds, (Booster)this.model_info.deserializeBooster(), (int)di._nums, (int)di._cats, (int[])di._catOffsets, (boolean)di._useAllFactorLevels, (int)((XGBoostOutput)this._output).nclasses(), (double[])((XGBoostOutput)this._output)._priorClassDist, (double)threshold, (boolean)((XGBoostOutput)this._output)._sparse);
                if (booster == null) break block4;
            }
            catch (Throwable throwable) {
                if (booster != null) {
                    BoosterHelper.dispose((Object[])new Object[]{booster});
                }
                throw throwable;
            }
            BoosterHelper.dispose((Object[])new Object[]{booster});
        }
        return dArray;
    }

    private boolean useJavaScoring() {
        return Boolean.getBoolean("sys.ai.h2o.xgboost.predict.java.enable");
    }

    protected Model.BigScorePredict setupBigScorePredict(Model.BigScore bs) {
        return this.useJavaScoring() ? this.setupBigScorePredictJava() : this.setupBigScorePredictNative();
    }

    private Model.BigScorePredict setupBigScorePredictNative() {
        DataInfo di = (DataInfo)this.model_info()._dataInfoKey.get();
        assert (di != null);
        BoosterParms boosterParms = XGBoostModel.createParams((XGBoostParameters)this._parms, ((XGBoostOutput)this._output).nclasses(), di.coefNames());
        return new XGBoostBigScorePredict(boosterParms);
    }

    private Model.BigScorePredict setupBigScorePredictJava() {
        DataInfo di = (DataInfo)this.model_info._dataInfoKey.get();
        assert (di != null);
        return new XGBoostJavaBigScorePredict(di, (XGBoostOutput)this._output, this.defaultThreshold(), this.model_info()._boosterBytes);
    }

    private void setDataInfoToOutput(DataInfo dinfo) {
        ((XGBoostOutput)this._output)._names = dinfo._adaptedFrame.names();
        ((XGBoostOutput)this._output)._domains = dinfo._adaptedFrame.domains();
        ((XGBoostOutput)this._output)._origNames = ((Frame)((XGBoostParameters)this._parms)._train.get()).names();
        ((XGBoostOutput)this._output)._origDomains = ((Frame)((XGBoostParameters)this._parms)._train.get()).domains();
        ((XGBoostOutput)this._output)._nums = dinfo._nums;
        ((XGBoostOutput)this._output)._cats = dinfo._cats;
        ((XGBoostOutput)this._output)._catOffsets = dinfo._catOffsets;
        ((XGBoostOutput)this._output)._useAllFactorLevels = dinfo._useAllFactorLevels;
    }

    protected Futures remove_impl(Futures fs) {
        if (this.model_info()._dataInfoKey != null) {
            ((DataInfo)this.model_info()._dataInfoKey.get()).remove(fs);
        }
        return super.remove_impl(fs);
    }

    public SharedTreeGraph convert(int treeNumber, String treeClassName) {
        GradBooster booster = null;
        try {
            booster = new Predictor((InputStream)new ByteArrayInputStream(this.model_info._boosterBytes)).getBooster();
        }
        catch (IOException e) {
            Log.err((Throwable)e);
            throw new IllegalStateException("Booster bytes inaccessible. Not able to extract the predictor and construct tree graph.");
        }
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException(String.format("Given XGBoost model is not backed by a tree-based booster. Booster class is %d", booster.getClass().getCanonicalName()));
        }
        RegTree[][] groupedTrees = ((GBTree)booster).getGroupedTrees();
        int treeClass = this.getXGBoostClassIndex(treeClassName);
        if (treeClass >= groupedTrees.length) {
            throw new IllegalArgumentException(String.format("Given XGBoost model does not have given class '%s'.", treeClassName));
        }
        RegTree[] treesInGroup = groupedTrees[treeClass];
        if (treeNumber >= treesInGroup.length || treeNumber < 0) {
            throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", treesInGroup.length));
        }
        RegTreeNode[] treeNodes = treesInGroup[treeNumber].getNodes();
        assert (treeNodes.length >= 1);
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        SharedTreeSubgraph sharedTreeSubgraph = sharedTreeGraph.makeSubgraph(((XGBoostOutput)this._output)._training_metrics._description);
        XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames((DataInfo)this.model_info._dataInfoKey.get());
        XGBoostModel.constructSubgraph(treeNodes, sharedTreeSubgraph.makeRootNode(), 0, sharedTreeSubgraph, featureProperties, true);
        return sharedTreeGraph;
    }

    private static void constructSubgraph(RegTreeNode[] xgBoostNodes, SharedTreeNode sharedTreeNode, int nodeIndex, SharedTreeSubgraph sharedTreeSubgraph, XGBoostUtils.FeatureProperties featureProperties, boolean inclusiveNA) {
        RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
        if (featureProperties._oneHotEncoded[xgBoostNode.split_index()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
        sharedTreeNode.setCol(xgBoostNode.split_index(), featureProperties._names[xgBoostNode.split_index()]);
        sharedTreeNode.setInclusiveNa(inclusiveNA);
        sharedTreeNode.setNodeNumber(nodeIndex);
        if (xgBoostNode.getLeftChildIndex() != -1) {
            XGBoostModel.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, featureProperties, xgBoostNode.default_left());
        }
        if (xgBoostNode.getRightChildIndex() != -1) {
            XGBoostModel.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, featureProperties, !xgBoostNode.default_left());
        }
    }

    private final int getXGBoostClassIndex(String treeClass) {
        ModelCategory modelCategory = ((XGBoostOutput)this._output).getModelCategory();
        if (ModelCategory.Regression.equals((Object)modelCategory) && treeClass != null && !treeClass.isEmpty()) {
            throw new IllegalArgumentException("There should be no tree class specified for regression.");
        }
        if ((treeClass == null || treeClass.isEmpty()) && ModelCategory.Regression.equals((Object)modelCategory)) {
            return 0;
        }
        if ((treeClass == null || treeClass.isEmpty()) && !ModelCategory.Regression.equals((Object)modelCategory)) {
            throw new IllegalArgumentException("Non-regressional models require tree class specified.");
        }
        Object[] domain = ((XGBoostOutput)this._output)._domains[((XGBoostOutput)this._output)._domains.length - 1];
        int treeClassIndex = ArrayUtils.find((Object[])domain, (Object)treeClass);
        if (ModelCategory.Binomial.equals((Object)modelCategory) && treeClassIndex != 0) {
            throw new IllegalArgumentException(String.format("For binomial XGBoost model, only one tree for class %s has been built.", domain[0]));
        }
        if (treeClassIndex < 0) {
            throw new IllegalArgumentException(String.format("No such class '%s' in tree.", treeClass));
        }
        return treeClassIndex;
    }

    private static class MutableOneHotEncoderFVec
    implements FVec {
        private final DataInfo _di;
        private final boolean _treatsZeroAsNA;
        private final int[] _catMap;
        private final int[] _catValues;
        private final float[] _numValues;
        private final float _notHot;

        MutableOneHotEncoderFVec(DataInfo di, boolean treatsZeroAsNA) {
            this._di = di;
            this._catValues = new int[this._di._cats];
            this._treatsZeroAsNA = treatsZeroAsNA;
            float f = this._notHot = this._treatsZeroAsNA ? Float.NaN : 0.0f;
            if (this._di._catOffsets == null) {
                this._catMap = new int[0];
            } else {
                this._catMap = new int[this._di._catOffsets[this._di._cats]];
                for (int c = 0; c < this._di._cats; ++c) {
                    for (int j = this._di._catOffsets[c]; j < this._di._catOffsets[c + 1]; ++j) {
                        this._catMap[j] = c;
                    }
                }
            }
            this._numValues = new float[this._di._nums];
        }

        void setInput(double[] input) {
            GenModel.setCats((double[])input, (int[])this._catValues, (int)this._di._cats, (int[])this._di._catOffsets, (boolean)this._di._useAllFactorLevels);
            for (int i = 0; i < this._numValues.length; ++i) {
                float val = (float)input[this._di._cats + i];
                this._numValues[i] = this._treatsZeroAsNA && val == 0.0f ? Float.NaN : val;
            }
        }

        public final float fvalue(int index) {
            if (index >= this._catMap.length) {
                return this._numValues[index - this._catMap.length];
            }
            boolean isHot = this._catValues[this._catMap[index]] == index;
            return isHot ? 1.0f : this._notHot;
        }
    }

    private static class XGboostJavaBigScoreChunkPredict
    implements Model.BigScoreChunkPredict {
        private final XGBoostOutput _output;
        private final double _threshold;
        private final Predictor _predictor;
        private final MutableOneHotEncoderFVec _row;

        public XGboostJavaBigScoreChunkPredict(DataInfo di, XGBoostOutput output, double threshold, Predictor predictor) {
            this._output = output;
            this._threshold = threshold;
            this._predictor = predictor;
            this._row = new MutableOneHotEncoderFVec(di, this._output._sparse);
        }

        public double[] score0(Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
            if (offset != 0.0) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            assert (this._output.nfeatures() == tmp.length);
            for (int i = 0; i < tmp.length; ++i) {
                tmp[i] = chks[i].atd(row_in_chunk);
            }
            this._row.setInput(tmp);
            float[] out = this._predictor.predict((FVec)this._row);
            return XGBoostMojoModel.toPreds((double[])tmp, (float[])out, (double[])preds, (int)this._output.nclasses(), (double[])this._output._priorClassDist, (double)this._threshold);
        }

        public void close() {
        }
    }

    private class XGBoostJavaBigScorePredict
    implements Model.BigScorePredict {
        private final DataInfo _di;
        private final XGBoostOutput _output;
        private final double _threshold;
        private final Predictor _predictor;

        XGBoostJavaBigScorePredict(DataInfo di, XGBoostOutput output, double threshold, byte[] boosterBytes) {
            this._di = di;
            this._output = output;
            this._threshold = threshold;
            this._predictor = PredictorFactory.makePredictor(boosterBytes);
        }

        public Model.BigScoreChunkPredict initMap(Frame fr, Chunk[] chks) {
            return new XGboostJavaBigScoreChunkPredict(this._di, this._output, this._threshold, this._predictor);
        }
    }

    private static class XGBoostBigScoreChunkPredict
    implements Model.BigScoreChunkPredict {
        private final int _nclasses;
        private final float[][] _preds;
        private final double _threshold;

        private XGBoostBigScoreChunkPredict(int nclasses, float[][] preds, double threshold) {
            this._nclasses = nclasses;
            this._preds = preds;
            this._threshold = threshold;
        }

        public double[] score0(Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
            for (int i = 0; i < tmp.length; ++i) {
                tmp[i] = chks[i].atd(row_in_chunk);
            }
            return XGBoostMojoModel.toPreds((double[])tmp, (float[])this._preds[row_in_chunk], (double[])preds, (int)this._nclasses, null, (double)this._threshold);
        }

        public void close() {
        }
    }

    private class XGBoostBigScorePredict
    implements Model.BigScorePredict {
        private final BoosterParms _boosterParms;

        private XGBoostBigScorePredict(BoosterParms boosterParms) {
            this._boosterParms = boosterParms;
        }

        public Model.BigScoreChunkPredict initMap(Frame fr, Chunk[] chks) {
            float[][] preds = this.scoreChunk(fr, chks);
            return new XGBoostBigScoreChunkPredict(((XGBoostOutput)XGBoostModel.this._output).nclasses(), preds, XGBoostModel.this.defaultThreshold());
        }

        private float[][] scoreChunk(Frame fr, Chunk[] chks) {
            return XGBoostScoreTask.scoreChunk(XGBoostModel.this.model_info(), (XGBoostParameters)XGBoostModel.this._parms, this._boosterParms, (XGBoostOutput)XGBoostModel.this._output, fr, chks);
        }
    }

    public static class XGBoostParameters
    extends Model.Parameters {
        public boolean _quiet_mode = true;
        public MissingValuesHandling _missing_values_handling;
        public int _ntrees = 50;
        public int _n_estimators;
        public int _max_depth = 6;
        public double _min_rows = 1.0;
        public double _min_child_weight = 1.0;
        public double _learn_rate = 0.3;
        public double _eta = 0.3;
        public double _learn_rate_annealing = 1.0;
        public double _sample_rate = 1.0;
        public double _subsample = 1.0;
        public double _col_sample_rate = 1.0;
        public double _colsample_bylevel = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        public double _colsample_bytree = 1.0;
        public KeyValue[] _monotone_constraints;
        public float _max_abs_leafnode_pred = 0.0f;
        public float _max_delta_step = 0.0f;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public float _min_split_improvement = 0.0f;
        public float _gamma;
        public int _nthread = -1;
        public int _max_bins = 256;
        public int _max_leaves = 0;
        public float _min_sum_hessian_in_leaf = 100.0f;
        public float _min_data_in_leaf = 0.0f;
        public TreeMethod _tree_method = TreeMethod.auto;
        public GrowPolicy _grow_policy = GrowPolicy.depthwise;
        public Booster _booster = Booster.gbtree;
        public DMatrixType _dmatrix_type = DMatrixType.auto;
        public float _reg_lambda = 1.0f;
        public float _reg_alpha = 0.0f;
        public DartSampleType _sample_type = DartSampleType.uniform;
        public DartNormalizeType _normalize_type = DartNormalizeType.tree;
        public float _rate_drop = 0.0f;
        public boolean _one_drop = false;
        public float _skip_drop = 0.0f;
        public int _gpu_id = 0;
        public Backend _backend = Backend.auto;

        public String algoName() {
            return "XGBoost";
        }

        public String fullName() {
            return "XGBoost";
        }

        public String javaName() {
            return XGBoostModel.class.getName();
        }

        public long progressUnits() {
            return this._ntrees;
        }

        Map<String, Object> gpuIncompatibleParams() {
            HashMap<String, Object> incompat = new HashMap<String, Object>();
            if (this._max_depth > 15 || this._max_depth < 1) {
                incompat.put("max_depth", this._max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
            }
            if (this._grow_policy == GrowPolicy.lossguide) {
                incompat.put("grow_policy", (Object)GrowPolicy.lossguide);
            }
            return incompat;
        }

        Map<String, Integer> monotoneConstraints() {
            if (this._monotone_constraints == null || this._monotone_constraints.length == 0) {
                return Collections.emptyMap();
            }
            HashMap<String, Integer> constraints = new HashMap<String, Integer>(this._monotone_constraints.length);
            for (KeyValue constraint : this._monotone_constraints) {
                double val = constraint.getValue();
                if (val == 0.0) continue;
                if (constraints.containsKey(constraint.getKey())) {
                    throw new IllegalStateException("Duplicate definition of constraint for feature '" + constraint.getKey() + "'.");
                }
                int direction = val < 0.0 ? -1 : 1;
                constraints.put(constraint.getKey(), direction);
            }
            return constraints;
        }

        public static enum Backend {
            auto,
            gpu,
            cpu;

        }

        public static enum DMatrixType {
            auto,
            dense,
            sparse;

        }

        public static enum DartNormalizeType {
            tree,
            forest;

        }

        public static enum DartSampleType {
            uniform,
            weighted;

        }

        public static enum MissingValuesHandling {
            MeanImputation,
            Skip;

        }

        public static enum Booster {
            gbtree,
            gblinear,
            dart;

        }

        public static enum GrowPolicy {
            depthwise,
            lossguide;

        }

        public static enum TreeMethod {
            auto,
            exact,
            approx,
            hist;

        }
    }
}

