/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import ai.h2o.a.a.a;
import ai.h2o.a.a.b;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import java.io.Serializable;

public class TreeSHAP<R, N extends a<R>, S extends b>
implements TreeSHAPPredictor<R> {
    private final int rootNodeId;
    private final N[] nodes;
    private final S[] stats;
    private final float expectedTreeValue;

    public TreeSHAP(N[] nodes, S[] stats, int rootNodeId) {
        this.rootNodeId = rootNodeId;
        this.nodes = nodes;
        this.stats = stats;
        this.expectedTreeValue = this.treeMeanValue();
    }

    private void extendPath(PathPointer unique_path, int unique_depth, float zero_fraction, float one_fraction, int feature_index) {
        unique_path.get((int)unique_depth).feature_index = feature_index;
        unique_path.get((int)unique_depth).zero_fraction = zero_fraction;
        unique_path.get((int)unique_depth).one_fraction = one_fraction;
        unique_path.get((int)unique_depth).pweight = unique_depth == 0 ? 1.0f : 0.0f;
        for (int i2 = unique_depth - 1; i2 >= 0; --i2) {
            unique_path.get((int)(i2 + 1)).pweight += one_fraction * unique_path.get((int)i2).pweight * (float)(i2 + 1) / (float)(unique_depth + 1);
            unique_path.get((int)i2).pweight = zero_fraction * unique_path.get((int)i2).pweight * (float)(unique_depth - i2) / (float)(unique_depth + 1);
        }
    }

    private void unwindPath(PathPointer unique_path, int unique_depth, int path_index) {
        int n2;
        float f2 = unique_path.get((int)path_index).one_fraction;
        float f3 = unique_path.get((int)path_index).zero_fraction;
        float f4 = unique_path.get((int)unique_depth).pweight;
        for (n2 = unique_depth - 1; n2 >= 0; --n2) {
            if (f2 != 0.0f) {
                float f5 = unique_path.get((int)n2).pweight;
                unique_path.get((int)n2).pweight = f4 * (float)(unique_depth + 1) / ((float)(n2 + 1) * f2);
                f4 = f5 - unique_path.get((int)n2).pweight * f3 * (float)(unique_depth - n2) / (float)(unique_depth + 1);
                continue;
            }
            unique_path.get((int)n2).pweight = unique_path.get((int)n2).pweight * (float)(unique_depth + 1) / (f3 * (float)(unique_depth - n2));
        }
        for (n2 = path_index; n2 < unique_depth; ++n2) {
            unique_path.get((int)n2).feature_index = unique_path.get((int)(n2 + 1)).feature_index;
            unique_path.get((int)n2).zero_fraction = unique_path.get((int)(n2 + 1)).zero_fraction;
            unique_path.get((int)n2).one_fraction = unique_path.get((int)(n2 + 1)).one_fraction;
        }
    }

    private float unwoundPathSum(PathPointer unique_path, int unique_depth, int path_index) {
        float f2 = unique_path.get((int)path_index).one_fraction;
        float f3 = unique_path.get((int)path_index).zero_fraction;
        float f4 = unique_path.get((int)unique_depth).pweight;
        float f5 = 0.0f;
        for (int i2 = unique_depth - 1; i2 >= 0; --i2) {
            if (f2 != 0.0f) {
                float f6 = f4 * (float)(unique_depth + 1) / ((float)(i2 + 1) * f2);
                f5 += f6;
                f4 = unique_path.get((int)i2).pweight - f6 * f3 * ((float)(unique_depth - i2) / (float)(unique_depth + 1));
                continue;
            }
            if (f3 != 0.0f) {
                f5 += unique_path.get((int)i2).pweight / f3 / ((float)(unique_depth - i2) / (float)(unique_depth + 1));
                continue;
            }
            if (unique_path.get((int)i2).pweight == 0.0f) continue;
            throw new IllegalStateException("Unique path " + i2 + " must have zero getWeight");
        }
        return f5;
    }

    private void treeShap(R feat, float[] phi, N node, S nodeStat, int unique_depth, PathPointer parent_unique_path, float parent_zero_fraction, float parent_one_fraction, int parent_feature_index, int condition, int condition_feature, float condition_fraction) {
        while (condition_fraction != 0.0f) {
            int n2;
            int n3;
            PathPointer pathPointer = parent_unique_path.move(unique_depth);
            if (condition == 0 || condition_feature != parent_feature_index) {
                this.extendPath(pathPointer, unique_depth, parent_zero_fraction, parent_one_fraction, parent_feature_index);
            }
            int n4 = node.getSplitIndex();
            if (node.isLeaf()) {
                for (n3 = 1; n3 <= unique_depth; ++n3) {
                    float f2 = this.unwoundPathSum(pathPointer, unique_depth, n3);
                    PathElement pathElement = pathPointer.get(n3);
                    int n5 = pathElement.feature_index;
                    phi[n5] = phi[n5] + f2 * (pathElement.one_fraction - pathElement.zero_fraction) * node.getLeafValue() * condition_fraction;
                }
                return;
            }
            n3 = node.next(feat);
            int n6 = n3 == node.getLeftChildIndex() ? node.getRightChildIndex() : node.getLeftChildIndex();
            float f3 = nodeStat.getWeight();
            float f4 = this.stats[n3].getWeight() / f3;
            float f5 = this.stats[n6].getWeight() / f3;
            float f6 = 1.0f;
            float f7 = 1.0f;
            for (n2 = 0; n2 <= unique_depth && pathPointer.get((int)n2).feature_index != n4; ++n2) {
            }
            if (n2 != unique_depth + 1) {
                f6 = pathPointer.get((int)n2).zero_fraction;
                f7 = pathPointer.get((int)n2).one_fraction;
                this.unwindPath(pathPointer, unique_depth, n2);
                --unique_depth;
            }
            float f8 = condition_fraction;
            float f9 = condition_fraction;
            if (condition > 0 && n4 == condition_feature) {
                f9 = 0.0f;
                --unique_depth;
            } else if (condition < 0 && n4 == condition_feature) {
                f8 = condition_fraction * f4;
                f9 = condition_fraction * f5;
                --unique_depth;
            }
            this.treeShap(feat, phi, this.nodes[n3], this.stats[n3], unique_depth + 1, pathPointer, f4 * f6, f7, n4, condition, condition_feature, f8);
            condition_fraction = f9;
            parent_feature_index = n4;
            parent_one_fraction = 0.0f;
            parent_zero_fraction = f5 * f6;
            parent_unique_path = pathPointer;
            ++unique_depth;
            nodeStat = this.stats[n6];
            node = this.nodes[n6];
        }
        return;
    }

    @Override
    public float[] calculateContributions(R feat, float[] out_contribs) {
        return this.calculateContributions(feat, out_contribs, 0, -1, this.makeWorkspace());
    }

    @Override
    public float[] calculateContributions(R feat, float[] out_contribs, int condition, int condition_feature, Object workspace) {
        if (condition == 0) {
            float[] fArray = out_contribs;
            int n2 = out_contribs.length - 1;
            fArray[n2] = fArray[n2] + this.expectedTreeValue;
        }
        PathPointer pathPointer = (PathPointer)workspace;
        pathPointer.reset();
        this.treeShap(feat, out_contribs, this.nodes[this.rootNodeId], this.stats[this.rootNodeId], 0, pathPointer, 1.0f, 1.0f, -1, condition, condition_feature, 1.0f);
        return out_contribs;
    }

    @Override
    public PathPointer makeWorkspace() {
        int n2 = this.getWorkspaceSize();
        PathElement[] pathElementArray = new PathElement[n2];
        for (int i2 = 0; i2 < n2; ++i2) {
            pathElementArray[i2] = new PathElement();
        }
        return new PathPointer(pathElementArray);
    }

    @Override
    public int getWorkspaceSize() {
        int n2 = this.treeDepth() + 2;
        return n2 * (n2 + 1) / 2;
    }

    private int treeDepth() {
        return TreeSHAP.nodeDepth(this.nodes, (int)0);
    }

    private static <N extends a> int nodeDepth(N[] nodes, int node) {
        N n2 = nodes[node];
        if (n2.isLeaf()) {
            return 1;
        }
        return 1 + Math.max(TreeSHAP.nodeDepth(nodes, (int)n2.getLeftChildIndex()), TreeSHAP.nodeDepth(nodes, (int)n2.getRightChildIndex()));
    }

    private float treeMeanValue() {
        return TreeSHAP.nodeMeanValue(this.nodes, this.stats, (int)0);
    }

    private static <N extends a, S extends b> float nodeMeanValue(N[] nodes, S[] stats, int node) {
        N n2 = nodes[node];
        if (n2.isLeaf()) {
            return n2.getLeafValue();
        }
        return (stats[n2.getLeftChildIndex()].getWeight() * TreeSHAP.nodeMeanValue(nodes, stats, (int)n2.getLeftChildIndex()) + stats[n2.getRightChildIndex()].getWeight() * TreeSHAP.nodeMeanValue(nodes, stats, (int)n2.getRightChildIndex())) / stats[node].getWeight();
    }

    public static class PathPointer {
        PathElement[] path;
        int position;

        PathPointer(PathElement[] path) {
            this.path = path;
        }

        PathPointer(PathElement[] path, int position) {
            this.path = path;
            this.position = position;
        }

        PathElement get(int i2) {
            return this.path[this.position + i2];
        }

        PathPointer move(int len) {
            for (int i2 = 0; i2 < len; ++i2) {
                this.path[this.position + len + i2].feature_index = this.path[this.position + i2].feature_index;
                this.path[this.position + len + i2].zero_fraction = this.path[this.position + i2].zero_fraction;
                this.path[this.position + len + i2].one_fraction = this.path[this.position + i2].one_fraction;
                this.path[this.position + len + i2].pweight = this.path[this.position + i2].pweight;
            }
            return new PathPointer(this.path, this.position + len);
        }

        void reset() {
            this.path[0].reset();
        }
    }

    private static class PathElement
    implements Serializable {
        int feature_index;
        float zero_fraction;
        float one_fraction;
        float pweight;

        private PathElement() {
        }

        void reset() {
            this.feature_index = 0;
            this.zero_fraction = 0.0f;
            this.one_fraction = 0.0f;
            this.pweight = 0.0f;
        }
    }
}

