/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.algo;

import com.tencent.angel.ml.GBDT.GBDTModel;
import com.tencent.angel.ml.GBDT.algo.AfterSplitThread;
import com.tencent.angel.ml.GBDT.algo.GBDTPhase;
import com.tencent.angel.ml.GBDT.algo.HistCalThread;
import com.tencent.angel.ml.GBDT.algo.HistSubThread;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradHistHelper;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradPair;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradStats;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTDataStore;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTree;
import com.tencent.angel.ml.GBDT.algo.tree.SplitEntry;
import com.tencent.angel.ml.GBDT.algo.tree.TYahooSketchSplit;
import com.tencent.angel.ml.GBDT.metric.EvalMetric;
import com.tencent.angel.ml.GBDT.objective.ObjFunc;
import com.tencent.angel.ml.GBDT.param.GBDTParam;
import com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowFunc;
import com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowResult;
import com.tencent.angel.ml.GBDT.psf.HistAggrParam;
import com.tencent.angel.ml.core.conf.MLConf;
import com.tencent.angel.ml.core.utils.Maths;
import com.tencent.angel.ml.math2.storage.IntDoubleDenseVectorStorage;
import com.tencent.angel.ml.math2.storage.IntDoubleSparseVectorStorage;
import com.tencent.angel.ml.math2.storage.IntDoubleVectorStorage;
import com.tencent.angel.ml.math2.storage.IntIntDenseVectorStorage;
import com.tencent.angel.ml.math2.storage.IntIntVectorStorage;
import com.tencent.angel.ml.math2.vector.IntDoubleVector;
import com.tencent.angel.ml.math2.vector.IntIntVector;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.update.base.UpdateFunc;
import com.tencent.angel.ml.model.PSModel;
import com.tencent.angel.ml.psf.compress.QuantifyDoubleFunc;
import com.tencent.angel.worker.task.TaskContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import scala.Tuple1;

public class GBDTController {
    private static final Log LOG = LogFactory.getLog(GBDTController.class);
    public TaskContext taskContext;
    public GBDTModel model;
    public GBDTParam param;
    public RegTDataStore trainDataStore;
    public RegTDataStore validDataStore;
    public RegTree[] forest;
    public GBDTPhase phase;
    public int clock;
    public int currentTree;
    public int currentDepth;
    public int maxNodeNum;
    public ObjFunc objfunc;
    public GradPair[] gradPairs;
    public float[] sketches;
    public List<Integer> cateFeatList;
    public Map<Integer, Integer> cateFeatNum;
    public int[] fSet;
    public int[] fPos;
    public int[] activeNode;
    public AtomicInteger[] activeNodeStat;
    public int[] instancePos;
    public int[] nodePosStart;
    public int[] nodePosEnd;
    public int[] splitFeats;
    public double[] splitValues;
    public double[] treePreds;
    public int[] validInsPos;
    public IntDoubleVector[] histCache;
    private ExecutorService threadPool;

    public GBDTController(TaskContext taskContext, GBDTParam param, RegTDataStore trainDataStore, RegTDataStore validDataStore, GBDTModel model) {
        this.taskContext = taskContext;
        this.param = param;
        this.trainDataStore = trainDataStore;
        this.validDataStore = validDataStore;
        this.model = model;
    }

    public void init() throws Exception {
        this.forest = new RegTree[this.param.treeNum];
        this.phase = GBDTPhase.CREATE_SKETCH;
        this.clock = 0;
        this.currentTree = 0;
        this.currentDepth = 1;
        this.objfunc = this.param.getLossFunc();
        this.gradPairs = new GradPair[this.trainDataStore.numRow];
        this.sketches = new float[this.param.numFeature * this.param.numSplit];
        String cateFeatStr = this.taskContext.getConf().get(MLConf.ML_GBDT_CATE_FEAT(), MLConf.DEFAULT_ML_GBDT_CATE_FEAT());
        this.cateFeatList = new ArrayList<Integer>();
        this.cateFeatNum = new HashMap<Integer, Integer>();
        switch (cateFeatStr) {
            case "all": {
                for (int fid = 0; fid < this.param.numFeature; ++fid) {
                    this.cateFeatList.add(fid);
                }
                break;
            }
            case "none": {
                break;
            }
            default: {
                String[] splits = cateFeatStr.split(",");
                for (int i2 = 0; i2 < splits.length; ++i2) {
                    String[] fidAndNum = splits[i2].split(":");
                    int fid = Integer.parseInt(fidAndNum[0]);
                    int num = Integer.parseInt(fidAndNum[1]);
                    assert (num < this.param.numSplit);
                    if (this.cateFeatList.contains(fid)) continue;
                    this.cateFeatList.add(fid);
                }
            }
        }
        this.fPos = new int[this.param.numFeature];
        this.maxNodeNum = Maths.pow(2, this.param.maxDepth) - 1;
        this.activeNode = new int[this.maxNodeNum];
        this.activeNodeStat = new AtomicInteger[this.maxNodeNum];
        Arrays.setAll(this.activeNodeStat, i -> new AtomicInteger(0));
        this.instancePos = new int[this.trainDataStore.numRow];
        Arrays.setAll(this.instancePos, i -> i);
        this.nodePosStart = new int[this.maxNodeNum];
        this.nodePosEnd = new int[this.maxNodeNum];
        this.nodePosStart[0] = 0;
        this.nodePosEnd[0] = this.instancePos.length - 1;
        this.splitFeats = new int[this.maxNodeNum];
        Arrays.setAll(this.splitFeats, i -> -1);
        this.splitValues = new double[this.maxNodeNum];
        Arrays.setAll(this.splitValues, i -> 0.0);
        this.treePreds = new double[this.maxNodeNum];
        this.validInsPos = new int[this.validDataStore.numRow];
        Arrays.setAll(this.validInsPos, i -> 0);
        this.histCache = new IntDoubleVector[this.maxNodeNum];
        this.threadPool = Executors.newFixedThreadPool(this.param.maxThreadNum);
    }

    private void clockAllMatrix(Set<String> needFlushMatrices, boolean wait) throws Exception {
        long startTime = System.currentTimeMillis();
        ArrayList<Future> clockFutures = new ArrayList<Future>();
        for (Map.Entry entry : this.model.getPSModels().entrySet()) {
            if (needFlushMatrices.contains(entry.getKey())) {
                clockFutures.add(((PSModel)entry.getValue()).clock(true));
                continue;
            }
            clockFutures.add(((PSModel)entry.getValue()).clock(false));
        }
        if (wait) {
            int size = clockFutures.size();
            for (int i = 0; i < size; ++i) {
                ((Future)clockFutures.get(i)).get();
            }
        }
        LOG.info((Object)String.format("clock and flush matrices %s cost %d ms", needFlushMatrices, System.currentTimeMillis() - startTime));
    }

    public void updatePhase() {
        switch (this.phase) {
            case CREATE_SKETCH: {
                this.setPhase(GBDTPhase.GET_SKETCH);
                break;
            }
            case GET_SKETCH: {
                this.setPhase(GBDTPhase.SAMPLE_FEATURE);
                break;
            }
            case SAMPLE_FEATURE: {
                this.setPhase(GBDTPhase.NEW_TREE);
                break;
            }
            case NEW_TREE: {
                this.setPhase(GBDTPhase.RUN_ACTIVE);
                break;
            }
            case RUN_ACTIVE: {
                this.setPhase(GBDTPhase.FIND_SPLIT);
                break;
            }
            case FIND_SPLIT: {
                this.setPhase(GBDTPhase.AFTER_SPLIT);
                break;
            }
            case AFTER_SPLIT: {
                if (this.hasActiveTNode()) {
                    this.setPhase(GBDTPhase.RUN_ACTIVE);
                    break;
                }
                this.setPhase(GBDTPhase.FINISH_TREE);
                break;
            }
            case FINISH_TREE: {
                if (this.isFinished()) {
                    this.setPhase(GBDTPhase.FINISHED);
                    break;
                }
                this.setPhase(GBDTPhase.SAMPLE_FEATURE);
                break;
            }
        }
    }

    public void incrementClock() {
        ++this.clock;
    }

    private void calGradPairs() {
        LOG.info((Object)"------Calculate grad pairs------");
        this.gradPairs = this.objfunc.calGrad(this.trainDataStore.preds, this.trainDataStore, 0);
        LOG.debug((Object)String.format("Instance[%d]: label[%f], pred[%f], gradient[%f], hessien[%f]", 0, Float.valueOf(this.trainDataStore.labels[0]), Float.valueOf(this.trainDataStore.preds[0]), Float.valueOf(this.gradPairs[0].getGrad()), Float.valueOf(this.gradPairs[0].getHess())));
    }

    public void createSketch() throws Exception {
        PSModel sketch = this.model.getPSModel(this.param.sketchName);
        PSModel cateFeat = this.model.getPSModel(this.param.cateFeatureName);
        if (this.taskContext.getTaskIndex() == 0) {
            float[][] splits;
            LOG.info((Object)"------Create sketch------");
            long startTime = System.currentTimeMillis();
            IntDoubleVector sketchVec = new IntDoubleVector(this.param.numFeature * this.param.numSplit, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(new double[this.param.numFeature * this.param.numSplit]));
            IntDoubleVector cateFeatVec = null;
            if (!this.cateFeatList.isEmpty()) {
                cateFeatVec = new IntDoubleVector(this.cateFeatList.size() * this.param.numSplit, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(new double[this.cateFeatList.size() * this.param.numSplit]));
            }
            if ((splits = TYahooSketchSplit.getSplitValue(this.trainDataStore, this.param.numSplit, this.cateFeatList)).length == this.param.numFeature && splits[0].length == this.param.numSplit) {
                for (int fid = 0; fid < splits.length; ++fid) {
                    if (this.cateFeatList.contains(fid)) continue;
                    for (int j = 0; j < splits[fid].length; ++j) {
                        sketchVec.set(fid * this.param.numSplit + j, (double)splits[fid][j]);
                    }
                }
            } else {
                LOG.error((Object)"Incompatible sketches size.");
            }
            if (!this.cateFeatList.isEmpty()) {
                Collections.sort(this.cateFeatList);
                for (int i = 0; i < this.cateFeatList.size(); ++i) {
                    int fid = this.cateFeatList.get(i);
                    int start = i * this.param.numSplit;
                    for (int j = 0; j < splits[fid].length && (splits[fid][j] != 0.0f || j <= 0); ++j) {
                        cateFeatVec.set(start + j, (double)splits[fid][j]);
                    }
                }
            }
            sketch.increment(0, (Vector)sketchVec);
            if (null != cateFeatVec) {
                cateFeat.increment(this.taskContext.getTaskIndex(), (Vector)cateFeatVec);
            }
            LOG.info((Object)String.format("Create sketch cost: %d ms", System.currentTimeMillis() - startTime));
        }
        HashSet<String> needFlushMatrixSet = new HashSet<String>(1);
        needFlushMatrixSet.add(this.param.sketchName);
        needFlushMatrixSet.add(this.param.cateFeatureName);
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    public void mergeCateFeatSketch() throws Exception {
        LOG.info((Object)"------Merge categorical features------");
        HashSet<String> needFlushMatrixSet = new HashSet<String>(1);
        if (!this.cateFeatList.isEmpty() && this.taskContext.getTaskIndex() == 0) {
            int j;
            PSModel cateFeat = this.model.getPSModel(this.param.cateFeatureName);
            PSModel sketch = this.model.getPSModel(this.param.sketchName);
            HashSet[] featSet = new HashSet[this.cateFeatList.size()];
            for (int i = 0; i < this.cateFeatList.size(); ++i) {
                featSet[i] = new HashSet();
            }
            int workerNum = this.taskContext.getConf().getInt("angel.workergroup.actual.number", 1);
            for (int worker = 0; worker < workerNum; ++worker) {
                IntDoubleVector vec = (IntDoubleVector)cateFeat.getRow(worker);
                for (int i = 0; i < this.cateFeatList.size(); ++i) {
                    int fid = this.cateFeatList.get(i);
                    int start = i * this.param.numSplit;
                    for (j = 0; j < this.param.numSplit; ++j) {
                        double fvalue = vec.get(start + j);
                        featSet[i].add(fvalue);
                    }
                }
            }
            IntDoubleVector cateFeatVec = new IntDoubleVector(this.param.numFeature * this.param.numSplit, (IntDoubleVectorStorage)new IntDoubleSparseVectorStorage(this.param.numFeature * this.param.numSplit));
            for (int i = 0; i < this.cateFeatList.size(); ++i) {
                int fid = this.cateFeatList.get(i);
                int start = fid * this.param.numSplit;
                ArrayList sortedValue = new ArrayList(featSet[i]);
                Collections.sort(sortedValue);
                assert (sortedValue.size() < this.param.numSplit);
                for (j = 0; j < sortedValue.size(); ++j) {
                    cateFeatVec.set(start + j, ((Double)sortedValue.get(j)).doubleValue());
                }
            }
            sketch.increment(0, (Vector)cateFeatVec);
            needFlushMatrixSet.add(this.param.sketchName);
        }
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    public void getSketch() throws Exception {
        int i;
        PSModel sketch = this.model.getPSModel(this.param.sketchName);
        LOG.info((Object)"------Get sketch from PS------");
        long startTime = System.currentTimeMillis();
        IntDoubleVector sketchVector = (IntDoubleVector)sketch.getRow(0);
        LOG.info((Object)String.format("Get sketch cost: %d ms", System.currentTimeMillis() - startTime));
        for (i = 0; i < sketchVector.getDim(); ++i) {
            this.sketches[i] = (float)sketchVector.get(i);
        }
        for (i = 0; i < this.cateFeatList.size(); ++i) {
            int fid = this.cateFeatList.get(i);
            int start = fid * this.param.numSplit;
            int splitNum = 1;
            for (int j = 0; j < this.param.numSplit && this.sketches[start + j + 1] > this.sketches[start + j]; ++j) {
                ++splitNum;
            }
            this.cateFeatNum.put(fid, splitNum);
        }
        LOG.info((Object)("Number of splits of categorical features: " + this.cateFeatNum.entrySet().toString()));
    }

    public void sampleFeature() throws Exception {
        LOG.info((Object)"------Sample feature------");
        PSModel featSample = this.model.getPSModel(this.param.sampledFeaturesName);
        HashSet<String> needFlushMatrixSet = new HashSet<String>(1);
        if (this.param.colSample < 1.0f && this.taskContext.getTaskIndex() == 0) {
            long startTime = System.currentTimeMillis();
            if (this.param.colSample < 1.0f) {
                int[] fset = this.trainDataStore.featureMeta.sampleCol(this.param.colSample);
                IntIntVector sampleFeatureVector = new IntIntVector(fset.length, (IntIntVectorStorage)new IntIntDenseVectorStorage(fset));
                featSample.increment(this.currentTree, (Vector)sampleFeatureVector);
                needFlushMatrixSet.add(this.param.sampledFeaturesName);
            }
            LOG.info((Object)String.format("Sample feature cost: %d ms", System.currentTimeMillis() - startTime));
        }
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    public void createNewTree() throws Exception {
        int nid;
        LOG.info((Object)"------Create new tree------");
        long startTime = System.currentTimeMillis();
        RegTree tree = new RegTree(this.param);
        tree.initTreeNodes();
        this.currentDepth = 1;
        this.forest[this.currentTree] = tree;
        if (this.param.colSample < 1.0f) {
            PSModel featSample = this.model.getPSModel(this.param.sampledFeaturesName);
            IntIntVector sampleFeatureVector = (IntIntVector)featSample.getRow(this.currentTree);
            this.fSet = sampleFeatureVector.getStorage().getValues();
            this.calfPos();
        } else if (null == this.fSet) {
            this.fSet = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fSet, i -> i);
            this.fPos = new int[this.trainDataStore.featureMeta.numFeature];
            Arrays.setAll(this.fPos, i -> i);
        }
        for (nid = 0; nid < this.maxNodeNum; ++nid) {
            this.resetActiveTNodes(nid);
        }
        this.addActiveNode(0);
        this.nodePosStart[0] = 0;
        this.nodePosEnd[0] = this.instancePos.length - 1;
        for (nid = 1; nid < this.maxNodeNum; ++nid) {
            this.nodePosStart[nid] = -1;
            this.nodePosEnd[nid] = -1;
        }
        Arrays.setAll(this.validInsPos, i -> 0);
        this.calGradPairs();
        LOG.info((Object)String.format("Create new tree cost: %d ms", System.currentTimeMillis() - startTime));
    }

    public void calfPos() {
        Arrays.setAll(this.fPos, i -> -1);
        int i2 = 0;
        while (i2 < this.fSet.length) {
            int fid = this.fSet[i2];
            this.fPos[fid] = i2++;
        }
    }

    public void runActiveNode() throws Exception {
        int nid;
        int nid2;
        LOG.info((Object)"------Run active node------");
        long startTime = System.currentTimeMillis();
        HashSet<String> needFlushMatrixSet = new HashSet<String>();
        HashSet<Integer> calNodes = new HashSet<Integer>();
        HashSet<Integer> subNodes = new HashSet<Integer>();
        for (int nid3 = 0; nid3 < this.maxNodeNum; ++nid3) {
            boolean ltSibling;
            if (this.activeNode[nid3] != 1) continue;
            if (nid3 == 0) {
                calNodes.add(nid3);
                continue;
            }
            int sampleNum = this.nodePosEnd[nid3] - this.nodePosStart[nid3] + 1;
            int parentNid = (nid3 - 1) / 2;
            int siblingNid = 4 * parentNid + 3 - nid3;
            int siblingSampleNum = this.nodePosEnd[siblingNid] - this.nodePosStart[siblingNid] + 1;
            boolean bl = ltSibling = sampleNum < siblingSampleNum || sampleNum == siblingSampleNum && nid3 < siblingNid;
            if (ltSibling) {
                calNodes.add(nid3);
                subNodes.add(siblingNid);
                continue;
            }
            calNodes.add(siblingNid);
            subNodes.add(nid3);
        }
        HashMap calFutures = new HashMap();
        Iterator parentNid = calNodes.iterator();
        while (parentNid.hasNext()) {
            nid2 = (Integer)parentNid.next();
            this.histCache[nid2] = new IntDoubleVector(this.fSet.length * 2 * this.param.numSplit, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(new double[this.param.numFeature * 2 * this.param.numSplit]));
            calFutures.put(nid2, new ArrayList());
            int nodeStart = this.nodePosStart[nid2];
            int nodeEnd = this.nodePosEnd[nid2];
            int batchNum = (nodeEnd - nodeStart + 1) / this.param.batchSize + ((nodeEnd - nodeStart + 1) % this.param.batchSize == 0 ? 0 : 1);
            LOG.info((Object)String.format("Node[%d], start[%d], end[%d], batch[%d]", nid2, nodeStart, nodeEnd, batchNum));
            for (int batch = 0; batch < batchNum; ++batch) {
                int start = nodeStart + batch * this.param.batchSize;
                int end = nodeStart + (batch + 1) * this.param.batchSize;
                if (end > nodeEnd) {
                    end = nodeEnd;
                }
                LOG.info((Object)String.format("Calculate thread: nid[%d], start[%d], end[%d]", nid2, start, end));
                Future<Boolean> future = this.threadPool.submit(new HistCalThread(this, nid2, start, end));
                ((List)calFutures.get(nid2)).add(future);
            }
        }
        parentNid = calNodes.iterator();
        while (parentNid.hasNext()) {
            nid2 = (Integer)parentNid.next();
            for (Future future : (List)calFutures.get(nid2)) {
                future.get();
            }
        }
        HashMap<Integer, Future<Boolean>> subFutures = new HashMap<Integer, Future<Boolean>>();
        Iterator nid4 = subNodes.iterator();
        while (nid4.hasNext()) {
            int nid5 = (Integer)nid4.next();
            int parentId = (nid5 - 1) / 2;
            this.histCache[nid5] = this.histCache[parentId].clone();
            LOG.info((Object)String.format("Subtract thread: nid[%d]", nid5));
            Future<Boolean> future = this.threadPool.submit(new HistSubThread(this, nid5));
            subFutures.put(nid5, future);
        }
        nid4 = subNodes.iterator();
        while (nid4.hasNext()) {
            int nid6 = (Integer)nid4.next();
            ((Future)subFutures.get(nid6)).get();
        }
        HashSet<Integer> pushNodes = new HashSet<Integer>(calNodes);
        pushNodes.addAll(subNodes);
        int bytesPerItem = this.taskContext.getConf().getInt(MLConf.ANGEL_COMPRESS_BYTES(), MLConf.DEFAULT_ANGEL_COMPRESS_BYTES());
        if (bytesPerItem < 1 || bytesPerItem > 8) {
            LOG.info((Object)("Invalid compress configuration: " + bytesPerItem + ", it should be [1,8]."));
            bytesPerItem = MLConf.DEFAULT_ANGEL_COMPRESS_BYTES();
        }
        Iterator iterator = pushNodes.iterator();
        while (iterator.hasNext()) {
            int nid7 = (Integer)iterator.next();
            this.pushHistogram(nid7, bytesPerItem);
            needFlushMatrixSet.add(this.param.gradHistNamePrefix + nid7);
        }
        iterator = calNodes.iterator();
        while (iterator.hasNext() && (nid = ((Integer)iterator.next()).intValue()) != 0) {
            int parentId = (nid - 1) / 2;
            this.histCache[parentId] = null;
        }
        LOG.info((Object)String.format("Run active node cost: %d ms", System.currentTimeMillis() - startTime));
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    private void pushHistogram(int nid, int bytesPerItem) {
        String histParaName = this.param.gradHistNamePrefix + nid;
        PSModel histMat = this.model.getPSModel(histParaName);
        try {
            if (bytesPerItem == 8) {
                histMat.increment(0, (Vector)this.histCache[nid]);
            } else {
                QuantifyDoubleFunc func = new QuantifyDoubleFunc(histMat.getMatrixId(), 0, this.histCache[nid], bytesPerItem * 8);
                histMat.update((UpdateFunc)func);
            }
        }
        catch (Exception e) {
            LOG.error((Object)(histParaName + " increment failed, "), (Throwable)e);
        }
    }

    public void findSplit() throws Exception {
        LOG.info((Object)"------Find split------");
        long startTime = System.currentTimeMillis();
        ArrayList<Integer> responsibleTNode = new ArrayList<Integer>();
        int activeTNodeNum = 0;
        for (int nid = 0; nid < this.activeNode.length; ++nid) {
            int isActive = this.activeNode[nid];
            if (isActive != 1) continue;
            if (this.taskContext.getTaskIndex() == activeTNodeNum) {
                responsibleTNode.add(nid);
            }
            if (++activeTNodeNum < this.taskContext.getTotalTaskNum()) continue;
            activeTNodeNum = 0;
        }
        int[] tNodeId = Maths.intList2Arr(responsibleTNode);
        LOG.info((Object)String.format("Task[%d] responsible tree node: %s", this.taskContext.getTaskId().getIndex(), ((Object)responsibleTNode).toString()));
        int[] updatedIndices = new int[tNodeId.length];
        int[] updatedSplitFid = new int[tNodeId.length];
        double[] updatedSplitFvalue = new double[tNodeId.length];
        double[] updatedSplitGain = new double[tNodeId.length];
        boolean isServerSplit = this.taskContext.getConf().getBoolean(MLConf.ML_GBDT_SERVER_SPLIT(), MLConf.DEFAULT_ML_GBDT_SERVER_SPLIT());
        int splitNum = this.taskContext.getConf().getInt(MLConf.ML_GBDT_SPLIT_NUM(), MLConf.DEFAULT_ML_GBDT_SPLIT_NUM());
        for (int i = 0; i < tNodeId.length; ++i) {
            int nid = tNodeId[i];
            LOG.debug((Object)String.format("Task[%d] find best split of tree node: %d", this.taskContext.getTaskIndex(), nid));
            String gradHistName = this.param.gradHistNamePrefix + nid;
            long pullStartTime = System.currentTimeMillis();
            PSModel histMat = this.model.getPSModel(gradHistName);
            IntDoubleVector histogram = null;
            SplitEntry splitEntry = null;
            if (isServerSplit) {
                int matrixId = histMat.getMatrixId();
                GBDTGradHistGetRowFunc func = new GBDTGradHistGetRowFunc(new HistAggrParam(matrixId, 0, this.param.numSplit, this.param.minChildWeight, this.param.regAlpha, this.param.regLambda));
                splitEntry = ((GBDTGradHistGetRowResult)histMat.get((GetFunc)func)).getSplitEntry();
            } else {
                histogram = (IntDoubleVector)histMat.getRow(0);
                LOG.debug((Object)("Get grad histogram without server split mode, histogram size" + histogram.getDim()));
            }
            LOG.info((Object)String.format("Pull histogram from PS cost %d ms", System.currentTimeMillis() - pullStartTime));
            GradHistHelper histHelper = new GradHistHelper(this, nid);
            if (this.param.isServerSplit) {
                if (splitEntry.getFid() != -1) {
                    int trueSplitFid = this.fSet[splitEntry.getFid()];
                    int splitIdx = (int)splitEntry.getFvalue();
                    float trueSplitValue = this.sketches[trueSplitFid * this.param.numSplit + splitIdx];
                    LOG.info((Object)String.format("Best split of node[%d]: feature[%d], value[%f], true feature[%d], true value[%f], losschg[%f]", nid, splitEntry.getFid(), Float.valueOf(splitEntry.getFvalue()), trueSplitFid, Float.valueOf(trueSplitValue), Float.valueOf(splitEntry.getLossChg())));
                    splitEntry.setFid(trueSplitFid);
                    splitEntry.setFvalue(trueSplitValue);
                }
                if (nid == 0) {
                    GradStats rootStats = new GradStats(splitEntry.leftGradStat);
                    rootStats.add(splitEntry.rightGradStat);
                    this.updateNodeGradStats(nid, rootStats);
                }
                if (splitEntry.fid != -1) {
                    this.updateNodeGradStats(2 * nid + 1, splitEntry.leftGradStat);
                    this.updateNodeGradStats(2 * nid + 2, splitEntry.rightGradStat);
                }
                updatedIndices[i] = nid;
                updatedSplitFid[i] = splitEntry.fid;
                updatedSplitFvalue[i] = splitEntry.fvalue;
                updatedSplitGain[i] = splitEntry.lossChg;
            } else {
                splitEntry = histHelper.findBestSplit(histogram);
                LOG.info((Object)String.format("Best split of node[%d]: feature[%d], value[%f], losschg[%f]", nid, splitEntry.getFid(), Float.valueOf(splitEntry.getFvalue()), Float.valueOf(splitEntry.getLossChg())));
                updatedIndices[i] = nid;
                updatedSplitFid[i] = splitEntry.fid;
                updatedSplitFvalue[i] = splitEntry.fvalue;
                updatedSplitGain[i] = splitEntry.lossChg;
            }
            histMat.zero();
        }
        IntIntVector splitFeatureVector = new IntIntVector(this.activeNode.length, (IntIntVectorStorage)new IntIntDenseVectorStorage(this.activeNode.length));
        IntDoubleVector splitValueVector = new IntDoubleVector(this.activeNode.length, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(this.activeNode.length));
        IntDoubleVector splitGainVector = new IntDoubleVector(this.activeNode.length, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(this.activeNode.length));
        for (int i = 0; i < updatedIndices.length; ++i) {
            splitFeatureVector.set(updatedIndices[i], updatedSplitFid[i]);
            splitValueVector.set(updatedIndices[i], updatedSplitFvalue[i]);
            splitGainVector.set(updatedIndices[i], updatedSplitGain[i]);
        }
        PSModel splitFeat = this.model.getPSModel(this.param.splitFeaturesName);
        splitFeat.increment(this.currentTree, (Vector)splitFeatureVector);
        PSModel splitValue = this.model.getPSModel(this.param.splitValuesName);
        splitValue.increment(this.currentTree, (Vector)splitValueVector);
        PSModel splitGain = this.model.getPSModel(this.param.splitGainsName);
        splitGain.increment(this.currentTree, (Vector)splitGainVector);
        LOG.info((Object)String.format("Find split cost: %d ms", System.currentTimeMillis() - startTime));
        HashSet<String> needFlushMatrixSet = new HashSet<String>(3);
        needFlushMatrixSet.add(this.param.splitFeaturesName);
        needFlushMatrixSet.add(this.param.splitValuesName);
        needFlushMatrixSet.add(this.param.splitGainsName);
        needFlushMatrixSet.add(this.param.nodeGradStatsName);
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    public void afterSplit() throws Exception {
        LOG.info((Object)"------After split------");
        long startTime = System.currentTimeMillis();
        PSModel splitFeatModel = this.model.getPSModel(this.param.splitFeaturesName);
        IntIntVector splitFeatureVec = (IntIntVector)splitFeatModel.getRow(this.currentTree);
        PSModel splitValueModel = this.model.getPSModel(this.param.splitValuesName);
        IntDoubleVector splitValueVec = (IntDoubleVector)splitValueModel.getRow(this.currentTree);
        PSModel splitGainModel = this.model.getPSModel(this.param.splitGainsName);
        IntDoubleVector splitGainVec = (IntDoubleVector)splitGainModel.getRow(this.currentTree);
        PSModel nodeGradStatsModel = this.model.getPSModel(this.param.nodeGradStatsName);
        IntDoubleVector nodeGradStatsVec = (IntDoubleVector)nodeGradStatsModel.getRow(this.currentTree);
        LOG.info((Object)String.format("Get split result from PS cost %d ms", System.currentTimeMillis() - startTime));
        LOG.debug((Object)String.format("Split active node: %s", Arrays.toString(this.activeNode)));
        int[] preActiveNode = (int[])this.activeNode.clone();
        for (int nid = 0; nid < this.maxNodeNum; ++nid) {
            if (preActiveNode[nid] != 1) continue;
            this.splitFeats[nid] = splitFeatureVec.get(nid);
            this.splitValues[nid] = splitValueVec.get(nid);
            this.activeNodeStat[nid].set(1);
            AfterSplitThread t = new AfterSplitThread(this, nid, splitFeatureVec, splitValueVec, splitGainVec, nodeGradStatsVec);
            this.threadPool.submit(t);
        }
        boolean hasRunning = true;
        while (hasRunning) {
            hasRunning = false;
            for (int nid = 0; nid < this.maxNodeNum; ++nid) {
                int stat = this.activeNodeStat[nid].get();
                if (stat != 1) continue;
                hasRunning = true;
                break;
            }
            if (!hasRunning) continue;
            LOG.debug((Object)"current has running thread");
        }
        this.updateValidInsPos();
        this.finishCurrentDepth();
        LOG.info((Object)String.format("After split cost: %d ms", System.currentTimeMillis() - startTime));
        HashSet<String> needFlushMatrixSet = new HashSet<String>(4);
        needFlushMatrixSet.add(this.param.splitFeaturesName);
        needFlushMatrixSet.add(this.param.splitValuesName);
        needFlushMatrixSet.add(this.param.splitGainsName);
        needFlushMatrixSet.add(this.param.nodeGradStatsName);
        this.clockAllMatrix(needFlushMatrixSet, true);
    }

    private void updateValidInsPos() {
        LOG.info((Object)"Update instance position of validation data");
        LOG.info((Object)String.format("Current split features: %s", Arrays.toString(this.splitFeats)));
        LOG.info((Object)String.format("Current split values: %s", Arrays.toString(this.splitValues)));
        LOG.info((Object)String.format("Old validation data position: %s", Arrays.toString(Arrays.copyOfRange(this.validInsPos, 0, 10))));
        for (int insIdx = 0; insIdx < this.validDataStore.numRow; ++insIdx) {
            int newNode;
            int curNode = this.validInsPos[insIdx];
            int splitFeat = this.splitFeats[curNode];
            double splitValue = this.splitValues[curNode];
            if (splitFeat == -1) continue;
            this.validInsPos[insIdx] = newNode = (double)this.validDataStore.instances[insIdx].get(splitFeat) <= splitValue ? 2 * curNode + 1 : 2 * curNode + 2;
        }
        LOG.info((Object)String.format("New validation data position: %s", Arrays.toString(Arrays.copyOfRange(this.validInsPos, 0, 10))));
    }

    public void updateTrainInsPos(int nid, int splitFeature, float splitValue) {
        LOG.debug((Object)String.format("------Update instance position of node[%d] split feature[%d] split value[%f]------", nid, splitFeature, Float.valueOf(splitValue)));
        int nodePosStart = this.nodePosStart[nid];
        int nodePosEnd = this.nodePosEnd[nid];
        LOG.debug((Object)String.format("Node[%d] instance positions: [%d-%d]", nid, nodePosStart, nodePosEnd));
        int left = nodePosStart;
        int right = nodePosEnd;
        if (left > right) {
            LOG.debug((Object)("nodePosStart > nodePosEnd, maybe there is no instance on node:" + nid));
            this.nodePosStart[2 * nid + 1] = left;
            this.nodePosEnd[2 * nid + 1] = right;
            LOG.debug((Object)String.format("Node[%d] instance positions: [%d-%d]", 2 * nid + 1, left, right));
            this.nodePosStart[2 * nid + 2] = left;
            this.nodePosEnd[2 * nid + 2] = right;
            LOG.debug((Object)String.format("Node[%d] instance positions: [%d-%d]", 2 * nid + 2, left, right));
            return;
        }
        while (right > left) {
            int leftInsIdx = this.instancePos[left];
            float leftValue = this.trainDataStore.instances[leftInsIdx].get(splitFeature);
            while (leftValue <= splitValue && left < right) {
                leftInsIdx = this.instancePos[++left];
                leftValue = this.trainDataStore.instances[leftInsIdx].get(splitFeature);
            }
            int rightInsIdx = this.instancePos[right];
            float rightValue = this.trainDataStore.instances[rightInsIdx].get(splitFeature);
            while (rightValue > splitValue && right > left) {
                rightInsIdx = this.instancePos[--right];
                rightValue = this.trainDataStore.instances[rightInsIdx].get(splitFeature);
            }
            if (right <= left) continue;
            this.instancePos[left] = rightInsIdx;
            this.instancePos[right] = leftInsIdx;
        }
        int curInsIdx = this.instancePos[left];
        float curValue = this.trainDataStore.instances[curInsIdx].get(splitFeature);
        int cutPos = curValue > splitValue ? left : left + 1;
        this.nodePosStart[2 * nid + 1] = nodePosStart;
        this.nodePosEnd[2 * nid + 1] = cutPos - 1;
        LOG.debug((Object)String.format("Node[%d] instance positions: [%d-%d]", 2 * nid + 1, nodePosStart, cutPos - 1));
        this.nodePosStart[2 * nid + 2] = cutPos;
        this.nodePosEnd[2 * nid + 2] = nodePosEnd;
        LOG.debug((Object)String.format("Node[%d] instance positions: [%d-%d]", 2 * nid + 2, cutPos, nodePosEnd));
    }

    public void addActiveNode(int nid) {
        this.activeNode[nid] = 1;
        this.activeNodeStat[nid].set(0);
    }

    public void setNodeToLeaf(int nid, float nodeWeight) {
        LOG.debug((Object)String.format("Set node[%d] to leaf node, leaf weight[%f]", nid, Float.valueOf(nodeWeight)));
        this.forest[this.currentTree].nodes.get(nid).chgToLeaf();
        this.forest[this.currentTree].nodes.get(nid).setLeafValue(nodeWeight);
    }

    public void resetActiveTNodes(int nid) {
        this.activeNode[nid] = 0;
        this.activeNodeStat[nid].set(0);
    }

    public void finishCurrentDepth() {
        ++this.currentDepth;
    }

    public void finishCurrentTree() throws Exception {
        this.updateLeafPreds();
        this.updateInsPreds();
        ++this.currentTree;
        this.currentDepth = 1;
    }

    public void setPhase(GBDTPhase phase) {
        this.phase = phase;
    }

    public boolean hasActiveTNode() {
        LOG.debug((Object)String.format("Check active node: %s", Arrays.toString(this.activeNode)));
        boolean hasActive = false;
        for (int isActive : this.activeNode) {
            if (isActive != 1) continue;
            hasActive = true;
            break;
        }
        return hasActive;
    }

    public boolean isFinished() {
        LOG.info((Object)String.format("Check if finished, cur tree[%d], max tree[%d]", this.currentTree, this.param.treeNum));
        return this.currentTree >= this.param.treeNum;
    }

    public void updateNodeGradStats(int nid, GradStats gradStats) throws Exception {
        LOG.debug((Object)String.format("Update gradStats of node[%d]: sumGrad[%f], sumHess[%f]", nid, Float.valueOf(gradStats.sumGrad), Float.valueOf(gradStats.sumHess)));
        IntDoubleVector vec = new IntDoubleVector(2 * this.activeNode.length, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(2 * this.activeNode.length));
        vec.set(nid, (double)gradStats.sumGrad);
        vec.set(nid + this.activeNode.length, (double)gradStats.sumHess);
        PSModel nodeGradStats = this.model.getPSModel(this.param.nodeGradStatsName);
        nodeGradStats.increment(this.currentTree, (Vector)vec);
    }

    public void updateInsPreds() {
        LOG.info((Object)"------Update instance predictions------");
        long startTime = System.currentTimeMillis();
        int nodeNum = this.forest[this.currentTree].nodes.size();
        for (int nid = 0; nid < nodeNum; ++nid) {
            if (null == this.forest[this.currentTree].nodes.get(nid) || !this.forest[this.currentTree].nodes.get(nid).isLeaf()) continue;
            float weight = this.forest[this.currentTree].nodes.get(nid).getLeafValue();
            int nodePosStart = this.nodePosStart[nid];
            int nodePosEnd = this.nodePosEnd[nid];
            for (int i = nodePosStart; i < nodePosEnd; ++i) {
                int insIdx;
                int n = insIdx = this.instancePos[i];
                this.trainDataStore.preds[n] = this.trainDataStore.preds[n] + this.param.learningRate * weight;
            }
        }
        LOG.info((Object)String.format("Old validation prediction: %s", Arrays.toString(Arrays.copyOfRange(this.validDataStore.preds, 0, 10))));
        int insIdx = 0;
        while (insIdx < this.validDataStore.numRow) {
            int nid = this.validInsPos[insIdx];
            float weight = this.forest[this.currentTree].nodes.get(nid).getLeafValue();
            int n = insIdx++;
            this.validDataStore.preds[n] = this.validDataStore.preds[n] + this.param.learningRate * weight;
        }
        LOG.info((Object)String.format("New validation prediction: %s", Arrays.toString(Arrays.copyOfRange(this.validDataStore.preds, 0, 10))));
        LOG.info((Object)String.format("Update instance predictions cost: %d ms", System.currentTimeMillis() - startTime));
    }

    public void updateLeafPreds() throws Exception {
        LOG.info((Object)"------Update leaf node predictions------");
        long startTime = System.currentTimeMillis();
        HashSet<String> needFlushMatrixSet = new HashSet<String>(1);
        if (this.taskContext.getTaskIndex() == 0) {
            int nodeNum = this.forest[this.currentTree].nodes.size();
            IntDoubleVector vec = new IntDoubleVector(this.maxNodeNum, (IntDoubleVectorStorage)new IntDoubleDenseVectorStorage(this.maxNodeNum));
            for (int nid = 0; nid < nodeNum; ++nid) {
                if (null == this.forest[this.currentTree].nodes.get(nid) || !this.forest[this.currentTree].nodes.get(nid).isLeaf()) continue;
                float weight = this.forest[this.currentTree].nodes.get(nid).getLeafValue();
                LOG.debug((Object)String.format("Leaf weight of node[%d]: %f", nid, Float.valueOf(weight)));
                vec.set(nid, (double)weight);
            }
            PSModel nodePreds = this.model.getPSModel(this.param.nodePredsName);
            nodePreds.increment(this.currentTree, (Vector)vec);
            needFlushMatrixSet.add(this.param.nodePredsName);
        }
        this.clockAllMatrix(needFlushMatrixSet, true);
        LOG.info((Object)String.format("Update leaf node predictions cost: %d ms", System.currentTimeMillis() - startTime));
    }

    public Tuple1<Double> eval() {
        LOG.info((Object)"------Evaluation------");
        long startTime = System.currentTimeMillis();
        EvalMetric evalMetric = this.param.getEvalMetric();
        float error = evalMetric.eval(this.trainDataStore.preds, this.trainDataStore.labels);
        LOG.info((Object)String.format("Error after tree[%d]: %f", this.currentTree, Float.valueOf(error)));
        LOG.info((Object)String.format("Evaluation cost: %d ms", System.currentTimeMillis() - startTime));
        return new Tuple1((Object)error);
    }

    public Tuple1<Double> predict() {
        LOG.info((Object)"------Predict------");
        long startTime = System.currentTimeMillis();
        EvalMetric evalMetric = this.param.getEvalMetric();
        float error = evalMetric.eval(this.validDataStore.preds, this.validDataStore.labels);
        LOG.info((Object)String.format("Error after tree[%d]: %f", this.currentTree, Float.valueOf(error)));
        LOG.info((Object)String.format("Evaluation cost: %d ms", System.currentTimeMillis() - startTime));
        return new Tuple1((Object)error);
    }

    public double treePred(IntIntVector splitFeatVec, IntDoubleVector splitValueVec, IntDoubleVector nodePredVec, IntDoubleVector ins) {
        assert (splitFeatVec.getDim() == splitValueVec.getDim() && splitValueVec.getDim() == nodePredVec.getDim());
        int nid = 0;
        int splitFeat = splitFeatVec.get(nid);
        double splitValue = splitValueVec.get(nid);
        double pred = nodePredVec.get(nid);
        while (null != this.forest[this.currentTree].nodes.get(nid) && !this.forest[this.currentTree].nodes.get(nid).isLeaf() && -1 != splitFeat && nid < splitFeatVec.getDim()) {
            nid = ins.get(splitFeat) <= splitValue ? 2 * nid + 1 : 2 * nid + 2;
            splitFeat = splitFeatVec.get(nid);
            splitValue = splitValueVec.get(nid);
            pred = nodePredVec.get(nid);
        }
        return pred;
    }
}

