/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.psf;

import com.tencent.angel.ml.GBDT.algo.RegTree.GradHistHelper;
import com.tencent.angel.ml.GBDT.algo.RegTree.GradStats;
import com.tencent.angel.ml.GBDT.algo.tree.SplitEntry;
import com.tencent.angel.ml.GBDT.param.GBDTParam;
import com.tencent.angel.ml.GBDT.psf.GBDTGradHistGetRowResult;
import com.tencent.angel.ml.GBDT.psf.HistAggrParam;
import com.tencent.angel.ml.matrix.RowType;
import com.tencent.angel.ml.matrix.psf.get.base.GetFunc;
import com.tencent.angel.ml.matrix.psf.get.base.GetParam;
import com.tencent.angel.ml.matrix.psf.get.base.GetResult;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetParam;
import com.tencent.angel.ml.matrix.psf.get.base.PartitionGetResult;
import com.tencent.angel.ml.matrix.psf.get.getrow.PartitionGetRowResult;
import com.tencent.angel.ps.storage.vector.ServerIntDoubleRow;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.psagent.matrix.ResponseType;
import com.tencent.angel.psagent.matrix.transport.router.RouterType;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class GBDTGradHistGetRowFunc
extends GetFunc {
    private static final Log LOG = LogFactory.getLog(GBDTGradHistGetRowFunc.class);

    public GBDTGradHistGetRowFunc() {
        super(null);
    }

    public GBDTGradHistGetRowFunc(HistAggrParam param) {
        super((GetParam)param);
    }

    public PartitionGetResult partitionGet(PartitionGetParam partParam) {
        HistAggrParam.HistPartitionAggrParam param = (HistAggrParam.HistPartitionAggrParam)partParam;
        LOG.info((Object)"For the gradient histogram of GBDT, we use PS to find the optimal split");
        GBDTParam gbtparam = new GBDTParam();
        gbtparam.numSplit = param.getSplitNum();
        gbtparam.minChildWeight = param.getMinChildWeight();
        gbtparam.regAlpha = param.getRegAlpha();
        gbtparam.regLambda = param.getRegLambda();
        ServerIntDoubleRow row = (ServerIntDoubleRow)this.psContext.getMatrixStorageManager().getRow(param.getMatrixId(), param.getRowId(), param.getPartKey().getPartitionId());
        SplitEntry splitEntry = GradHistHelper.findSplitOfServerRow(row, gbtparam);
        int fid = splitEntry.getFid();
        int splitIndex = (int)splitEntry.getFvalue();
        double lossGain = splitEntry.getLossChg();
        GradStats leftGradStat = splitEntry.leftGradStat;
        GradStats rightGradStat = splitEntry.rightGradStat;
        double leftSumGrad = leftGradStat.sumGrad;
        double leftSumHess = leftGradStat.sumHess;
        double rightSumGrad = rightGradStat.sumGrad;
        double rightSumHess = rightGradStat.sumHess;
        LOG.info((Object)String.format("split of matrix[%d] part[%d] row[%d]: fid[%d], split index[%d], loss gain[%f], left sumGrad[%f], left sum hess[%f], right sumGrad[%f], right sum hess[%f]", param.getMatrixId(), param.getPartKey().getPartitionId(), param.getRowId(), fid, splitIndex, lossGain, leftSumGrad, leftSumHess, rightSumGrad, rightSumHess));
        int startFid = (int)row.getStartCol() / (2 * gbtparam.numSplit);
        int sendStartCol = startFid * 7;
        int sendEndCol = sendStartCol + 7;
        ServerIntDoubleRow sendRow = new ServerIntDoubleRow(param.getRowId(), RowType.T_DOUBLE_DENSE, sendStartCol, sendEndCol, sendEndCol - sendStartCol, RouterType.RANGE);
        LOG.info((Object)String.format("Create server row of split result: row id[%d], start col[%d], end col[%d]", param.getRowId(), sendStartCol, sendEndCol));
        sendRow.set(0 + sendStartCol, (double)fid);
        sendRow.set(1 + sendStartCol, (double)splitIndex);
        sendRow.set(2 + sendStartCol, lossGain);
        sendRow.set(3 + sendStartCol, leftSumGrad);
        sendRow.set(4 + sendStartCol, leftSumHess);
        sendRow.set(5 + sendStartCol, rightSumGrad);
        sendRow.set(6 + sendStartCol, rightSumHess);
        return new PartitionGetRowResult((ServerRow)sendRow);
    }

    public GetResult merge(List<PartitionGetResult> partResults) {
        int size = partResults.size();
        ArrayList<ServerRow> rowSplits = new ArrayList<ServerRow>(size);
        for (int i = 0; i < size; ++i) {
            rowSplits.add(((PartitionGetRowResult)partResults.get(i)).getRowSplit());
        }
        SplitEntry splitEntry = new SplitEntry();
        for (int i = 0; i < size; ++i) {
            ServerIntDoubleRow row = (ServerIntDoubleRow)((PartitionGetRowResult)partResults.get(i)).getRowSplit();
            int fid = (int)row.get(0 + (int)row.getStartCol());
            if (fid == -1) continue;
            int splitIndex = (int)row.get(1 + (int)row.getStartCol());
            float lossGain = (float)row.get(2 + (int)row.getStartCol());
            float leftSumGrad = (float)row.get(3 + (int)row.getStartCol());
            float leftSumHess = (float)row.get(4 + (int)row.getStartCol());
            float rightSumGrad = (float)row.get(5 + (int)row.getStartCol());
            float rightSumHess = (float)row.get(6 + (int)row.getStartCol());
            LOG.info((Object)String.format("psFunc: the best split after looping a split: fid[%d], fvalue[%d], loss gain[%f], leftSumGrad[%f], leftSumHess[%f], rightSumGrad[%f], rightSumHess[%f]", fid, splitIndex, Float.valueOf(lossGain), Float.valueOf(leftSumGrad), Float.valueOf(leftSumHess), Float.valueOf(rightSumGrad), Float.valueOf(rightSumHess)));
            GradStats curLeftGradStat = new GradStats(leftSumGrad, leftSumHess);
            GradStats curRightGradStat = new GradStats(rightSumGrad, rightSumHess);
            SplitEntry curSplitEntry = new SplitEntry(fid, splitIndex, lossGain);
            curSplitEntry.leftGradStat = curLeftGradStat;
            curSplitEntry.rightGradStat = curRightGradStat;
            splitEntry.update(curSplitEntry);
        }
        return new GBDTGradHistGetRowResult(ResponseType.SUCCESS, splitEntry);
    }
}

