/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import hex.DataInfo;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.BoosterHelper;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;

public class XGBoostScoreTask
extends MRTask<XGBoostScoreTask> {
    private final XGBoostModelInfo _sharedmodel;
    private final XGBoostOutput _output;
    private final XGBoostModel.XGBoostParameters _parms;
    private final BoosterParms _boosterParms;
    private final boolean _isTrain;
    private final boolean _computeMetrics;
    private final int _weightsChunkId;
    private final Model _model;
    private final double _threshold;
    private ModelMetrics.MetricBuilder _metricBuilder;

    public static XGBoostScoreTaskResult runScoreTask(XGBoostModelInfo sharedmodel, XGBoostOutput output, XGBoostModel.XGBoostParameters parms, Key<Frame> destinationKey, Frame data, Frame originalData, boolean isTrain, boolean computeMetrics, Model m) {
        BoosterParms boosterParms = XGBoostModel.createParams(parms, output.nclasses(), sharedmodel.dataInfo().coefNames());
        XGBoostScoreTask task = (XGBoostScoreTask)new XGBoostScoreTask(sharedmodel, output, parms, boosterParms, isTrain, computeMetrics, data.find(parms._weights_column), m).doAll(XGBoostScoreTask.outputTypes(output), data);
        String[] names = Model.makeScoringNames((Model.Output)output);
        Frame preds = task.outputFrame(destinationKey, names, XGBoostScoreTask.makeDomains(output, names));
        XGBoostScoreTaskResult res = new XGBoostScoreTaskResult();
        if (output.nclasses() == 1) {
            Vec pred = preds.vec(0);
            if (computeMetrics) {
                res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, new Frame(new Vec[]{pred}));
            }
        } else if (output.nclasses() == 2) {
            Vec p1 = preds.vec(2);
            if (computeMetrics) {
                res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, new Frame(new Vec[]{p1}));
            }
        } else if (computeMetrics) {
            Frame pp = new Frame(preds);
            pp.remove(0);
            Scope.enter();
            res.mm = task._metricBuilder.makeModelMetrics(m, originalData, data, pp);
            Scope.exit((Key[])new Key[0]);
        }
        res.preds = preds;
        assert ("predict".equals(preds.name(0)));
        return res;
    }

    private static byte[] outputTypes(XGBoostOutput output) {
        if (output.nclasses() == 1) {
            return new byte[]{3};
        }
        if (output.nclasses() == 2) {
            return new byte[]{4, 3, 3};
        }
        byte[] types = new byte[output.nclasses() + 1];
        Arrays.fill(types, (byte)3);
        return types;
    }

    private static String[][] makeDomains(XGBoostOutput output, String[] names) {
        if (output.nclasses() == 1) {
            return null;
        }
        String[][] domains = new String[names.length][];
        domains[0] = output.classNames();
        return domains;
    }

    private XGBoostScoreTask(XGBoostModelInfo sharedmodel, XGBoostOutput output, XGBoostModel.XGBoostParameters parms, BoosterParms boosterParms, boolean isTrain, boolean computeMetrics, int weightsChunkId, Model model) {
        this._sharedmodel = sharedmodel;
        this._output = output;
        this._parms = parms;
        this._boosterParms = boosterParms;
        this._isTrain = isTrain;
        this._computeMetrics = computeMetrics;
        this._weightsChunkId = weightsChunkId;
        this._model = model;
        this._threshold = Model.defaultThreshold((Model.Output)this._output);
    }

    private ModelMetrics.MetricBuilder createMetricsBuilder(int responseClassesNum, String[] responseDomain) {
        switch (responseClassesNum) {
            case 1: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
            case 2: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(responseDomain);
            }
        }
        return new ModelMetricsMultinomial.MetricBuilderMultinomial(responseClassesNum, responseDomain);
    }

    /*
     * Loose catch block
     */
    private static ScoreResult scoreChunkExt(XGBoostModelInfo sharedmodel, DataInfo dataInfo, XGBoostModel.XGBoostParameters parms, BoosterParms boosterParms, XGBoostOutput output, Frame fr, Chunk[] cs, OutputType outputType) {
        Booster booster;
        DMatrix data;
        block15: {
            data = null;
            booster = null;
            HashMap rabitEnv = new HashMap();
            Rabit.init(rabitEnv);
            data = XGBoostUtils.convertChunksToDMatrix(dataInfo, cs, fr.find(parms._response_column), -1, fr.find(parms._fold_column), output._sparse);
            if (data.rowNum() != 0L) break block15;
            ScoreResult scoreResult = null;
            BoosterHelper.dispose((Object[])new Object[]{booster, data});
            try {
                Rabit.shutdown();
            }
            catch (XGBoostError xgBoostError) {
                throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
            }
            return scoreResult;
        }
        booster = sharedmodel.deserializeBooster();
        booster.setParams(boosterParms.get());
        ScoreResult result = new ScoreResult();
        switch (outputType) {
            case PREDICT: {
                result._preds = booster.predict(data);
                result._labels = data.getLabel();
                break;
            }
            case PREDICT_CONTRIB_APPROX: {
                result._preds = booster.predictContrib(data, 0);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported output type: " + (Object)((Object)outputType));
            }
        }
        ScoreResult scoreResult = result;
        BoosterHelper.dispose((Object[])new Object[]{booster, data});
        try {
            Rabit.shutdown();
        }
        catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
        }
        return scoreResult;
        catch (XGBoostError xgBoostError) {
            try {
                throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
            }
            catch (Throwable throwable) {
                BoosterHelper.dispose((Object[])new Object[]{booster, data});
                try {
                    Rabit.shutdown();
                }
                catch (XGBoostError xgBoostError2) {
                    throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError2);
                }
                throw throwable;
            }
        }
    }

    public static float[][] scoreChunk(XGBoostModelInfo sharedmodel, DataInfo dataInfo, XGBoostModel.XGBoostParameters parms, BoosterParms boosterParms, XGBoostOutput output, Frame fr, Chunk[] cs) {
        ScoreResult r = XGBoostScoreTask.scoreChunkExt(sharedmodel, dataInfo, parms, boosterParms, output, fr, cs, OutputType.PREDICT);
        return r == null ? new float[][]{} : (float[][])r._preds;
    }

    public static float[][] scoreChunkContribApprox(XGBoostModelInfo sharedmodel, DataInfo dataInfo, XGBoostModel.XGBoostParameters parms, BoosterParms boosterParms, XGBoostOutput output, Frame fr, Chunk[] cs) {
        ScoreResult r = XGBoostScoreTask.scoreChunkExt(sharedmodel, dataInfo, parms, boosterParms, output, fr, cs, OutputType.PREDICT_CONTRIB_APPROX);
        return r == null ? new float[][]{} : (float[][])r._preds;
    }

    public void map(Chunk[] cs, NewChunk[] ncs) {
        this._metricBuilder = this._computeMetrics ? this.createMetricsBuilder(this._output.nclasses(), this._output.classNames()) : null;
        DataInfo di = this._sharedmodel.scoringInfo(this._isTrain);
        ScoreResult r = XGBoostScoreTask.scoreChunkExt(this._sharedmodel, di, this._parms, this._boosterParms, this._output, this._fr, cs, OutputType.PREDICT);
        if (r == null) {
            return;
        }
        if (this._output.nclasses() == 1) {
            double[] currentPred = new double[1];
            float[] yact = new float[1];
            for (int j = 0; j < r._preds.length; ++j) {
                currentPred[0] = r._preds[j][0];
                if (!this._computeMetrics) continue;
                yact[0] = r._labels[j];
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(j) : 1.0;
                this._metricBuilder.perRow(currentPred, yact, weight, 0.0, this._model);
            }
            for (int i = 0; i < cs[0]._len; ++i) {
                ncs[0].addNum((double)r._preds[i][0]);
            }
        } else if (this._output.nclasses() == 2) {
            double[] row = new double[3];
            float[] yact = new float[1];
            for (int i = 0; i < cs[0]._len; ++i) {
                double p = r._preds[i][0];
                row[1] = 1.0 - p;
                row[2] = p;
                row[0] = GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)this._threshold);
                ncs[0].addNum(row[0]);
                ncs[1].addNum(row[1]);
                ncs[2].addNum(row[2]);
                if (!this._computeMetrics) continue;
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i) : 1.0;
                yact[0] = r._labels[i];
                this._metricBuilder.perRow(row, yact, weight, 0.0, this._model);
            }
        } else {
            float[] yact = new float[1];
            double[] row = MemoryManager.malloc8d((int)ncs.length);
            for (int i = 0; i < cs[0]._len; ++i) {
                for (int j = 1; j < row.length; ++j) {
                    double val = r._preds[i][j - 1];
                    ncs[j].addNum(val);
                    row[j] = val;
                }
                row[0] = GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)this._threshold);
                ncs[0].addNum(row[0]);
                if (!this._computeMetrics) continue;
                yact[0] = r._labels[i];
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i) : 1.0;
                this._metricBuilder.perRow(row, yact, weight, 0.0, this._model);
            }
        }
    }

    public void reduce(XGBoostScoreTask mrt) {
        super.reduce((MRTask)mrt);
        if (this._computeMetrics) {
            this._metricBuilder.reduce(mrt._metricBuilder);
        }
    }

    private static class ScoreResult {
        float[][] _preds;
        float[] _labels;

        private ScoreResult() {
        }
    }

    public static class XGBoostScoreTaskResult {
        public Frame preds;
        public ModelMetrics mm;
    }

    static enum OutputType {
        PREDICT,
        PREDICT_CONTRIB_APPROX;

    }
}

