/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictorFactory;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class UpdateAuxTreeWeightsTask
extends MRTask<UpdateAuxTreeWeightsTask> {
    private final DistributionFamily _dist;
    private final Predictor _p;
    private final DataInfo _di;
    private final boolean _sparse;
    private double[][] _nodeWeights;

    public UpdateAuxTreeWeightsTask(DistributionFamily dist, DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output) {
        this._dist = dist;
        this._p = PredictorFactory.makePredictor(modelInfo._boosterBytes, null, false);
        this._di = di;
        this._sparse = output._sparse;
        if (this._p.getNumClass() > 2) {
            throw new UnsupportedOperationException("Updating tree weights is currently not supported for multinomial models.");
        }
        if (this._dist != DistributionFamily.gaussian && this._dist != DistributionFamily.bernoulli) {
            throw new UnsupportedOperationException("Updating tree weights is currently not supported for distribution " + this._dist + ".");
        }
    }

    private double[][] initNodeWeights() {
        GBTree gbTree = (GBTree)this._p.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        double[][] nodeWeights = new double[trees.length][];
        for (int i = 0; i < trees.length; ++i) {
            nodeWeights[i] = new double[trees[i].getStats().length];
        }
        return nodeWeights;
    }

    public void map(Chunk[] chks, NewChunk[] idx) {
        this._nodeWeights = this.initNodeWeights();
        LinkFunction logit = LinkFunctionFactory.getLinkFunction((LinkFunctionType)LinkFunctionType.logit);
        RegTree[] trees = ((GBTree)this._p.getBooster()).getGroupedTrees()[0];
        MutableOneHotEncoderFVec inputVec = new MutableOneHotEncoderFVec(this._di, this._sparse);
        int inputLength = chks.length - 1;
        int weightIndex = chks.length - 1;
        double[] input = new double[inputLength];
        for (int row = 0; row < chks[0]._len; ++row) {
            double weight = chks[weightIndex].atd(row);
            if (weight == 0.0 || Double.isNaN(weight)) continue;
            for (int i = 0; i < input.length; ++i) {
                input[i] = chks[i].atd(row);
            }
            inputVec.setInput(input);
            int ntrees = this._nodeWeights.length;
            int[] leafIdx = this._p.getBooster().predictLeaf((FVec)inputVec, ntrees);
            assert (leafIdx.length == ntrees) : "Leaf indices (#idx=" + leafIdx.length + ") were not returned for all trees (#trees=" + ntrees + ").";
            if (this._dist == DistributionFamily.gaussian) {
                for (int i = 0; i < leafIdx.length; ++i) {
                    double[] dArray = this._nodeWeights[i];
                    int n = leafIdx[i];
                    dArray[n] = dArray[n] + weight;
                }
                continue;
            }
            assert (this._dist == DistributionFamily.bernoulli);
            double f = -this._p.getBaseScore();
            for (int i = 0; i < leafIdx.length; ++i) {
                RegTreeNode[] nodes = trees[i].getNodes();
                double p = logit.linkInv(f);
                double hessian = p * (1.0 - p);
                double[] dArray = this._nodeWeights[i];
                int n = leafIdx[i];
                dArray[n] = dArray[n] + weight * hessian;
                f += (double)nodes[leafIdx[i]].getLeafValue();
            }
        }
    }

    public void reduce(UpdateAuxTreeWeightsTask mrt) {
        ArrayUtils.add((double[][])this._nodeWeights, (double[][])mrt._nodeWeights);
    }

    protected void postGlobal() {
        GBTree gbTree = (GBTree)this._p.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        for (int i = 0; i < trees.length; ++i) {
            RegTreeNode[] nodes = trees[i].getNodes();
            for (int j = nodes.length - 1; j >= 0; --j) {
                RegTreeNode node = nodes[j];
                int parentId = node.getParentIndex();
                if (parentId < 0) continue;
                assert (parentId < j) : "Broken tree #" + i + ". Tree rollups assume parentId (=" + parentId + " < childId (=" + j + ").";
                RegTreeNode parent = nodes[parentId];
                this._nodeWeights[i][parentId] = this._nodeWeights[i][parent.getLeftChildIndex()] + this._nodeWeights[i][parent.getRightChildIndex()];
            }
        }
    }

    public double[][] getNodeWeights() {
        return this._nodeWeights;
    }
}

