/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.util.FVec;
import hex.ContributionsWithBackgroundFrameTask;
import hex.DataInfo;
import hex.Distribution;
import hex.DistributionFactory;
import hex.Model;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import java.util.Arrays;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;

public class PredictTreeSHAPWithBackgroundTask
extends ContributionsWithBackgroundFrameTask<PredictTreeSHAPWithBackgroundTask> {
    protected final DataInfo _di;
    protected final XGBoostModelInfo _modelInfo;
    protected final XGBoostOutput _output;
    protected final boolean _outputAggregated;
    protected final boolean _outputSpace;
    protected final Distribution _distribution;
    protected transient XGBoostJavaMojoModel _mojo;

    public PredictTreeSHAPWithBackgroundTask(DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output, Model.Contributions.ContributionsOptions options, Frame frame, Frame backgroundFrame, boolean perReference, boolean outputSpace) {
        super(frame._key, backgroundFrame._key, perReference);
        this._di = di;
        this._modelInfo = modelInfo;
        this._output = output;
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals((Object)options._outputFormat);
        this._outputSpace = outputSpace;
        this._distribution = outputSpace ? (this._modelInfo._parameters.getDistributionFamily().equals((Object)DistributionFamily.AUTO) && this._output.isBinomialClassifier() ? DistributionFactory.getDistribution((DistributionFamily)DistributionFamily.bernoulli) : DistributionFactory.getDistribution((Model.Parameters)this._modelInfo._parameters)) : null;
    }

    protected void setupLocal() {
        this._mojo = new XGBoostJavaMojoModel(this._modelInfo._boosterBytes, this._modelInfo.auxNodeWeightBytes(), this._output._names, this._output._domains, this._output.responseName(), true);
    }

    protected void fillInput(Chunk[] chks, int row, double[] input) {
        for (int i = 0; i < chks.length; ++i) {
            input[i] = chks[i].atd(row);
        }
    }

    protected void addContribToNewChunk(double[] contribs, NewChunk[] nc) {
        double transformationRatio = 1.0;
        double biasTerm = contribs[contribs.length - 1];
        if (this._outputSpace) {
            double linkSpaceX = Arrays.stream(contribs).sum();
            double linkSpaceBg = biasTerm;
            double outSpaceX = this._distribution.linkInv(linkSpaceX);
            double outSpaceBg = this._distribution.linkInv(linkSpaceBg);
            transformationRatio = Math.abs(linkSpaceX - linkSpaceBg) < 1.0E-6 ? 0.0 : (outSpaceX - outSpaceBg) / (linkSpaceX - linkSpaceBg);
            biasTerm = outSpaceBg;
        }
        for (int i = 0; i < nc.length - 1; ++i) {
            nc[i].addNum(contribs[i] * transformationRatio);
        }
        nc[nc.length - 1].addNum(biasTerm);
    }

    protected void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] ncs) {
        MutableOneHotEncoderFVec rowFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        MutableOneHotEncoderFVec rowFBgVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] input = MemoryManager.malloc8d((int)cs.length);
        double[] inputBg = MemoryManager.malloc8d((int)cs.length);
        double[] contribs = MemoryManager.malloc8d((int)(this._outputAggregated ? ncs.length : this._di.fullN() + 1));
        for (int row = 0; row < cs[0]._len; ++row) {
            this.fillInput(cs, row, input);
            rowFVec.setInput(input);
            for (int bgRow = 0; bgRow < bgCs[0]._len; ++bgRow) {
                Arrays.fill(contribs, 0.0);
                this.fillInput(bgCs, bgRow, inputBg);
                rowFBgVec.setInput(inputBg);
                this._mojo.calculateInterventionalContributions((FVec)rowFVec, (FVec)rowFBgVec, contribs, this._outputAggregated ? this._di._catOffsets : null, false);
                this.addContribToNewChunk(contribs, ncs);
            }
        }
    }
}

