/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.algo.RegTree;

import com.tencent.angel.ml.GBDT.algo.GBDTController;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradPair;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradStats;
import com.tencent.angel.ml.GBDT.algo.tree.SplitEntry;
import com.tencent.angel.ml.GBDT.param.GBDTParam;
import com.tencent.angel.ml.core.conf.MLConf;
import com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage;
import com.tencent.angel.ml.math2.storage.IntDoubleVectorStorage;
import com.tencent.angel.ml.math2.vector.IntDoubleVector;
import com.tencent.angel.ml.math2.vector.IntFloatVector;
import com.tencent.angel.ps.storage.vector.ServerIntDoubleRow;
import com.tencent.angel.worker.WorkerContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class GradHistHelper {
    private static final Log LOG = LogFactory.getLog(GradHistHelper.class);
    private GBDTController controller;
    private int nid;

    public GradHistHelper(GBDTController controller, int nid) {
        this.controller = controller;
        this.nid = nid;
    }

    public IntDoubleVector buildHistogram(int insStart, int insEnd) {
        int featureNum = this.controller.fSet.length;
        int splitNum = this.controller.param.numSplit;
        IntDoubleVector histogram = new IntDoubleVector(featureNum * 2 * splitNum, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(new double[featureNum * 2 * splitNum]));
        int nodeStart = insStart;
        int nodeEnd = insEnd;
        LOG.debug((Object)String.format("Build histogram of node[%d]: size[%d] instance span [%d - %d]", this.nid, histogram.getDim(), nodeStart, nodeEnd));
        float gradSum = 0.0f;
        float hessSum = 0.0f;
        long parseInstanceTime = 0L;
        long startTime = System.currentTimeMillis();
        for (int idx = nodeStart; idx <= nodeEnd; ++idx) {
            int insIdx = this.controller.instancePos[idx];
            GradPair gradPair = this.controller.gradPairs[insIdx];
            gradSum += gradPair.getGrad();
            hessSum += gradPair.getHess();
            IntFloatVector instance = this.controller.trainDataStore.instances[insIdx];
            int numNnz = instance.getStorage().getIndices().length;
            long tmpTime = System.currentTimeMillis();
            int[] indices = instance.getStorage().getIndices();
            float[] values = instance.getStorage().getValues();
            parseInstanceTime += System.currentTimeMillis() - tmpTime;
            for (int i = 0; i < numNnz; ++i) {
                int fid = indices[i];
                float fv = values[i];
                int fPos = this.controller.fPos[fid];
                if (fPos == -1) continue;
                int start = fid * splitNum;
                int end = this.controller.cateFeatNum.containsKey(fid) ? start + this.controller.cateFeatNum.get(fid) - 1 : start + splitNum - 1;
                int fValueIdx = GradHistHelper.findFvaluePlace(this.controller.sketches, fv, start, end);
                assert (fValueIdx >= 0 && fValueIdx < splitNum);
                int gradIdx = 2 * splitNum * fPos + fValueIdx;
                int hessIdx = gradIdx + splitNum;
                histogram.set(gradIdx, histogram.get(gradIdx) + (double)gradPair.getGrad());
                histogram.set(hessIdx, histogram.get(hessIdx) + (double)gradPair.getHess());
                int fZeroValueIdx = GradHistHelper.findFvaluePlace(this.controller.sketches, 0.0f, start, end);
                assert (fZeroValueIdx >= 0 && fZeroValueIdx < splitNum);
                int gradZeroIdx = 2 * splitNum * fPos + fZeroValueIdx;
                int hessZeroIdx = gradZeroIdx + splitNum;
                double curGrad = histogram.get(gradZeroIdx);
                double curHess = histogram.get(hessZeroIdx);
                histogram.set(gradZeroIdx, curGrad - (double)gradPair.getGrad());
                histogram.set(hessZeroIdx, curHess - (double)gradPair.getHess());
            }
        }
        for (int fid = 0; fid < featureNum; ++fid) {
            int fPos = GradHistHelper.findFidPlace(this.controller.fSet, fid);
            if (fPos == -1) continue;
            int start = fPos * splitNum;
            int end = this.controller.cateFeatNum.containsKey(fid) ? start + this.controller.cateFeatNum.get(fid) - 1 : start + splitNum - 1;
            int fZeroValueIdx = GradHistHelper.findFvaluePlace(this.controller.sketches, 0.0f, start, end);
            int gradZeroIdx = 2 * splitNum * fPos + fZeroValueIdx;
            int hessZeroIdx = 2 * splitNum * fPos + fZeroValueIdx + splitNum;
            histogram.set(gradZeroIdx, histogram.get(gradZeroIdx) + (double)gradSum);
            histogram.set(hessZeroIdx, histogram.get(hessZeroIdx) + (double)hessSum);
        }
        LOG.debug((Object)String.format("Build histogram cost %d ms, parse instance cost %d ms", System.currentTimeMillis() - startTime, parseInstanceTime));
        return histogram;
    }

    public SplitEntry findBestSplit(IntDoubleVector histogram) throws Exception {
        LOG.debug((Object)String.format("------To find the best split of node[%d]------", this.nid));
        SplitEntry splitEntry = new SplitEntry();
        LOG.debug((Object)String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue)));
        GradStats rootStats = null;
        if (null != histogram) {
            rootStats = this.calGradStats(histogram);
            if (this.nid == 0) {
                this.controller.updateNodeGradStats(this.nid, rootStats);
            }
        } else {
            LOG.error((Object)"null histogram.");
        }
        if (null == rootStats) {
            LOG.error((Object)"null root stat.");
            return splitEntry;
        }
        for (int fid = 0; fid < this.controller.fSet.length; ++fid) {
            int trueFid = this.controller.fSet[fid];
            int startIdx = 2 * this.controller.param.numSplit * fid;
            SplitEntry curSplit = this.findBestSplitOfOneFeature(trueFid, histogram, startIdx, rootStats);
            splitEntry.update(curSplit);
        }
        if (this.nid == 0) {
            this.controller.updateNodeGradStats(this.nid, rootStats);
        }
        if (splitEntry.fid != -1) {
            this.controller.updateNodeGradStats(2 * this.nid + 1, splitEntry.leftGradStat);
            this.controller.updateNodeGradStats(2 * this.nid + 2, splitEntry.rightGradStat);
        }
        LOG.debug((Object)String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue), Float.valueOf(splitEntry.lossChg)));
        return splitEntry;
    }

    public SplitEntry findBestSplitOfOneFeature(int fid, IntDoubleVector histogram, int startIdx, GradStats rootStats) {
        SplitEntry splitEntry = new SplitEntry();
        splitEntry.setFid(fid);
        GradStats bestLeftStat = new GradStats();
        GradStats bestRightStat = new GradStats();
        if (startIdx + 2 * this.controller.param.numSplit <= histogram.getDim()) {
            float rootGain = rootStats.calcGain(this.controller.param);
            GradStats leftStats = new GradStats();
            GradStats rightStats = new GradStats();
            for (int histIdx = startIdx; histIdx < startIdx + this.controller.param.numSplit - 1; ++histIdx) {
                int splitIdx;
                float lossChg;
                float grad = (float)histogram.get(histIdx);
                float hess = (float)histogram.get(this.controller.param.numSplit + histIdx);
                leftStats.add(grad, hess);
                if (!(leftStats.sumHess >= this.controller.param.minChildWeight)) continue;
                rightStats.setSubstract(rootStats, leftStats);
                if (!(rightStats.sumHess >= this.controller.param.minChildWeight) || !splitEntry.update(lossChg = leftStats.calcGain(this.controller.param) + rightStats.calcGain(this.controller.param) - rootGain, fid, this.controller.sketches[splitIdx = fid * this.controller.param.numSplit + histIdx - startIdx])) continue;
                bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
                bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
            }
            splitEntry.leftGradStat = bestLeftStat;
            splitEntry.rightGradStat = bestRightStat;
        } else {
            LOG.error((Object)"index out of grad histogram size.");
        }
        return splitEntry;
    }

    public SplitEntry findBestFromServerSplit(IntDoubleVector histogram) throws Exception {
        LOG.debug((Object)String.format("------To find the best split of node[%d]------", this.nid));
        SplitEntry splitEntry = new SplitEntry();
        LOG.debug((Object)String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue)));
        int partitionNum = WorkerContext.get().getConf().getInt("angel.ps.number", 1);
        int colPerPartition = histogram.getDim() / partitionNum;
        assert (histogram.getDim() == partitionNum * colPerPartition);
        for (int pid = 0; pid < partitionNum; ++pid) {
            int startIdx = pid * colPerPartition;
            int splitFid = (int)histogram.get(startIdx);
            if (splitFid == -1) continue;
            int trueSplitFid = this.controller.fSet[splitFid];
            int splitIdx = (int)histogram.get(startIdx + 1);
            float splitValue = this.controller.sketches[trueSplitFid * this.controller.param.numSplit + splitIdx];
            float lossChg = (float)histogram.get(startIdx + 2);
            float leftSumGrad = (float)histogram.get(startIdx + 3);
            float leftSumHess = (float)histogram.get(startIdx + 4);
            float rightSumGrad = (float)histogram.get(startIdx + 5);
            float rightSumHess = (float)histogram.get(startIdx + 6);
            LOG.debug((Object)String.format("The best split of the %d-th partition: split feature[%d], split index[%d], split value[%f], loss gain[%f], left sumGrad[%f], left sumHess[%f], right sumGrad[%f], right sumHess[%f]", pid, trueSplitFid, splitIdx, Float.valueOf(splitValue), Float.valueOf(lossChg), Float.valueOf(leftSumGrad), Float.valueOf(leftSumHess), Float.valueOf(rightSumGrad), Float.valueOf(rightSumHess)));
            GradStats curLeftGradStat = new GradStats(leftSumGrad, leftSumHess);
            GradStats curRightGradStat = new GradStats(rightSumGrad, rightSumHess);
            SplitEntry curSplitEntry = new SplitEntry(trueSplitFid, splitValue, lossChg);
            curSplitEntry.leftGradStat = curLeftGradStat;
            curSplitEntry.rightGradStat = curRightGradStat;
            splitEntry.update(curSplitEntry);
        }
        LOG.debug((Object)String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue), Float.valueOf(splitEntry.lossChg)));
        return splitEntry;
    }

    private void printHistogram(IntDoubleVector histogram, int fid, int splitnum) {
        int start = 2 * fid * splitnum;
        int end = start + splitnum - 1;
        StringBuilder sb = new StringBuilder();
        for (int i = start; i <= end; ++i) {
            sb.append(histogram.get(i) + ", ");
        }
        LOG.info((Object)String.format("Histogram of feature %d: %s", fid, sb.toString()));
    }

    private GradStats calGradStats(IntDoubleVector histogram) {
        float sumGrad = 0.0f;
        float sumHess = 0.0f;
        for (int i = 0; i < this.controller.param.numSplit; ++i) {
            sumGrad = (float)((double)sumGrad + histogram.get(i));
            sumHess = (float)((double)sumHess + histogram.get(this.controller.param.numSplit + i));
        }
        GradStats rootStats = new GradStats(sumGrad, sumHess);
        return rootStats;
    }

    private static GradStats calGradStats(IntDoubleVector histogram, int startIdx, int splitNum) {
        float sumGrad = 0.0f;
        float sumHess = 0.0f;
        for (int i = startIdx; i < startIdx + splitNum; ++i) {
            sumGrad = (float)((double)sumGrad + histogram.get(i));
            sumHess = (float)((double)sumHess + histogram.get(splitNum + i));
        }
        GradStats rootStats = new GradStats(sumGrad, sumHess);
        return rootStats;
    }

    private static GradStats calGradStats(ServerIntDoubleRow row, int startIdx, int splitNum) {
        float sumGrad = 0.0f;
        float sumHess = 0.0f;
        for (int i = startIdx; i < startIdx + splitNum; ++i) {
            sumGrad = (float)((double)sumGrad + row.get(i));
            sumHess = (float)((double)sumHess + row.get(splitNum + i));
        }
        GradStats rootStats = new GradStats(sumGrad, sumHess);
        return rootStats;
    }

    private static int findFidPlace(int[] fset, int fid) {
        int low = 0;
        int high = fset.length - 1;
        while (high >= low) {
            int middle = (high + low) / 2;
            if (fset[middle] == fid) {
                return middle;
            }
            if (fset[middle] > fid) {
                high = middle - 1;
                continue;
            }
            low = middle + 1;
        }
        return -1;
    }

    private static int findFvaluePlace(float[] sketch, float fvalue, int start, int end) {
        int left = start;
        int right = end;
        while (left < right & right <= end) {
            int mid = right + (left - right) / 2;
            if (sketch[mid] > fvalue) {
                if (sketch[mid - 1] < fvalue) {
                    return mid - 1 - start;
                }
                right = mid - 1;
                continue;
            }
            if (sketch[mid] < fvalue) {
                if (sketch[mid + 1] > fvalue) {
                    return mid - start;
                }
                left = mid + 1;
                continue;
            }
            return mid - start;
        }
        return Math.min(left, right) - start;
    }

    public static SplitEntry findBestSplitHelper(IntDoubleVector histogram) throws InterruptedException {
        LOG.debug((Object)String.format("------To find the best split of histogram size[%d]------", histogram.getDim()));
        SplitEntry splitEntry = new SplitEntry();
        LOG.debug((Object)String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue)));
        int featureNum = WorkerContext.get().getConf().getInt(MLConf.ML_FEATURE_INDEX_RANGE(), MLConf.DEFAULT_ML_FEATURE_INDEX_RANGE());
        int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
        if (histogram.getDim() != featureNum * 2 * splitNum) {
            LOG.debug((Object)"The size of histogram is not equal to 2 * featureNum*splitNum.");
            return splitEntry;
        }
        for (int fid = 0; fid < featureNum; ++fid) {
            int startIdx = 2 * splitNum * fid;
            SplitEntry curSplit = GradHistHelper.findBestSplitOfOneFeatureHelper(fid, histogram, startIdx);
            splitEntry.update(curSplit);
        }
        LOG.debug((Object)String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue), Float.valueOf(splitEntry.lossChg)));
        return splitEntry;
    }

    public static SplitEntry findBestSplitOfOneFeatureHelper(int fid, IntDoubleVector histogram, int startIdx) {
        LOG.debug((Object)String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, histogram.getDim(), startIdx));
        int splitNum = WorkerContext.get().getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
        SplitEntry splitEntry = new SplitEntry();
        GradStats bestLeftStat = new GradStats();
        GradStats bestRightStat = new GradStats();
        GradStats rootStats = GradHistHelper.calGradStats(histogram, startIdx, splitNum);
        GBDTParam param = new GBDTParam();
        if (startIdx + 2 * splitNum <= histogram.getDim()) {
            float rootGain = rootStats.calcGain(param);
            LOG.debug((Object)String.format("Feature[%d]: sumGrad[%f], sumHess[%f], gain[%f]", fid, Float.valueOf(rootStats.sumGrad), Float.valueOf(rootStats.sumHess), Float.valueOf(rootGain)));
            GradStats leftStats = new GradStats();
            GradStats rightStats = new GradStats();
            for (int histIdx = startIdx; histIdx < startIdx + splitNum - 1; ++histIdx) {
                int splitIdx;
                float lossChg;
                float grad = (float)histogram.get(histIdx);
                float hess = (float)histogram.get(splitNum + histIdx);
                leftStats.add(grad, hess);
                if (!(leftStats.sumHess >= param.minChildWeight)) continue;
                rightStats.setSubstract(rootStats, leftStats);
                if (!(rightStats.sumHess >= param.minChildWeight) || !splitEntry.update(lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain, fid, splitIdx = histIdx - startIdx + 1)) continue;
                bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
                bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
            }
            splitEntry.leftGradStat = bestLeftStat;
            splitEntry.rightGradStat = bestRightStat;
            LOG.debug((Object)String.format("Find best split for fid[%d], split feature[%d]: split index[%f], lossChg[%f]", fid, splitEntry.fid, Float.valueOf(splitEntry.fvalue), Float.valueOf(splitEntry.lossChg)));
        } else {
            LOG.error((Object)"index out of grad histogram size.");
        }
        return splitEntry;
    }

    public static SplitEntry findSplitOfServerRow(ServerIntDoubleRow row, GBDTParam param) {
        LOG.debug((Object)String.format("------To find the best split from server row[%d], cols[%d-%d]------", row.getRowId(), row.getStartCol(), row.getEndCol()));
        SplitEntry splitEntry = new SplitEntry();
        splitEntry.leftGradStat = new GradStats();
        splitEntry.rightGradStat = new GradStats();
        LOG.debug((Object)String.format("The best split before looping the histogram: fid[%d], fvalue[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue)));
        int startFid = (int)row.getStartCol() / (2 * param.numSplit);
        int endFid = (int)row.getEndCol() / (2 * param.numSplit) - 1;
        LOG.debug((Object)String.format("Row split col[%d-%d), start feature[%d], end feature[%d]", row.getStartCol(), row.getEndCol(), startFid, endFid));
        int i = 0;
        while (startFid + i <= endFid) {
            int startIdx = 2 * param.numSplit * i + (int)row.getStartCol();
            SplitEntry curSplit = GradHistHelper.findSplitOfFeature(startFid + i, row, startIdx, param);
            splitEntry.update(curSplit);
            ++i;
        }
        LOG.debug((Object)String.format("The best split after looping the histogram: fid[%d], fvalue[%f], loss gain[%f]", splitEntry.fid, Float.valueOf(splitEntry.fvalue), Float.valueOf(splitEntry.lossChg)));
        return splitEntry;
    }

    public static SplitEntry findSplitOfFeature(int fid, ServerIntDoubleRow row, int startIdx, GBDTParam param) {
        LOG.debug((Object)String.format("Find best split for fid[%d] in histogram size[%d], startIdx[%d]", fid, row.size(), startIdx));
        SplitEntry splitEntry = new SplitEntry();
        splitEntry.setFid(fid);
        GradStats bestLeftStat = new GradStats();
        GradStats bestRightStat = new GradStats();
        GradStats rootStats = GradHistHelper.calGradStats(row, startIdx, param.numSplit);
        if ((long)(startIdx + 2 * param.numSplit) <= row.getEndCol()) {
            float rootGain = rootStats.calcGain(param);
            GradStats leftStats = new GradStats();
            GradStats rightStats = new GradStats();
            for (int histIdx = startIdx; histIdx < startIdx + param.numSplit; ++histIdx) {
                int splitIdx;
                float lossChg;
                float grad = (float)row.get(histIdx);
                float hess = (float)row.get(param.numSplit + histIdx);
                leftStats.add(grad, hess);
                if (!(leftStats.sumHess >= param.minChildWeight)) continue;
                rightStats.setSubstract(rootStats, leftStats);
                if (!(rightStats.sumHess >= param.minChildWeight) || !splitEntry.update(lossChg = leftStats.calcGain(param) + rightStats.calcGain(param) - rootGain, fid, splitIdx = histIdx - startIdx)) continue;
                bestLeftStat.update(leftStats.sumGrad, leftStats.sumHess);
                bestRightStat.update(rightStats.sumGrad, rightStats.sumHess);
            }
            splitEntry.leftGradStat = bestLeftStat;
            splitEntry.rightGradStat = bestRightStat;
        } else {
            LOG.error((Object)"index out of grad histogram size.");
        }
        return splitEntry;
    }

    public static void main(String[] args) {
        float[] sketch = new float[]{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
        System.out.println("Result:" + GradHistHelper.findFvaluePlace(sketch, 0.7f, 0, 6));
    }
}

