/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.genmodel.PredictContributions;
import hex.genmodel.PredictContributionsFactory;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import java.util.ArrayList;

public abstract class SharedTreeMojoModelWithContributions
extends SharedTreeMojoModel
implements PredictContributionsFactory {
    protected SharedTreeMojoModelWithContributions(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    public PredictContributions makeContributionsPredictor() {
        if (this._nclasses > 2) {
            throw new UnsupportedOperationException("Predicting contributions for multinomial classification problems is not yet supported.");
        }
        SharedTreeGraph graph = this.computeGraph(-1);
        SharedTreeNode[] empty = new SharedTreeNode[]{};
        ArrayList treeSHAPs = new ArrayList(graph.subgraphArray.size());
        for (SharedTreeSubgraph tree : graph.subgraphArray) {
            INode[] nodes = tree.nodesArray.toArray(empty);
            treeSHAPs.add(new TreeSHAP(nodes, (INodeStat[])nodes, 0));
        }
        TreeSHAPEnsemble<double[]> predictor = new TreeSHAPEnsemble<double[]>(treeSHAPs, (float)this.getInitF());
        return this.getContributionsPredictor(predictor);
    }

    @Override
    public double getInitF() {
        return 0.0;
    }

    protected abstract ContributionsPredictor getContributionsPredictor(TreeSHAPPredictor<double[]> var1);

    protected static class ContributionsPredictor
    implements PredictContributions {
        private final int _nfeatures;
        private final TreeSHAPPredictor<double[]> _treeSHAPPredictor;
        private static ThreadLocal<Object> _workspace = new ThreadLocal();

        public ContributionsPredictor(SharedTreeMojoModel model, TreeSHAPPredictor<double[]> treeSHAPPredictor) {
            this._nfeatures = model._nfeatures;
            this._treeSHAPPredictor = treeSHAPPredictor;
        }

        @Override
        public final float[] calculateContributions(double[] input) {
            float[] contribs = new float[this._nfeatures + 1];
            this._treeSHAPPredictor.calculateContributions(input, contribs, 0, -1, this.getWorkspace());
            return this.getContribs(contribs);
        }

        private Object getWorkspace() {
            Object workspace = _workspace.get();
            if (workspace == null) {
                workspace = this._treeSHAPPredictor.makeWorkspace();
                _workspace.set(workspace);
            }
            return workspace;
        }

        public float[] getContribs(float[] contribs) {
            return contribs;
        }
    }
}

