/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT.param;

import com.tencent.angel.ml.GBDT.param.TrainParam;
import com.tencent.angel.ml.core.conf.MLConf;
import com.tencent.angel.ml.core.utils.Maths;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class RegTParam
implements TrainParam {
    private static final Log LOG = LogFactory.getLog(RegTParam.class);
    public float learningRate = (float)MLConf.DEFAULT_ML_LEARN_RATE();
    public int numClass = 2;
    public float minSplitLoss = 0.0f;
    public int maxDepth = MLConf.DEFAULT_ML_GBDT_TREE_DEPTH();
    public int numFeature;
    public int numNonzero;
    public int numSplit = MLConf.DEFAULT_ML_GBDT_SPLIT_NUM();
    public float baseWeight = 0.0f;
    public float minChildWeight = (float)MLConf.DEFAULT_ML_GBDT_MIN_CHILD_WEIGHT();
    public float regLambda = (float)MLConf.DEFAULT_ML_GBDT_REG_LAMBDA();
    public float regAlpha = MLConf.DEFAULT_ML_GBDT_REG_ALPHA();
    public int defaultDirection;
    public float maxDeltaStep = 0.0f;
    public float rowSample = 1.0f;
    public float colSample = 1.0f;
    public float sketchEps = 0.03f;
    public float sketchRatio = 2.0f;
    public int sizeLeafVector = 0;
    public int parallelOption = 0;
    public boolean cacheOpt = true;
    public boolean silent = false;

    @Override
    public void printParam() {
        LOG.info((Object)String.format("Tree hyper-parameters------maxdepth: %d, minSplitLoss: %f, rowSample: %f, colSample: %f", this.maxDepth, Float.valueOf(this.minSplitLoss), Float.valueOf(this.rowSample), Float.valueOf(this.colSample)));
    }

    public float calcWeight(float sumGrad, float sumHess) {
        if (sumHess < this.minChildWeight) {
            return 0.0f;
        }
        float dw = this.regAlpha == 0.0f ? -sumGrad / (sumHess + this.regLambda) : -Maths.thresholdL1(sumGrad, this.regAlpha) / (sumHess + this.regLambda);
        if (this.maxDeltaStep != 0.0f) {
            if (dw > this.maxDeltaStep) {
                dw = this.maxDeltaStep;
            }
            if (dw < -this.maxDeltaStep) {
                dw = -this.maxDeltaStep;
            }
        }
        return dw;
    }

    public float calcGain(float sumGrad, float sumHess) {
        if (sumHess < this.minChildWeight) {
            return 0.0f;
        }
        if (this.maxDeltaStep == 0.0f) {
            if (this.regAlpha == 0.0f) {
                return sumGrad / (sumHess + this.regLambda) * sumGrad;
            }
            return Maths.sqr(Maths.thresholdL1(sumGrad, this.regAlpha)) / (sumHess + this.regLambda);
        }
        float w = this.calcWeight(sumGrad, sumHess);
        float ret = sumGrad * w + 0.5f * (sumHess + this.regLambda) * Maths.sqr(w);
        if (this.regAlpha == 0.0f) {
            return -2.0f * ret;
        }
        return -2.0f * (ret + this.regAlpha * Math.abs(w));
    }

    public float calcGain(float sumGrad, float sumHess, float testGrad, float testHess) {
        float w = this.calcWeight(sumGrad, sumHess);
        float ret = testGrad * w + 0.5f * (testHess + this.regLambda) * Maths.sqr(w);
        if (this.regAlpha == 0.0f) {
            return -2.0f * ret;
        }
        return -2.0f * (ret + this.regAlpha * Math.abs(w));
    }

    public boolean needPrune(float lossChg, int depth) {
        return lossChg < this.minSplitLoss;
    }

    public boolean cannotSplit(float sumHess, int depth) {
        return sumHess < this.minChildWeight * 2.0f;
    }

    public int maxSketchSize() {
        int ret = (int)(this.sketchRatio / this.sketchEps);
        return ret;
    }
}

