/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.VarImp;
import hex.genmodel.GenModel;
import hex.genmodel.algos.xgboost.XGBoostNativeMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostMojoWriter;
import hex.tree.xgboost.XGBoostOutput;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import ml.dmlc.xgboost4j.java.XGBoostScoreTask;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.Log;

public class XGBoostModel
extends Model<XGBoostModel, XGBoostParameters, XGBoostOutput> {
    private XGBoostModelInfo model_info;

    public XGBoostModelInfo model_info() {
        return this.model_info;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((XGBoostOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((XGBoostOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public XGBoostModel(Key<XGBoostModel> selfKey, XGBoostParameters parms, XGBoostOutput output, Frame train, Frame valid) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
        DataInfo dinfo = XGBoost.makeDataInfo(train, valid, (XGBoostParameters)this._parms, output.nclasses());
        DKV.put((Keyed)dinfo);
        this.setDataInfoToOutput(dinfo);
        this.model_info = new XGBoostModelInfo(parms, output.nclasses());
        this.model_info._dataInfoKey = dinfo._key;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static BoosterParms createParams(XGBoostParameters p, int nClasses) {
        int nthread;
        HashMap<String, Object> params = new HashMap<String, Object>();
        if (p._n_estimators != 0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter n_estimators instead of ntrees."});
            params.put("nround", p._n_estimators);
            p._ntrees = p._n_estimators;
        } else {
            params.put("nround", p._ntrees);
        }
        if (p._eta != 0.3) {
            Log.info((Object[])new Object[]{"Using user-provided parameter eta instead of learn_rate."});
            params.put("eta", p._eta);
            p._learn_rate = p._eta;
        } else {
            params.put("eta", p._learn_rate);
        }
        params.put("max_depth", p._max_depth);
        params.put("silent", p._quiet_mode);
        if (p._subsample != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter subsample instead of sample_rate."});
            params.put("subsample", p._subsample);
            p._sample_rate = p._subsample;
        } else {
            params.put("subsample", p._sample_rate);
        }
        if (p._colsample_bytree != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bytree instead of col_sample_rate_per_tree."});
            params.put("colsample_bytree", p._colsample_bytree);
            p._col_sample_rate_per_tree = p._colsample_bytree;
        } else {
            params.put("colsample_bytree", p._col_sample_rate_per_tree);
        }
        if (p._colsample_bylevel != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter colsample_bylevel instead of col_sample_rate."});
            params.put("colsample_bylevel", p._colsample_bylevel);
            p._col_sample_rate = p._colsample_bylevel;
        } else {
            params.put("colsample_bylevel", p._col_sample_rate);
        }
        if (p._max_delta_step != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter max_delta_step instead of max_abs_leafnode_pred."});
            params.put("max_delta_step", Float.valueOf(p._max_delta_step));
            p._max_abs_leafnode_pred = p._max_delta_step;
        } else {
            params.put("max_delta_step", Float.valueOf(p._max_abs_leafnode_pred));
        }
        params.put("seed", (int)(p._seed % Integer.MAX_VALUE));
        params.put("tree_method", p._tree_method.toString());
        params.put("grow_policy", p._grow_policy.toString());
        if (p._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
            params.put("max_bins", p._max_bins);
            params.put("max_leaves", p._max_leaves);
            params.put("min_sum_hessian_in_leaf", Float.valueOf(p._min_sum_hessian_in_leaf));
            params.put("min_data_in_leaf", Float.valueOf(p._min_data_in_leaf));
        }
        params.put("booster", p._booster.toString());
        if (p._booster == XGBoostParameters.Booster.dart) {
            params.put("sample_type", p._sample_type.toString());
            params.put("normalize_type", p._normalize_type.toString());
            params.put("rate_drop", Float.valueOf(p._rate_drop));
            params.put("one_drop", p._one_drop ? "1" : "0");
            params.put("skip_drop", Float.valueOf(p._skip_drop));
        }
        if (p._backend == XGBoostParameters.Backend.auto || p._backend == XGBoostParameters.Backend.gpu) {
            if (H2O.getCloudSize() > 1) {
                Log.info((Object[])new Object[]{"GPU backend not supported in distributed mode. Using CPU backend."});
            } else if (!p.gpuIncompatibleParams().isEmpty()) {
                Log.info((Object[])new Object[]{"GPU backend not supported for the choice of parameters (" + p.gpuIncompatibleParams() + "). Using CPU backend."});
            } else if (XGBoost.hasGPU(H2O.CLOUD.members()[0], p._gpu_id)) {
                Log.info((Object[])new Object[]{"Using GPU backend (gpu_id: " + p._gpu_id + ")."});
                params.put("gpu_id", p._gpu_id);
                if (p._tree_method == XGBoostParameters.TreeMethod.exact) {
                    Log.info((Object[])new Object[]{"Using grow_gpu (exact) updater."});
                    params.put("tree_method", "exact");
                    params.put("updater", "grow_gpu");
                } else {
                    Log.info((Object[])new Object[]{"Using grow_gpu_hist (approximate) updater."});
                    params.put("max_bins", p._max_bins);
                    params.put("tree_method", "exact");
                    params.put("updater", "grow_gpu_hist");
                }
            } else {
                Log.info((Object[])new Object[]{"No GPU (gpu_id: " + p._gpu_id + ") found. Using CPU backend."});
            }
        } else {
            assert (p._backend == XGBoostParameters.Backend.cpu);
            Log.info((Object[])new Object[]{"Using CPU backend."});
        }
        if (p._min_child_weight != 1.0) {
            Log.info((Object[])new Object[]{"Using user-provided parameter min_child_weight instead of min_rows."});
            params.put("min_child_weight", p._min_child_weight);
            p._min_rows = p._min_child_weight;
        } else {
            params.put("min_child_weight", p._min_rows);
        }
        if (p._gamma != 0.0f) {
            Log.info((Object[])new Object[]{"Using user-provided parameter gamma instead of min_split_improvement."});
            params.put("gamma", Float.valueOf(p._gamma));
            p._min_split_improvement = p._gamma;
        } else {
            params.put("gamma", Float.valueOf(p._min_split_improvement));
        }
        params.put("lambda", Float.valueOf(p._reg_lambda));
        params.put("alpha", Float.valueOf(p._reg_alpha));
        if (nClasses == 2) {
            params.put("objective", "binary:logistic");
        } else if (nClasses == 1) {
            if (p._distribution == DistributionFamily.gamma) {
                params.put("objective", "reg:gamma");
            } else if (p._distribution == DistributionFamily.tweedie) {
                params.put("objective", "reg:tweedie");
                params.put("tweedie_variance_power", p._tweedie_power);
            } else if (p._distribution == DistributionFamily.poisson) {
                params.put("objective", "count:poisson");
            } else {
                if (p._distribution != DistributionFamily.gaussian && p._distribution != DistributionFamily.AUTO) throw new UnsupportedOperationException("No support for distribution=" + p._distribution.toString());
                params.put("objective", "reg:linear");
            }
        } else {
            params.put("objective", "multi:softprob");
            params.put("num_class", nClasses);
        }
        Log.info((Object[])new Object[]{"XGBoost Parameters:"});
        for (Map.Entry s : params.entrySet()) {
            Log.info((Object[])new Object[]{" " + (String)s.getKey() + " = " + s.getValue()});
        }
        Log.info((Object[])new Object[]{""});
        int nthreadMax = XGBoostModel.getMaxNThread();
        int n = nthread = p._nthread != -1 ? Math.min(p._nthread, nthreadMax) : nthreadMax;
        if (nthread < p._nthread) {
            Log.warn((Object[])new Object[]{"Requested nthread=" + p._nthread + " but the cluster has only " + nthreadMax + " available.Training will use nthread=" + nthreadMax + " instead of the user specified value."});
        }
        params.put("nthread", nthread);
        return BoosterParms.fromMap(Collections.unmodifiableMap(params));
    }

    private static int getMaxNThread() {
        return Integer.getInteger("sys.ai.h2o.xgboost.nthread", H2O.ARGS.nthreads);
    }

    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(this.model_info._dataInfoKey);
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(this.model_info._dataInfoKey, fs);
        return super.readAll_impl(ab, fs);
    }

    public XGBoostMojoWriter getMojo() {
        return new XGBoostMojoWriter(this);
    }

    private ModelMetrics makeMetrics(Booster booster, Frame data, Frame originalData, String description) throws XGBoostError {
        return this.makeMetrics(booster, data, originalData, description, null);
    }

    private ModelMetrics makeMetrics(Booster booster, Frame data, Frame originalData, String description, Key<Frame> predFrameKey) throws XGBoostError {
        Futures fs = new Futures();
        ModelMetrics[] mms = new ModelMetrics[1];
        Frame predictions = this.makePreds(booster, data, mms, true, predFrameKey, fs);
        if (predFrameKey == null) {
            predictions.remove(fs);
        } else {
            DKV.put((Keyed)predictions, (Futures)fs);
        }
        fs.blockForPending();
        ModelMetrics mm = mms[0].withModelAndFrame((Model)this, originalData).withDescription(description);
        return mm;
    }

    private Frame makePredsOnly(Booster booster, Frame data, Key<Frame> destinationKey) throws XGBoostError {
        Futures fs = new Futures();
        Frame preds = this.makePreds(booster, data, null, false, destinationKey, fs);
        DKV.put((Keyed)preds, (Futures)fs);
        fs.blockForPending();
        return preds;
    }

    private Frame makePreds(Booster booster, Frame data, ModelMetrics[] mms, boolean computeMetrics, Key<Frame> destinationKey, Futures fs) throws XGBoostError {
        assert (!computeMetrics || mms != null && mms.length == 1);
        XGBoostScoreTask.XGBoostScoreTaskResult score = XGBoostScoreTask.runScoreTask(this.model_info(), (XGBoostOutput)this._output, (XGBoostParameters)this._parms, booster, destinationKey, data, computeMetrics);
        if (computeMetrics) {
            mms[0] = score.mm;
        }
        return score.preds;
    }

    public void doScoring(Booster booster, Frame _train, Frame _trainOrig, Frame _valid, Frame _validOrig) throws XGBoostError {
        ModelMetrics mm;
        ((XGBoostOutput)this._output)._training_metrics = mm = this.makeMetrics(booster, _train, _trainOrig, "Metrics reported on training frame");
        ((XGBoostOutput)this._output)._scored_train[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
        this.addModelMetrics(mm);
        if (_valid != null) {
            assert (_valid != null) : "Validation frame (source of validation matrix) has to be not null!";
            ((XGBoostOutput)this._output)._validation_metrics = mm = this.makeMetrics(booster, _valid, _validOrig, "Metrics reported on validation frame");
            ((XGBoostOutput)this._output)._scored_valid[((XGBoostOutput)this._output)._ntrees].fillFrom(mm);
            this.addModelMetrics(mm);
        }
    }

    void computeVarImp(Map<String, Integer> varimp) {
        if (varimp.isEmpty()) {
            return;
        }
        float[] viFloat = new float[varimp.size()];
        String[] names = new String[varimp.size()];
        int j = 0;
        for (Map.Entry<String, Integer> it : varimp.entrySet()) {
            viFloat[j] = it.getValue().intValue();
            names[j] = it.getKey();
            ++j;
        }
        ((XGBoostOutput)this._output)._varimp = new VarImp(viFloat, names);
    }

    public double[] score0(double[] data, double[] preds, double offset) {
        DataInfo di = (DataInfo)this.model_info._dataInfoKey.get();
        return XGBoostNativeMojoModel.score0((double[])data, (double)offset, (double[])preds, (Booster)this.model_info.getBooster(), (int)di._nums, (int)di._cats, (int[])di._catOffsets, (boolean)di._useAllFactorLevels, (int)((XGBoostOutput)this._output).nclasses(), (double[])((XGBoostOutput)this._output)._priorClassDist, (double)this.defaultThreshold(), (boolean)((XGBoostOutput)this._output)._sparse);
    }

    public double[][] score0(Chunk[] chks, double[] offset, int[] rowsInChunk, double[][] tmp, double[][] preds) {
        for (int row = 0; row < rowsInChunk.length; ++row) {
            for (int i = 0; i < tmp[row].length; ++i) {
                tmp[row][i] = chks[i].atd(rowsInChunk[row]);
            }
        }
        DataInfo di = (DataInfo)this.model_info._dataInfoKey.get();
        double[][] scored = XGBoostNativeMojoModel.bulkScore0((double[][])tmp, (double[])offset, (double[][])preds, (Booster)this.model_info.getBooster(), (int)di._nums, (int)di._cats, (int[])di._catOffsets, (boolean)di._useAllFactorLevels, (int)((XGBoostOutput)this._output).nclasses(), (double[])((XGBoostOutput)this._output)._priorClassDist, (double)this.defaultThreshold(), (boolean)((XGBoostOutput)this._output)._sparse);
        if (this.isSupervised() && ((XGBoostOutput)this._output).isClassifier()) {
            for (int row = 0; row < rowsInChunk.length; ++row) {
                if (((XGBoostParameters)this._parms)._balance_classes) {
                    GenModel.correctProbabilities((double[])scored[row], (double[])((XGBoostOutput)this._output)._priorClassDist, (double[])((XGBoostOutput)this._output)._modelClassDist);
                }
                scored[row][0] = GenModel.getPrediction((double[])scored[row], (double[])((XGBoostOutput)this._output)._priorClassDist, (double[])tmp[row], (double)this.defaultThreshold());
            }
        }
        return scored;
    }

    protected boolean bulkBigScorePredict() {
        return false;
    }

    private void setDataInfoToOutput(DataInfo dinfo) {
        ((XGBoostOutput)this._output)._names = dinfo._adaptedFrame.names();
        ((XGBoostOutput)this._output)._domains = dinfo._adaptedFrame.domains();
        ((XGBoostOutput)this._output)._origNames = ((Frame)((XGBoostParameters)this._parms)._train.get()).names();
        ((XGBoostOutput)this._output)._origDomains = ((Frame)((XGBoostParameters)this._parms)._train.get()).domains();
        ((XGBoostOutput)this._output)._nums = dinfo._nums;
        ((XGBoostOutput)this._output)._cats = dinfo._cats;
        ((XGBoostOutput)this._output)._catOffsets = dinfo._catOffsets;
        ((XGBoostOutput)this._output)._useAllFactorLevels = dinfo._useAllFactorLevels;
    }

    protected Futures remove_impl(Futures fs) {
        this.model_info().nukeBackend();
        if (this.model_info()._dataInfoKey != null) {
            ((DataInfo)this.model_info()._dataInfoKey.get()).remove(fs);
        }
        return super.remove_impl(fs);
    }

    public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics) throws IllegalArgumentException {
        Frame adaptFr = new Frame(fr);
        String[] msg = this.adaptTestForTrain(adaptFr, true, computeMetrics = computeMetrics && (!this.isSupervised() || adaptFr.vec(((XGBoostOutput)this._output).responseName()) != null && !adaptFr.vec(((XGBoostOutput)this._output).responseName()).isBad()));
        if (msg.length > 0) {
            for (String s : msg) {
                Log.warn((Object[])new Object[]{s});
            }
        }
        try {
            Key destFrameKey = Key.make((String)destination_key);
            if (computeMetrics) {
                ModelMetrics mm = this.makeMetrics(this.model_info().booster(), adaptFr, fr, "Prediction on frame " + fr._key, (Key<Frame>)destFrameKey);
                this.addModelMetrics(mm);
                DKV.put((Keyed)this);
            } else {
                this.makePredsOnly(this.model_info().booster(), adaptFr, (Key<Frame>)destFrameKey);
            }
            return (Frame)destFrameKey.get();
        }
        catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed scoring.", xgBoostError);
        }
    }

    public static class XGBoostParameters
    extends Model.Parameters {
        public boolean _quiet_mode = true;
        public MissingValuesHandling _missing_values_handling;
        public int _ntrees = 50;
        public int _n_estimators;
        public int _max_depth = 6;
        public double _min_rows = 1.0;
        public double _min_child_weight = 1.0;
        public double _learn_rate = 0.3;
        public double _eta = 0.3;
        public double _learn_rate_annealing = 1.0;
        public double _sample_rate = 1.0;
        public double _subsample = 1.0;
        public double _col_sample_rate = 1.0;
        public double _colsample_bylevel = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        public double _colsample_bytree = 1.0;
        public float _max_abs_leafnode_pred = 0.0f;
        public float _max_delta_step = 0.0f;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public float _min_split_improvement = 0.0f;
        public float _gamma;
        public int _nthread = -1;
        public int _max_bins = 256;
        public int _max_leaves = 0;
        public float _min_sum_hessian_in_leaf = 100.0f;
        public float _min_data_in_leaf = 0.0f;
        public TreeMethod _tree_method = TreeMethod.auto;
        public GrowPolicy _grow_policy = GrowPolicy.depthwise;
        public Booster _booster = Booster.gbtree;
        public DMatrixType _dmatrix_type = DMatrixType.auto;
        public float _reg_lambda = 0.0f;
        public float _reg_alpha = 0.0f;
        public DartSampleType _sample_type = DartSampleType.uniform;
        public DartNormalizeType _normalize_type = DartNormalizeType.tree;
        public float _rate_drop = 0.0f;
        public boolean _one_drop = false;
        public float _skip_drop = 0.0f;
        public int _gpu_id = 0;
        public Backend _backend = Backend.auto;

        public String algoName() {
            return "XGBoost";
        }

        public String fullName() {
            return "XGBoost";
        }

        public String javaName() {
            return XGBoostModel.class.getName();
        }

        public long progressUnits() {
            return this._ntrees;
        }

        Map<String, Object> gpuIncompatibleParams() {
            HashMap<String, Object> incompat = new HashMap<String, Object>();
            if (this._max_depth > 15 || this._max_depth < 1) {
                incompat.put("max_depth", this._max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
            }
            if (this._grow_policy == GrowPolicy.lossguide) {
                incompat.put("grow_policy", (Object)GrowPolicy.lossguide);
            }
            return incompat;
        }

        public static enum Backend {
            auto,
            gpu,
            cpu;

        }

        public static enum DMatrixType {
            auto,
            dense,
            sparse;

        }

        public static enum DartNormalizeType {
            tree,
            forest;

        }

        public static enum DartSampleType {
            uniform,
            weighted;

        }

        public static enum MissingValuesHandling {
            MeanImputation,
            Skip;

        }

        public static enum Booster {
            gbtree,
            gblinear,
            dart;

        }

        public static enum GrowPolicy {
            depthwise,
            lossguide;

        }

        public static enum TreeMethod {
            auto,
            exact,
            approx,
            hist;

        }
    }
}

