/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import water.DKV;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.util.ArrayUtils;
import water.util.VecUtils;

public class FriedmanPopescusH {
    public static double h(Frame frame, String[] vars, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        int i;
        Frame filteredFrame = FriedmanPopescusH.filterFrame(frame, vars);
        int[] modelIds = FriedmanPopescusH.getModelIds(frame.names(), vars);
        HashMap<String, Frame> fValues = new HashMap<String, Frame>();
        int numCols = filteredFrame.numCols();
        int[] colIds = new int[numCols];
        for (i = 0; i < numCols; ++i) {
            colIds[i] = i;
        }
        for (i = numCols; i > 0; --i) {
            List<int[]> currCombinations = FriedmanPopescusH.combinations(colIds, i);
            for (int j = 0; j < currCombinations.size(); ++j) {
                int[] currCombination = currCombinations.get(j);
                String[] cols = FriedmanPopescusH.getCurrCombinationCols(currCombination, vars);
                Integer[] currModelIds = FriedmanPopescusH.getCurrentCombinationModelIds(currCombination, modelIds);
                fValues.put(Arrays.toString(currCombination), FriedmanPopescusH.computeFValues(currModelIds, filteredFrame, cols, learnRate, sharedTreeSubgraphs));
            }
        }
        return FriedmanPopescusH.computeHValue(fValues, filteredFrame, colIds);
    }

    static Integer[] getCurrentCombinationModelIds(int[] currCombination, int[] modelIds) {
        Integer[] currCombinationCols = new Integer[currCombination.length];
        for (int i = 0; i < currCombination.length; ++i) {
            currCombinationCols[i] = modelIds[currCombination[i]];
        }
        return currCombinationCols;
    }

    static double computeHValue(Map<String, Frame> fValues, Frame filteredFrame, int[] inds) {
        if (filteredFrame._key == null) {
            filteredFrame._key = Key.make();
        }
        Frame uniqueWithCounts = FriedmanPopescusH.uniqueRowsWithCounts(filteredFrame);
        long uniqHeight = uniqueWithCounts.numRows();
        Vec numerEls = Vec.makeZero((long)uniqHeight);
        Vec denomEls = Vec.makeZero((long)uniqHeight);
        for (long i = 0L; i < uniqHeight; ++i) {
            int sign = 1;
            for (int n = inds.length; n > 0; --n) {
                List<int[]> currCombinations = FriedmanPopescusH.combinations(inds, n);
                for (int j = 0; j < currCombinations.size(); ++j) {
                    double fValue = FriedmanPopescusH.findFValue(i, (int[])currCombinations.toArray()[j], fValues.get(Arrays.toString((int[])currCombinations.toArray()[j])), filteredFrame);
                    numerEls.set(i, numerEls.at(i) + (double)((float)sign * (float)fValue));
                }
                sign *= -1;
            }
            denomEls.set(i, (float)fValues.get(Arrays.toString(inds)).vec(0).at(i));
        }
        double numer = ((Transform)new Transform((int)2).doAll((Vec[])new Vec[]{numerEls, uniqueWithCounts.vec((String)"nrow")})).result;
        double denom = ((Transform)new Transform((int)2).doAll((Vec[])new Vec[]{denomEls, uniqueWithCounts.vec((String)"nrow")})).result;
        return numer < denom ? Math.sqrt(numer / denom) : Double.NaN;
    }

    static double[] getValueToFindFValueFor(int[] currCombination, Frame filteredFrame, long i) {
        int combinationLength = currCombination.length;
        double[] value = new double[combinationLength];
        for (int j = 0; j < combinationLength; ++j) {
            value[j] = filteredFrame.vec(currCombination[j]).at(i);
        }
        return value;
    }

    static double findFValue(long i, int[] currCombination, Frame currFValues, Frame filteredFrame) {
        String[] currNames;
        double[] valueToFindFValueFor = FriedmanPopescusH.getValueToFindFValueFor(currCombination, filteredFrame, i);
        FindFValue findFValueTask = new FindFValue(valueToFindFValueFor, currNames = FriedmanPopescusH.getCurrCombinationNames(currCombination, filteredFrame.names()), currFValues._names, 1.0E-5);
        Frame result = ((FindFValue)findFValueTask.doAll((byte)3, currFValues)).outputFrame();
        if (result.numRows() == 0L) {
            throw new RuntimeException("FValue was not found!" + Arrays.toString(currCombination) + "value: " + Arrays.toString(valueToFindFValueFor));
        }
        return result.vec(0).at(0L);
    }

    static String[] getCurrCombinationNames(int[] currCombination, String[] names) {
        String[] currNames = new String[currCombination.length];
        for (int j = 0; j < currCombination.length; ++j) {
            currNames[j] = names[currCombination[j]];
        }
        return currNames;
    }

    static String[] getCurrCombinationCols(int[] currCombination, String[] vars) {
        String[] currCombinationCols = new String[currCombination.length];
        for (int i = 0; i < currCombination.length; ++i) {
            currCombinationCols[i] = vars[currCombination[i]];
        }
        return currCombinationCols;
    }

    static int findFirstNumericalColumn(Frame frame) {
        for (int i = 0; i < frame.names().length; ++i) {
            if (!frame.vec(i).isNumeric()) continue;
            return i;
        }
        return -1;
    }

    static Frame uniqueRowsWithCounts(Frame frame) {
        int i;
        DKV.put((Keyed)frame);
        StringBuilder sb = new StringBuilder("(GB ");
        String[] cols = frame.names();
        sb.append(frame._key.toString());
        sb.append(" [");
        for (i = 0; i < cols.length; ++i) {
            if (i != 0) {
                sb.append(",");
            }
            sb.append(i);
        }
        sb.append("] ");
        i = FriedmanPopescusH.findFirstNumericalColumn(frame);
        if (i == -1) {
            frame.add("nrow", Vec.makeOne((long)frame.numRows()));
            return frame;
        }
        sb.append(" nrow ").append(i).append(" \"all\")");
        Val val = Rapids.exec((String)sb.toString());
        DKV.remove((Key)frame._key);
        return val.getFrame();
    }

    static Frame computeFValues(Integer[] modelIds, Frame filteredFrame, String[] cols, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        filteredFrame = FriedmanPopescusH.filterFrame(filteredFrame, cols);
        filteredFrame = new Frame(Key.make(), filteredFrame.names(), filteredFrame.vecs());
        Frame uniqueWithCounts = FriedmanPopescusH.uniqueRowsWithCounts(filteredFrame);
        Frame uncenteredFvalues = new Frame(new Vec[]{FriedmanPopescusH.partialDependence(modelIds, uniqueWithCounts, learnRate, sharedTreeSubgraphs).vec(0)});
        VecUtils.DotProduct multiply = (VecUtils.DotProduct)new VecUtils.DotProduct().doAll(new Vec[]{uniqueWithCounts.vec("nrow"), uncenteredFvalues.vec(0)});
        double meanUncenteredFValue = multiply.result / (double)filteredFrame.numRows();
        int i = 0;
        while ((long)i < uncenteredFvalues.numRows()) {
            uncenteredFvalues.vec(0).set((long)i, uncenteredFvalues.vec(0).at((long)i) - meanUncenteredFValue);
            ++i;
        }
        return uncenteredFvalues.add(uniqueWithCounts);
    }

    static Frame partialDependence(Integer[] modelIds, Frame uniqueWithCounts, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        Frame result = new Frame(new Vec[0]);
        int nclasses = sharedTreeSubgraphs[0].length;
        int ntrees = sharedTreeSubgraphs.length;
        for (int treeClass = 0; treeClass < nclasses; ++treeClass) {
            Vec pdp = Vec.makeZero((long)uniqueWithCounts.numRows());
            for (int i = 0; i < ntrees; ++i) {
                SharedTreeSubgraph sharedTreeSubgraph = sharedTreeSubgraphs[i][treeClass];
                Vec currTreePdp = FriedmanPopescusH.partialDependenceTree(sharedTreeSubgraph, modelIds, learnRate, uniqueWithCounts);
                for (long j = 0L; j < uniqueWithCounts.numRows(); ++j) {
                    pdp.set(j, pdp.at(j) + currTreePdp.at(j));
                }
            }
            result.add("pdp_C" + treeClass, pdp);
        }
        return result;
    }

    public static double[] add(double[] first, double[] second) {
        int length = first.length < second.length ? first.length : second.length;
        double[] result = new double[length];
        for (int i = 0; i < length; ++i) {
            result[i] = first[i] + second[i];
        }
        return result;
    }

    static Frame filterFrame(Frame frame, String[] cols) {
        Frame frame1 = new Frame(new Vec[0]);
        frame1.add(cols, frame.vecs(cols));
        return frame1;
    }

    static int[] getModelIds(String[] frameNames, String[] vars) {
        int[] modelIds = new int[vars.length];
        Arrays.fill(modelIds, -1);
        for (int i = 0; i < vars.length; ++i) {
            for (int j = 0; j < frameNames.length; ++j) {
                if (!vars[i].equals(frameNames[j])) continue;
                modelIds[i] = j;
            }
            if (modelIds[i] != -1) continue;
            throw new RuntimeException("Column " + vars[i] + " is not present in the input frame!");
        }
        return modelIds;
    }

    static List<int[]> combinations(int[] vals, int combinationSize) {
        ArrayList<int[]> overallResult = new ArrayList<int[]>();
        FriedmanPopescusH.combinations(vals, combinationSize, 0, new int[combinationSize], overallResult);
        return overallResult;
    }

    private static void combinations(int[] arr, int len, int startPosition, int[] result, List<int[]> overallResult) {
        if (len == 0) {
            overallResult.add((int[])result.clone());
            return;
        }
        for (int i = startPosition; i <= arr.length - len; ++i) {
            result[result.length - len] = arr[i];
            FriedmanPopescusH.combinations(arr, len - 1, i + 1, result, overallResult);
        }
    }

    public static Vec partialDependenceTree(SharedTreeSubgraph tree, Integer[] targetFeature, double learnRate, Frame grid) {
        Vec outVec = Vec.makeZero((long)grid.numRows());
        SharedTreeNode[] nodeStackAr = new SharedTreeNode[tree.nodesArray.size() * 2];
        Object[] weightStackAr = new Double[tree.nodesArray.size() * 2];
        Arrays.fill(weightStackAr, (Object)1.0);
        for (long i = 0L; i < grid.numRows(); ++i) {
            int stackSize = 1;
            nodeStackAr[0] = tree.rootNode;
            weightStackAr[0] = 1.0;
            double totalWeight = 0.0;
            while (stackSize > 0) {
                SharedTreeNode currNode;
                if ((currNode = nodeStackAr[--stackSize]).isLeaf()) {
                    outVec.set(i, outVec.at(i) + (Double)weightStackAr[stackSize] * (double)currNode.getPredValue() * learnRate);
                    totalWeight += ((Double)weightStackAr[stackSize]).doubleValue();
                    continue;
                }
                int featureId = ArrayUtils.indexOf((Comparable[])targetFeature, (Comparable)Integer.valueOf(currNode.getColId()));
                if (featureId != -1) {
                    nodeStackAr[stackSize] = grid.vec(featureId).at(i) <= (double)currNode.getSplitValue() ? currNode.getLeftChild() : currNode.getRightChild();
                    ++stackSize;
                    continue;
                }
                double currWeight = (Double)weightStackAr[stackSize];
                nodeStackAr[stackSize] = currNode.getLeftChild();
                double left_sample_frac = currNode.getLeftChild().getWeight() / currNode.getWeight();
                weightStackAr[stackSize] = currWeight * left_sample_frac;
                nodeStackAr[++stackSize] = currNode.getRightChild();
                weightStackAr[stackSize] = currWeight * (1.0 - left_sample_frac);
                ++stackSize;
            }
            if (0.999 < totalWeight && totalWeight < 1.001) continue;
            throw new RuntimeException("Total weight should be 1.0 but was " + totalWeight);
        }
        return outVec;
    }

    static class FindFValue
    extends MRTask<FindFValue> {
        double[] valueToFindFValueFor;
        String[] currNames;
        String[] currFValuesNames;
        double eps;

        FindFValue(double[] valueToFindFValueFor, String[] currNames, String[] currFValuesNames, double eps) {
            this.valueToFindFValueFor = valueToFindFValueFor;
            this.currNames = currNames;
            this.currFValuesNames = currFValuesNames;
            this.eps = eps;
        }

        public void map(Chunk[] cs, NewChunk[] nc) {
            int count = 0;
            for (int iRow = 0; iRow < cs[0].len(); ++iRow) {
                for (int k = 0; k < this.valueToFindFValueFor.length; ++k) {
                    int id = ArrayUtils.find((Object[])this.currFValuesNames, (Object)this.currNames[k]);
                    if (!(Math.abs(this.valueToFindFValueFor[k] - cs[id].atd(iRow)) < this.eps)) continue;
                    ++count;
                }
                if (count == this.valueToFindFValueFor.length) {
                    nc[0].addNum(cs[0].atd(iRow));
                    continue;
                }
                count = 0;
            }
        }
    }

    private static class Transform
    extends MRTask<Transform> {
        double result;
        int power;

        Transform(int power) {
            this.power = power;
        }

        public void map(Chunk[] bvs) {
            this.result = 0.0;
            int len = bvs[0]._len;
            for (int i = 0; i < len; ++i) {
                this.result += Math.pow(bvs[0].atd(i), 2.0) * bvs[1].atd(i);
            }
        }

        public void reduce(Transform mrt) {
            this.result += mrt.result;
        }
    }
}

