/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.Model;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.SharedTreeModel;
import java.util.ArrayList;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public abstract class SharedTreeModelWithContributions<M extends SharedTreeModel<M, P, O>, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput>
extends SharedTreeModel<M, P, O>
implements Model.Contributions {
    public SharedTreeModelWithContributions(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column);
        String[] outputNames = (String[])ArrayUtils.append((Object[])adaptFrm.names(), (Object[])new String[]{"BiasTerm"});
        return ((ScoreContributionsTask)this.getScoreContributionsTask(this).withPostMapAction((MRTask.PostMapAction)JobUpdatePostMap.forJob(j)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    protected abstract ScoreContributionsTask getScoreContributionsTask(SharedTreeModel var1);

    public class ScoreContributionsTask
    extends MRTask<ScoreContributionsTask> {
        private final Key<SharedTreeModel> _modelKey;
        private transient SharedTreeModel _model;
        private transient SharedTreeModel.SharedTreeOutput _output;
        private transient TreeSHAPPredictor<double[]> _treeSHAP;

        public ScoreContributionsTask(SharedTreeModel model) {
            this._modelKey = model._key;
        }

        protected void setupLocal() {
            this._model = (SharedTreeModel)this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            SharedTreeNode[] empty = new SharedTreeNode[]{};
            ArrayList<TreeSHAP> treeSHAPs = new ArrayList<TreeSHAP>(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    SharedTreeNode[] nodes = tree.nodesArray.toArray(empty);
                    treeSHAPs.add(new TreeSHAP((INode[])nodes, (INodeStat[])nodes, 0));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble(treeSHAPs, (float)this._output._init_f);
        }

        public void map(Chunk[] chks, NewChunk[] nc) {
            assert (chks.length == nc.length - 1);
            double[] input = MemoryManager.malloc8d((int)chks.length);
            float[] contribs = MemoryManager.malloc4f((int)nc.length);
            Object workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                int i;
                for (i = 0; i < chks.length; ++i) {
                    input[i] = chks[i].atd(row);
                }
                for (i = 0; i < contribs.length; ++i) {
                    contribs[i] = 0.0f;
                }
                this._treeSHAP.calculateContributions((Object)input, contribs, 0, -1, workspace);
                this.addContribToNewChunk(contribs, nc);
            }
        }

        protected void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i = 0; i < nc.length; ++i) {
                nc[i].addNum((double)contribs[i]);
            }
        }
    }
}

