/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.objective;

import com.tencent.angel.ml.GBDT.algo.RegTree.GradPair;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTDataStore;
import com.tencent.angel.ml.GBDT.objective.LossHelper;
import com.tencent.angel.ml.GBDT.objective.ObjFunc;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class RegLossObj
implements ObjFunc {
    private static final Log LOG = LogFactory.getLog(RegLossObj.class);
    private LossHelper loss;
    private float scalePosWeight;

    public RegLossObj(LossHelper loss) {
        this.loss = loss;
        this.scalePosWeight = 1.0f;
    }

    @Override
    public GradPair[] calGrad(float[] preds, RegTDataStore dataStore, int iteration) {
        assert (preds.length > 0);
        assert (preds.length == dataStore.labels.length);
        int ndata = preds.length;
        GradPair[] rec = new GradPair[ndata];
        boolean label_correct = true;
        for (int i = 0; i < ndata; ++i) {
            GradPair pair;
            float p = this.loss.transPred(preds[i]);
            float w = dataStore.getWeight(i);
            if (dataStore.labels[i] == 1.0f) {
                w *= this.scalePosWeight;
            }
            if (!this.loss.checkLabel(dataStore.labels[i])) {
                label_correct = false;
            }
            rec[i] = pair = new GradPair(this.loss.firOrderGrad(p, dataStore.labels[i]) * w, this.loss.secOrderGrad(p, dataStore.labels[i]) * w);
        }
        if (!label_correct) {
            LOG.error((Object)this.loss.labelErrorMsg());
        }
        return rec;
    }

    @Override
    public void transPred(float[] preds) {
        int ndata = preds.length;
        for (int j = 0; j < ndata; ++j) {
            preds[j] = this.loss.transPred(preds[j]);
        }
    }

    @Override
    public void transEval(float[] preds) {
        this.transPred(preds);
    }

    @Override
    public float prob2Margin(float base_score) {
        return this.loss.prob2Margin(base_score);
    }

    @Override
    public String defaultEvalMetric() {
        return this.loss.defaultEvalMetric();
    }
}

