/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.objective;

import com.tencent.angel.ml.GBDT.algo.RegTree.GradPair;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTDataStore;
import com.tencent.angel.ml.GBDT.objective.ObjFunc;
import com.tencent.angel.ml.GBDT.param.RegTParam;
import com.tencent.angel.ml.core.utils.Maths;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SoftmaxMultiClassObj
implements ObjFunc {
    private static final Log LOG = LogFactory.getLog(SoftmaxMultiClassObj.class);
    public RegTParam param;
    public int numClass;
    private boolean outputProb;

    public SoftmaxMultiClassObj(RegTParam param, boolean outputProb) {
        this.param = param;
        this.numClass = param.numClass;
        this.outputProb = outputProb;
    }

    public SoftmaxMultiClassObj(RegTParam param) {
        this(param, true);
    }

    @Override
    public GradPair[] calGrad(float[] preds, RegTDataStore dataStore, int iteration) {
        assert (preds.length == this.numClass * dataStore.labels.length);
        int ndata = preds.length / this.numClass;
        GradPair[] rec = new GradPair[preds.length];
        int labelError = -1;
        float[] tmp = new float[this.numClass];
        for (int insIdx = 0; insIdx < ndata; ++insIdx) {
            System.arraycopy(preds, insIdx * this.numClass, tmp, 0, this.numClass);
            Maths.softmax(tmp);
            int label = (int)dataStore.labels[insIdx];
            if (label < 0 || label >= this.numClass) {
                labelError = label;
                label = 0;
            }
            float wt = dataStore.getWeight(insIdx);
            for (int k = 0; k < this.numClass; ++k) {
                float p = tmp[k];
                float h = 2.0f * p * (1.0f - p) * wt;
                GradPair pair = null;
                pair = label == k ? new GradPair((p - 1.0f) * wt, h) : new GradPair(p * wt, h);
                rec[insIdx * this.numClass + k] = pair;
            }
        }
        if (labelError >= 0 && labelError < this.numClass) {
            LOG.error((Object)String.format("SoftmaxMultiClassObj: label must be in [0, num_class), numClass = %d, but found %d in label", this.numClass, labelError));
        }
        return rec;
    }

    public float[] transform(float[] preds, boolean prob) {
        int ndata = preds.length / this.numClass;
        float[] rec = new float[ndata];
        float[] tmp = new float[this.numClass];
        for (int insIdx = 0; insIdx < ndata; ++insIdx) {
            int k;
            for (k = 0; k < this.numClass; ++k) {
                tmp[k] = preds[insIdx * this.numClass + k];
            }
            if (!prob) {
                rec[insIdx] = Maths.findMaxIndex(tmp);
                continue;
            }
            Maths.softmax(tmp);
            for (k = 0; k < this.numClass; ++k) {
                preds[insIdx * this.numClass + k] = tmp[k];
            }
        }
        return rec;
    }

    @Override
    public String defaultEvalMetric() {
        return "merror";
    }

    @Override
    public void transPred(float[] preds) {
        this.transform(preds, this.outputProb);
    }

    @Override
    public void transEval(float[] preds) {
        this.transform(preds, true);
    }

    @Override
    public float prob2Margin(float base_score) {
        return 0.0f;
    }
}

