/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import java.util.Arrays;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

public class PredictTreeSHAPTask
extends MRTask<PredictTreeSHAPTask> {
    private final DataInfo _di;
    private final XGBoostModelInfo _modelInfo;
    private final XGBoostOutput _output;
    private final boolean _outputAggregated;
    private transient XGBoostJavaMojoModel _mojo;

    public PredictTreeSHAPTask(DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output, Model.Contributions.ContributionsOptions options) {
        this._di = di;
        this._modelInfo = modelInfo;
        this._output = output;
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals((Object)options._outputFormat);
    }

    protected void setupLocal() {
        this._mojo = new XGBoostJavaMojoModel(this._modelInfo._boosterBytes, this._output._names, this._output._domains, this._output.responseName(), true);
    }

    public void map(Chunk[] chks, NewChunk[] nc) {
        MutableOneHotEncoderFVec rowFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] input = MemoryManager.malloc8d((int)chks.length);
        float[] contribs = MemoryManager.malloc4f((int)(this._di.fullN() + 1));
        float[] output = this._outputAggregated ? MemoryManager.malloc4f((int)nc.length) : contribs;
        TreeSHAPPredictor.Workspace workspace = this._mojo.makeContributionsWorkspace();
        for (int row = 0; row < chks[0]._len; ++row) {
            int i;
            for (i = 0; i < chks.length; ++i) {
                input[i] = chks[i].atd(row);
            }
            Arrays.fill(contribs, 0.0f);
            rowFVec.setInput(input);
            this._mojo.calculateContributions((FVec)rowFVec, contribs, workspace);
            if (this._outputAggregated) {
                rowFVec.decodeAggregate(contribs, output);
                output[output.length - 1] = contribs[contribs.length - 1];
            }
            for (i = 0; i < nc.length; ++i) {
                nc[i].addNum((double)output[i]);
            }
        }
    }
}

