/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import hex.tree.xgboost.predict.XGBoostPredict;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

public class XGBoostScoreTask
extends MRTask<XGBoostScoreTask> {
    private final XGBoostOutput _output;
    private final int _weightsChunkId;
    private final XGBoostModel _model;
    private final boolean _isTrain;
    private final double _threshold;
    public ModelMetrics.MetricBuilder _metricBuilder;
    private transient XGBoostBigScorePredict _predict;

    public XGBoostScoreTask(XGBoostOutput output, int weightsChunkId, boolean isTrain, XGBoostModel model) {
        this._output = output;
        this._weightsChunkId = weightsChunkId;
        this._model = model;
        this._isTrain = isTrain;
        this._threshold = model.defaultThreshold();
    }

    private ModelMetrics.MetricBuilder createMetricsBuilder(int responseClassesNum, String[] responseDomain) {
        switch (responseClassesNum) {
            case 1: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
            case 2: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(responseDomain);
            }
        }
        return new ModelMetricsMultinomial.MetricBuilderMultinomial(responseClassesNum, responseDomain);
    }

    protected void setupLocal() {
        this._predict = this._model.setupBigScorePredict(this._isTrain);
    }

    public void map(Chunk[] cs, NewChunk[] ncs) {
        this._metricBuilder = this.createMetricsBuilder(this._output.nclasses(), this._output.classNames());
        XGBoostPredict predictor = this._predict.initMap(this._fr, cs);
        float[][] preds = predictor.predict(cs);
        if (preds.length == 0) {
            return;
        }
        assert (preds.length == cs[0]._len);
        Chunk responseChunk = cs[this._output.responseIdx()];
        if (this._output.nclasses() == 1) {
            double[] currentPred = new double[1];
            float[] yact = new float[1];
            for (int j = 0; j < preds.length; ++j) {
                currentPred[0] = preds[j][0];
                yact[0] = (float)responseChunk.atd(j);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(j) : 1.0;
                this._metricBuilder.perRow(currentPred, yact, weight, 0.0, (Model)this._model);
            }
            for (int i = 0; i < cs[0]._len; ++i) {
                ncs[0].addNum((double)preds[i][0]);
            }
        } else if (this._output.nclasses() == 2) {
            double[] row = new double[3];
            float[] yact = new float[1];
            for (int i = 0; i < cs[0]._len; ++i) {
                double p = preds[i][0];
                row[1] = 1.0 - p;
                row[2] = p;
                row[0] = GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)this._threshold);
                ncs[0].addNum(row[0]);
                ncs[1].addNum(row[1]);
                ncs[2].addNum(row[2]);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i) : 1.0;
                yact[0] = (float)responseChunk.atd(i);
                this._metricBuilder.perRow(row, yact, weight, 0.0, (Model)this._model);
            }
        } else {
            float[] yact = new float[1];
            double[] row = MemoryManager.malloc8d((int)ncs.length);
            for (int i = 0; i < cs[0]._len; ++i) {
                for (int j = 1; j < row.length; ++j) {
                    double val = preds[i][j - 1];
                    ncs[j].addNum(val);
                    row[j] = val;
                }
                row[0] = GenModel.getPrediction((double[])row, (double[])this._output._priorClassDist, null, (double)this._threshold);
                ncs[0].addNum(row[0]);
                yact[0] = (float)responseChunk.atd(i);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i) : 1.0;
                this._metricBuilder.perRow(row, yact, weight, 0.0, (Model)this._model);
            }
        }
    }

    public void reduce(XGBoostScoreTask mrt) {
        super.reduce((MRTask)mrt);
        this._metricBuilder.reduce(mrt._metricBuilder);
    }
}

