/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import ai.h2o.xgboost4j.java.Booster;
import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.Rabit;
import ai.h2o.xgboost4j.java.XGBoostError;
import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.predict.XGBoostPredict;
import hex.tree.xgboost.predict.XGBoostPredictContrib;
import hex.tree.xgboost.util.BoosterHelper;
import java.util.HashMap;
import org.apache.log4j.Logger;
import water.fvec.Chunk;
import water.fvec.Frame;

public class XGBoostNativeBigScoreChunkPredict
implements XGBoostPredictContrib,
Model.BigScoreChunkPredict {
    private static final Logger LOG = Logger.getLogger(XGBoostNativeBigScoreChunkPredict.class);
    private final double _threshold;
    private final int _responseIndex;
    private final int _offsetIndex;
    private final XGBoostModelInfo _modelInfo;
    private final XGBoostModel.XGBoostParameters _parms;
    private final DataInfo _dataInfo;
    private final BoosterParms _boosterParms;
    private final XGBoostOutput _output;
    private final float[][] _preds;

    public XGBoostNativeBigScoreChunkPredict(XGBoostModelInfo modelInfo, XGBoostModel.XGBoostParameters parms, DataInfo di, BoosterParms boosterParms, double threshold, XGBoostOutput output, Frame fr, Chunk[] chks) {
        this._modelInfo = modelInfo;
        this._parms = parms;
        this._dataInfo = di;
        this._boosterParms = boosterParms;
        this._threshold = threshold;
        this._output = output;
        this._responseIndex = fr.find(this._parms._response_column);
        this._offsetIndex = fr.find(this._parms._offset_column);
        this._preds = this.scoreChunk(chks, XGBoostPredict.OutputType.PREDICT);
    }

    public double[] score0(Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        return XGBoostMojoModel.toPreds((double[])tmp, (float[])this._preds[row_in_chunk], (double[])preds, (int)this._output.nclasses(), null, (double)this._threshold);
    }

    @Override
    public float[][] predictContrib(Chunk[] cs) {
        return this.scoreChunk(cs, XGBoostPredict.OutputType.PREDICT_CONTRIB_APPROX);
    }

    @Override
    public float[][] predict(Chunk[] cs) {
        return this.scoreChunk(cs, XGBoostPredict.OutputType.PREDICT);
    }

    /*
     * Loose catch block
     */
    private float[][] scoreChunk(Chunk[] cs, XGBoostPredict.OutputType outputType) {
        float[][] preds;
        Booster booster;
        DMatrix data;
        block16: {
            data = null;
            booster = null;
            Rabit.init(new HashMap());
            data = XGBoostUtils.convertChunksToDMatrix(this._dataInfo, cs, this._responseIndex, this._output._sparse, this._offsetIndex);
            if (data.rowNum() != 0L) break block16;
            float[][] fArrayArray = new float[][]{};
            BoosterHelper.dispose(booster, data);
            try {
                Rabit.shutdown();
            }
            catch (XGBoostError xgBoostError) {
                LOG.error((Object)"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", (Throwable)xgBoostError);
            }
            return fArrayArray;
        }
        booster = BoosterHelper.loadModel(this._modelInfo._boosterBytes);
        booster.setParams(this._boosterParms.get());
        int treeLimit = 0;
        if (this._parms._booster == XGBoostModel.XGBoostParameters.Booster.dart) {
            treeLimit = this._parms._ntrees;
        }
        switch (outputType) {
            case PREDICT: {
                preds = booster.predict(data, false, treeLimit);
                break;
            }
            case PREDICT_CONTRIB_APPROX: {
                preds = booster.predictContrib(data, treeLimit);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported output type: " + (Object)((Object)outputType));
            }
        }
        Object object = preds == null ? (Object)new float[0][] : preds;
        BoosterHelper.dispose(booster, data);
        try {
            Rabit.shutdown();
        }
        catch (XGBoostError xgBoostError) {
            LOG.error((Object)"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", (Throwable)xgBoostError);
        }
        return object;
        catch (XGBoostError xgBoostError) {
            try {
                throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
            }
            catch (Throwable throwable) {
                BoosterHelper.dispose(booster, data);
                try {
                    Rabit.shutdown();
                }
                catch (XGBoostError xgBoostError2) {
                    LOG.error((Object)"Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", (Throwable)xgBoostError2);
                }
                throw throwable;
            }
        }
    }

    public void close() {
    }
}

