// Code in this file started as 1-1 conversion of Native XGBoost implementation in C++ to Java
// please see:
//    https://github.com/dmlc/xgboost/blob/master/src/tree/tree_model.cc
// All credit for this implementation goes to XGBoost Contributors:
//    https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
// Licensed under Apache License Version 2.0
package biz.k11i.xgboost.tree;

import biz.k11i.xgboost.util.FVec;

/**
 * Warning: Experimental (alpha-level) code - needs testing for correctness & benchmarking!
 *
 * NOT thread safe, internal state is mutated in order to avoid memory allocation on a per-observation level
 */
public class TreeSHAP {

  private final RegTreeImpl.Node[] nodes;
  private final RegTreeImpl.RTreeNodeStat[] stats;
  private final float expectedTreeValue;

  private final PathElement[] unique_path_data; // shared state

  public TreeSHAP(RegTreeImpl tree) {
    this.nodes = tree.getNodes();
    this.stats = tree.getStats();
    this.expectedTreeValue = treeMeanValue();

    // Preallocate space for the unique path data
    final int maxd = treeDepth() + 2;
    this.unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
    for (int i = 0; i < unique_path_data.length; i++) {
      this.unique_path_data[i] = new PathElement();
    }
  }

  private static class PathElement {
    int feature_index;
    float zero_fraction;
    float one_fraction;
    float pweight;
    void reset() {
      feature_index = 0;
      zero_fraction = 0;
      one_fraction = 0;
      pweight = 0;
    }
  }

  // extend our decision path with a fraction of one and zero extensions
  private void extendPath(PathPointer unique_path, int unique_depth,
                          float zero_fraction, float one_fraction,
                          int feature_index) {
    unique_path.get(unique_depth).feature_index = feature_index;
    unique_path.get(unique_depth).zero_fraction = zero_fraction;
    unique_path.get(unique_depth).one_fraction = one_fraction;
    unique_path.get(unique_depth).pweight = (unique_depth == 0 ? 1.0f : 0.0f);
    for (int i = unique_depth - 1; i >= 0; i--) {
      unique_path.get(i+1).pweight += one_fraction * unique_path.get(i).pweight * (i + 1)
              / (float) (unique_depth + 1);
      unique_path.get(i).pweight = zero_fraction * unique_path.get(i).pweight * (unique_depth - i)
              / (float) (unique_depth + 1);
    }
  }

  // undo a previous extension of the decision path
  private void unwindPath(PathPointer unique_path, int unique_depth,
                          int path_index) {
    final float one_fraction = unique_path.get(path_index).one_fraction;
    final float zero_fraction = unique_path.get(path_index).zero_fraction;
    float next_one_portion = unique_path.get(unique_depth).pweight;

    for (int i = unique_depth - 1; i >= 0; --i) {
      if (one_fraction != 0) {
        final float tmp = unique_path.get(i).pweight;
        unique_path.get(i).pweight = next_one_portion * (unique_depth + 1)
                / ((i + 1) * one_fraction);
        next_one_portion = tmp - unique_path.get(i).pweight * zero_fraction * (unique_depth - i)
                / (float) (unique_depth + 1);
      } else {
        unique_path.get(i).pweight = (unique_path.get(i).pweight * (unique_depth + 1))
                / (zero_fraction * (unique_depth - i));
      }
    }

    for (int i = path_index; i < unique_depth; ++i) {
      unique_path.get(i).feature_index = unique_path.get(i+1).feature_index;
      unique_path.get(i).zero_fraction = unique_path.get(i+1).zero_fraction;
      unique_path.get(i).one_fraction = unique_path.get(i+1).one_fraction;
    }
  }

  // determine what the total permutation weight would be if
  // we unwound a previous extension in the decision path
  private float unwoundPathSum(final PathPointer unique_path, int unique_depth,
                               int path_index) {
    final float one_fraction = unique_path.get(path_index).one_fraction;
    final float zero_fraction = unique_path.get(path_index).zero_fraction;
    float next_one_portion = unique_path.get(unique_depth).pweight;
    float total = 0;
    for (int i = unique_depth - 1; i >= 0; --i) {
      if (one_fraction != 0) {
        final float tmp = next_one_portion * (unique_depth + 1)
                / ((i + 1) * one_fraction);
        total += tmp;
        next_one_portion = unique_path.get(i).pweight - tmp * zero_fraction * ((unique_depth - i)
                / (float)(unique_depth + 1));
      } else if (zero_fraction != 0) {
        total += (unique_path.get(i).pweight / zero_fraction) / ((unique_depth - i)
                / (float)(unique_depth + 1));
      } else {
        if (unique_path.get(i).pweight != 0)
          throw new IllegalStateException("Unique path " + i + " must have zero weight");
      }
    }
    return total;
  }

  // recursive computation of SHAP values for a decision tree
  private void treeShap(final FVec feat, float[] phi,
                        int node_index, int unique_depth,
                        PathPointer parent_unique_path,
                        float parent_zero_fraction,
                        float parent_one_fraction, int parent_feature_index,
                        int condition, int condition_feature,
                        float condition_fraction) {

    final RegTreeImpl.Node node = nodes[node_index];

    // stop if we have no weight coming down to us
    if (condition_fraction == 0) return;

    // extend the unique path
    PathPointer unique_path = parent_unique_path.move(unique_depth);

    if (condition == 0 || condition_feature != parent_feature_index) {
      extendPath(unique_path, unique_depth, parent_zero_fraction,
              parent_one_fraction, parent_feature_index);
    }
    final int split_index = node.split_index();

    // leaf node
    if (node.is_leaf()) {
      for (int i = 1; i <= unique_depth; ++i) {
        final float w = unwoundPathSum(unique_path, unique_depth, i);
        final PathElement el = unique_path.get(i);
        phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction)
                * node.leaf_value * condition_fraction;
      }

      // internal node
    } else {
      // find which branch is "hot" (meaning x would follow it)
      int hot_index = 0;
      if (Float.isNaN(feat.fvalue(split_index))) {
        hot_index = node.cdefault();
      } else if (feat.fvalue(split_index) < node.split_cond) {
        hot_index = node.cleft_;
      } else {
        hot_index = node.cright_;
      }
      final int cold_index = hot_index == node.cleft_ ? node.cright_ : node.cleft_;
      final float w = stats[node_index].sum_hess;
      final float hot_zero_fraction = stats[hot_index].sum_hess / w;
      final float cold_zero_fraction = stats[cold_index].sum_hess / w;
      float incoming_zero_fraction = 1;
      float incoming_one_fraction = 1;

      // see if we have already split on this feature,
      // if so we undo that split so we can redo it for this node
      int path_index = 0;
      for (; path_index <= unique_depth; ++path_index) {
        if (unique_path.get(path_index).feature_index == split_index)
          break;
      }
      if (path_index != unique_depth + 1) {
        incoming_zero_fraction = unique_path.get(path_index).zero_fraction;
        incoming_one_fraction = unique_path.get(path_index).one_fraction;
        unwindPath(unique_path, unique_depth, path_index);
        unique_depth -= 1;
      }

      // divide up the condition_fraction among the recursive calls
      float hot_condition_fraction = condition_fraction;
      float cold_condition_fraction = condition_fraction;
      if (condition > 0 && split_index == condition_feature) {
        cold_condition_fraction = 0;
        unique_depth -= 1;
      } else if (condition < 0 && split_index == condition_feature) {
        hot_condition_fraction *= hot_zero_fraction;
        cold_condition_fraction *= cold_zero_fraction;
        unique_depth -= 1;
      }

      treeShap(feat, phi, hot_index, unique_depth + 1, unique_path,
              hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
              split_index, condition, condition_feature, hot_condition_fraction);

      treeShap(feat, phi, cold_index, unique_depth + 1, unique_path,
              cold_zero_fraction * incoming_zero_fraction, 0,
              split_index, condition, condition_feature, cold_condition_fraction);
    }
  }

  private static class PathPointer {
    PathElement[] path;
    int position;

    PathPointer(PathElement[] path) {
      this.path = path;
    }

    PathPointer(PathElement[] path, int position) {
      this.path = path;
      this.position = position;
    }

    PathElement get(int i) {
      return path[position + i];
    }

    PathPointer move(int len) {
      for (int i = 0; i < len; i++) {
        path[position + len + i].feature_index = path[position + i].feature_index;
        path[position + len + i].zero_fraction = path[position + i].zero_fraction;
        path[position + len + i].one_fraction = path[position + i].one_fraction;
        path[position + len + i].pweight = path[position + i].pweight;
      }
      return new PathPointer(path, position + len);
    }
  }

  public void calculateContributions(final FVec feat,
                                     int root_id, float[] out_contribs,
                                     int condition,
                                     int condition_feature) {

    // find the expected value of the tree's predictions
    if (condition == 0) {
      out_contribs[out_contribs.length - 1] += expectedTreeValue;
    }

    PathPointer uniquePathWorkspace = getWorkspace();

    treeShap(feat, out_contribs, root_id, 0, uniquePathWorkspace,
            1, 1, -1, condition, condition_feature, 1);
  }

  private PathPointer getWorkspace() {
    this.unique_path_data[0].reset();
    return new PathPointer(this.unique_path_data);
  }

  private int treeDepth() {
    return nodeDepth(nodes, 0);
  }

  private static int nodeDepth(RegTreeImpl.Node[] nodes, int node) {
    final RegTreeImpl.Node n = nodes[node];
    if (n.is_leaf()) {
      return 1;
    } else {
      return 1 + Math.max(nodeDepth(nodes, n.cleft_), nodeDepth(nodes, n.cright_));
    }
  }

  private float treeMeanValue() {
    return nodeMeanValue(nodes, stats, 0);
  }

  private static float nodeMeanValue(RegTreeImpl.Node[] nodes, RegTreeImpl.RTreeNodeStat[] stats, int node) {
    final RegTreeImpl.Node n = nodes[node];
    if (n.is_leaf()) {
      return n.leaf_value;
    } else {
      return (stats[n.cleft_].sum_hess * nodeMeanValue(nodes, stats, n.cleft_) +
              stats[n.cright_].sum_hess * nodeMeanValue(nodes, stats, n.cright_)) / stats[node].sum_hess;
    }
  }

}
