/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import com.google.common.collect.ObjectArrays;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.HashMap;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;

public class XGBoostScoreTask
extends MRTask<XGBoostScoreTask> {
    private final XGBoostModelInfo _sharedmodel;
    private final XGBoostOutput _output;
    private final XGBoostModel.XGBoostParameters _parms;
    private byte[] rawBooster;

    public static XGBoostScoreTaskResult runScoreTask(XGBoostModelInfo sharedmodel, XGBoostOutput output, XGBoostModel.XGBoostParameters parms, Booster booster, Key<Frame> destinationKey, Frame data, boolean computeMetrics) {
        XGBoostScoreTask task = (XGBoostScoreTask)new XGBoostScoreTask(sharedmodel, output, parms, booster).doAll(XGBoostScoreTask.outputTypes(output), data);
        String[] names = (String[])ObjectArrays.concat((Object[])Model.makeScoringNames((Model.Output)output), (Object[])new String[]{"label"}, String.class);
        Frame preds = task.outputFrame(destinationKey, names, XGBoostScoreTask.makeDomains(output, names));
        XGBoostScoreTaskResult res = new XGBoostScoreTaskResult();
        Vec resp = preds.lastVec();
        preds.remove(preds.vecs().length - 1);
        if (output.nclasses() == 1) {
            Vec pred = preds.vec(0);
            if (computeMetrics) {
                res.mm = ModelMetricsRegression.make((Vec)pred, (Vec)resp, (DistributionFamily)DistributionFamily.gaussian);
            }
        } else if (output.nclasses() == 2) {
            Vec p1 = preds.vec(2);
            if (computeMetrics) {
                resp.setDomain(output.classNames());
                res.mm = ModelMetricsBinomial.make((Vec)p1, (Vec)resp);
            }
        } else if (computeMetrics) {
            resp.setDomain(output.classNames());
            Frame pp = new Frame(preds);
            pp.remove(0);
            Scope.enter();
            res.mm = ModelMetricsMultinomial.make((Frame)pp, (Vec)resp, (String[])resp.toCategoricalVec().domain());
            Scope.exit((Key[])new Key[0]);
        }
        res.preds = preds;
        if (resp != null) {
            resp.remove();
        }
        if (computeMetrics) assert (res.mm != null);
        assert ("predict".equals(preds.name(0)));
        return res;
    }

    private static byte[] outputTypes(XGBoostOutput output) {
        if (output.nclasses() == 1) {
            return new byte[]{3, 3};
        }
        if (output.nclasses() == 2) {
            return new byte[]{4, 3, 3, 3};
        }
        byte[] types = new byte[output.nclasses() + 2];
        Arrays.fill(types, (byte)3);
        return types;
    }

    private static String[][] makeDomains(XGBoostOutput output, String[] names) {
        if (output.nclasses() == 1) {
            return null;
        }
        if (output.nclasses() == 2) {
            String[][] domains = new String[4][];
            domains[0] = new String[]{"N", "Y"};
            domains[3] = new String[]{"N", "Y"};
            return domains;
        }
        String[][] domains = new String[names.length][];
        domains[0] = output.classNames();
        return domains;
    }

    private XGBoostScoreTask(XGBoostModelInfo sharedmodel, XGBoostOutput output, XGBoostModel.XGBoostParameters parms, Booster booster) {
        this._sharedmodel = sharedmodel;
        this._output = output;
        this._parms = parms;
        this.rawBooster = XGBoost.getRawArray(booster);
    }

    public void map(Chunk[] cs, NewChunk[] ncs) {
        try {
            HashMap<String, Object> params = XGBoostModel.createParams(this._parms, this._output);
            HashMap rabitEnv = new HashMap();
            Rabit.init(rabitEnv);
            DMatrix data = XGBoostUtils.convertChunksToDMatrix(this._sharedmodel._dataInfoKey, cs, this._fr.find(this._parms._response_column), -1, this._fr.find(this._parms._fold_column), this._output._sparse);
            if (data.rowNum() == 0L) {
                return;
            }
            Booster booster = null;
            try {
                booster = Booster.loadModel((InputStream)new ByteArrayInputStream(this.rawBooster));
                booster.setParams(params);
            }
            catch (IOException e) {
                throw new IllegalStateException("Failed to load the booster.", e);
            }
            float[][] preds = booster.predict(data);
            float[] labels = data.getLabel();
            float[] weights = data.getWeight();
            if (this._output.nclasses() == 1) {
                double[] dpreds = new double[preds.length];
                for (int j = 0; j < dpreds.length; ++j) {
                    dpreds[j] = preds[j][0];
                }
                for (int i = 0; i < cs[0]._len; ++i) {
                    ncs[0].addNum(dpreds[i]);
                    ncs[1].addNum((double)labels[i]);
                }
            } else if (this._output.nclasses() == 2) {
                int j;
                double[] dpreds = new double[preds.length];
                for (j = 0; j < dpreds.length; ++j) {
                    dpreds[j] = preds[j][0];
                }
                if (weights.length > 0) {
                    for (j = 0; j < dpreds.length; ++j) {
                        assert ((double)weights[j] == 1.0);
                    }
                }
                for (int i = 0; i < cs[0]._len; ++i) {
                    double p = dpreds[i];
                    ncs[1].addNum(1.0 - p);
                    ncs[2].addNum(p);
                    double[] row = new double[]{0.0, 1.0 - p, p};
                    double predLab = GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)Model.defaultThreshold((Model.Output)this._output));
                    ncs[0].addNum(predLab);
                    ncs[3].addNum((double)labels[i]);
                }
            } else {
                for (int i = 0; i < cs[0]._len; ++i) {
                    double[] row = new double[ncs.length - 1];
                    for (int j = 1; j < row.length; ++j) {
                        double val = preds[i][j - 1];
                        ncs[j].addNum(val);
                        row[j] = val;
                    }
                    ncs[0].addNum((double)GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)Model.defaultThreshold((Model.Output)this._output)));
                    ncs[ncs.length - 1].addNum((double)labels[i]);
                }
            }
        }
        catch (XGBoostError xgBoostError) {
            throw new IllegalStateException("Failed to score with XGBoost.", xgBoostError);
        }
        finally {
            try {
                Rabit.shutdown();
            }
            catch (XGBoostError xgBoostError) {
                throw new IllegalStateException("Failed Rabit shutdown. A hanging RabitTracker task might be present on the driver node.", xgBoostError);
            }
        }
    }

    public static class XGBoostScoreTaskResult {
        public Frame preds;
        public ModelMetrics mm;
    }
}

