/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.GBDT;

import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.GBDT.GBDTModel;
import com.tencent.angel.ml.GBDT.GBDTModel$;
import com.tencent.angel.ml.GBDT.algo.GBDTController;
import com.tencent.angel.ml.GBDT.algo.GBDTPhase;
import com.tencent.angel.ml.GBDT.algo.RegTree.RegTDataStore;
import com.tencent.angel.ml.GBDT.param.GBDTParam;
import com.tencent.angel.ml.GBDT.param.RegTParam;
import com.tencent.angel.ml.core.MLLearner;
import com.tencent.angel.ml.core.conf.MLConf$;
import com.tencent.angel.ml.feature.LabeledData;
import com.tencent.angel.ml.metric.ErrorMetric$;
import com.tencent.angel.ml.model.MLModel;
import com.tencent.angel.worker.storage.DataBlock;
import com.tencent.angel.worker.task.TaskContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple1;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u00055a\u0001B\u0001\u0003\u00015\u00111b\u0012\"E)2+\u0017M\u001d8fe*\u00111\u0001B\u0001\u0005\u000f\n#EK\u0003\u0002\u0006\r\u0005\u0011Q\u000e\u001c\u0006\u0003\u000f!\tQ!\u00198hK2T!!\u0003\u0006\u0002\u000fQ,gnY3oi*\t1\"A\u0002d_6\u001c\u0001a\u0005\u0002\u0001\u001dA\u0011qBE\u0007\u0002!)\u0011\u0011\u0003B\u0001\u0005G>\u0014X-\u0003\u0002\u0014!\tIQ\n\u0014'fCJtWM\u001d\u0005\t+\u0001\u0011)\u0019!C!-\u0005\u00191\r\u001e=\u0016\u0003]\u0001\"\u0001G\u000f\u000e\u0003eQ!AG\u000e\u0002\tQ\f7o\u001b\u0006\u00039\u0019\taa^8sW\u0016\u0014\u0018B\u0001\u0010\u001a\u0005-!\u0016m]6D_:$X\r\u001f;\t\u0013\u0001\u0002!\u0011!Q\u0001\n]\t\u0013\u0001B2uq\u0002J!!\u0006\n\t\u000b\r\u0002A\u0011\u0001\u0013\u0002\rqJg.\u001b;?)\t)s\u0005\u0005\u0002'\u00015\t!\u0001C\u0003\u0016E\u0001\u0007q\u0003C\u0004*\u0001\t\u0007I\u0011\u0001\u0016\u0002\u00071{u)F\u0001,!\taS'D\u0001.\u0015\tqs&A\u0004m_\u001e<\u0017N\\4\u000b\u0005A\n\u0014aB2p[6|gn\u001d\u0006\u0003eM\na!\u00199bG\",'\"\u0001\u001b\u0002\u0007=\u0014x-\u0003\u00027[\t\u0019Aj\\4\t\ra\u0002\u0001\u0015!\u0003,\u0003\u0011auj\u0012\u0011\t\u000fi\u0002!\u0019!C\u0001w\u0005)\u0001/\u0019:b[V\tA\b\u0005\u0002>\u007f5\taH\u0003\u0002;\u0005%\u0011\u0001I\u0010\u0002\n\u000f\n#E\u000bU1sC6DaA\u0011\u0001!\u0002\u0013a\u0014A\u00029be\u0006l\u0007\u0005C\u0004E\u0001\t\u0007I\u0011A#\u0002\u000b5|G-\u001a7\u0016\u0003\u0019\u0003\"AJ$\n\u0005!\u0013!!C$C\tRku\u000eZ3m\u0011\u0019Q\u0005\u0001)A\u0005\r\u00061Qn\u001c3fY\u0002BQ\u0001\u0014\u0001\u0005\u00025\u000b\u0011\"\u001b8jiB\u000b'/Y7\u0015\u00039\u0003\"a\u0014*\u000e\u0003AS\u0011!U\u0001\u0006g\u000e\fG.Y\u0005\u0003'B\u0013A!\u00168ji\")Q\u000b\u0001C\u0001-\u0006a\u0011N\\5u\t\u0006$\u0018-T3uCR\u0019qkX7\u0011\u0005akV\"A-\u000b\u0005i[\u0016a\u0002*fOR\u0013X-\u001a\u0006\u00039\n\tA!\u00197h_&\u0011a,\u0017\u0002\u000e%\u0016<G\u000bR1uCN#xN]3\t\u000b\u0001$\u0006\u0019A1\u0002\u000f\u0011\fG/Y*fiB\u0019!-Z4\u000e\u0003\rT!\u0001Z\u000e\u0002\u000fM$xN]1hK&\u0011am\u0019\u0002\n\t\u0006$\u0018M\u00117pG.\u0004\"\u0001[6\u000e\u0003%T!A\u001b\u0003\u0002\u000f\u0019,\u0017\r^;sK&\u0011A.\u001b\u0002\f\u0019\u0006\u0014W\r\\3e\t\u0006$\u0018\rC\u0003;)\u0002\u0007a\u000e\u0005\u0002>_&\u0011\u0001O\u0010\u0002\n%\u0016<G\u000bU1sC6DQA\u001d\u0001\u0005\u0002M\fQ\"\u001e9eCR,W*\u001a;sS\u000e\u001cHC\u0001(u\u0011\u0015)\u0018\u000f1\u0001w\u0003)\u0019wN\u001c;s_2dWM\u001d\t\u0003obl\u0011aW\u0005\u0003sn\u0013ab\u0012\"E)\u000e{g\u000e\u001e:pY2,'\u000fC\u0003|\u0001\u0011\u0005C0A\u0003ue\u0006Lg\u000eF\u0003~\u0003\u000b\tI\u0001E\u0002\u007f\u0003\u0003i\u0011a \u0006\u0003\t\u0012I1!a\u0001\u0000\u0005\u001diE*T8eK2Da!a\u0002{\u0001\u0004\t\u0017!\u0003;sC&tG)\u0019;b\u0011\u0019\tYA\u001fa\u0001C\u0006qa/\u00197jI\u0006$\u0018n\u001c8ECR\f\u0007")
public class GBDTLearner
extends MLLearner {
    private final Log LOG = LogFactory.getLog(GBDTLearner.class);
    private final GBDTParam param = new GBDTParam();
    private final GBDTModel model;

    @Override
    public TaskContext ctx() {
        return super.ctx();
    }

    public Log LOG() {
        return this.LOG;
    }

    public GBDTParam param() {
        return this.param;
    }

    public GBDTModel model() {
        return this.model;
    }

    public void initParam() {
        this.param().taskType = this.conf().get(MLConf$.MODULE$.ML_GBDT_TASK_TYPE(), MLConf$.MODULE$.DEFAULT_ML_GBDT_TASK_TYPE());
        this.param().numFeature = this.conf().getInt(MLConf$.MODULE$.ML_FEATURE_INDEX_RANGE(), MLConf$.MODULE$.DEFAULT_ML_FEATURE_INDEX_RANGE());
        this.param().numNonzero = this.conf().getInt(MLConf$.MODULE$.ML_MODEL_SIZE(), MLConf$.MODULE$.DEFAULT_ML_MODEL_SIZE());
        this.param().numSplit = this.conf().getInt(MLConf$.MODULE$.ML_GBDT_SPLIT_NUM(), MLConf$.MODULE$.DEFAULT_ML_GBDT_SPLIT_NUM());
        this.param().treeNum = this.conf().getInt(MLConf$.MODULE$.ML_GBDT_TREE_NUM(), MLConf$.MODULE$.DEFAULT_ML_GBDT_TREE_NUM());
        this.param().maxDepth = this.conf().getInt(MLConf$.MODULE$.ML_GBDT_TREE_DEPTH(), MLConf$.MODULE$.DEFAULT_ML_GBDT_TREE_DEPTH());
        this.param().colSample = this.conf().getFloat(MLConf$.MODULE$.ML_GBDT_SAMPLE_RATIO(), (float)MLConf$.MODULE$.DEFAULT_ML_GBDT_SAMPLE_RATIO());
        this.param().learningRate = this.conf().getFloat(MLConf$.MODULE$.ML_LEARN_RATE(), (float)MLConf$.MODULE$.DEFAULT_ML_LEARN_RATE());
        this.param().maxThreadNum = this.conf().getInt(MLConf$.MODULE$.ML_GBDT_THREAD_NUM(), MLConf$.MODULE$.DEFAULT_ML_GBDT_THREAD_NUM());
        this.param().batchSize = this.conf().getInt(MLConf$.MODULE$.ML_GBDT_BATCH_SIZE(), MLConf$.MODULE$.DEFAULT_ML_GBDT_BATCH_SIZE());
        this.param().isServerSplit = this.conf().getBoolean(MLConf$.MODULE$.ML_GBDT_SERVER_SPLIT(), MLConf$.MODULE$.DEFAULT_ML_GBDT_SERVER_SPLIT());
        this.param().sketchName = GBDTModel$.MODULE$.SKETCH_MAT();
        this.param().gradHistNamePrefix = GBDTModel$.MODULE$.GRAD_HIST_MAT_PREFIX();
        this.param().activeTreeNodesName = GBDTModel$.MODULE$.ACTIVE_NODE_MAT();
        this.param().sampledFeaturesName = GBDTModel$.MODULE$.FEAT_SAMPLE_MAT();
        this.param().cateFeatureName = GBDTModel$.MODULE$.FEAT_CATEGORY_MAT();
        this.param().splitFeaturesName = GBDTModel$.MODULE$.SPLIT_FEAT_MAT();
        this.param().splitValuesName = GBDTModel$.MODULE$.SPLIT_VALUE_MAT();
        this.param().splitGainsName = GBDTModel$.MODULE$.SPLIT_GAIN_MAT();
        this.param().nodeGradStatsName = GBDTModel$.MODULE$.NODE_GRAD_MAT();
        this.param().nodePredsName = GBDTModel$.MODULE$.NODE_PRED_MAT();
    }

    public RegTDataStore initDataMeta(DataBlock<LabeledData> dataSet, RegTParam param) {
        int numFeature = param.numFeature;
        int numNonzero = param.numNonzero;
        this.LOG().info((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Create data meta, numFeature=", ", nonzero=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numFeature), BoxesRunTime.boxToInteger((int)numNonzero)})));
        RegTDataStore dataStore = new RegTDataStore(param);
        dataStore.init(dataSet);
        this.LOG().info((Object)new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Finish creating data meta, numRow=", ", "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)dataStore.numRow)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"numCol=", ", nonzero=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)dataStore.numCol), BoxesRunTime.boxToInteger((int)dataStore.numNonzero)}))).toString());
        return dataStore;
    }

    public void updateMetrics(GBDTController controller) {
        Tuple1<Double> trainMetrics = controller.eval();
        Tuple1<Double> validMetrics = controller.predict();
        this.globalMetrics().metric(MLConf$.MODULE$.TRAIN_ERROR(), Predef$.MODULE$.Double2double((Double)trainMetrics._1()));
        this.globalMetrics().metric(MLConf$.MODULE$.VALID_ERROR(), Predef$.MODULE$.Double2double((Double)validMetrics._1()));
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public MLModel train(DataBlock<LabeledData> trainData, DataBlock<LabeledData> validationData) {
        void var9_7;
        this.LOG().debug((Object)"------GBDT starts training------");
        this.LOG().info((Object)"1. initialize");
        this.initParam();
        long dataGenStartTs = System.currentTimeMillis();
        RegTDataStore trainDataStore = this.initDataMeta(trainData, this.param());
        RegTDataStore validDataStore = this.initDataMeta(validationData, this.param());
        this.LOG().info((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Build data info cost ", " ms"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToLong((long)(System.currentTimeMillis() - dataGenStartTs))})));
        this.LOG().info((Object)"2.train");
        long trainStartTs = System.currentTimeMillis();
        GBDTController controller = new GBDTController(this.ctx(), this.param(), trainDataStore, validDataStore, this.model());
        controller.init();
        this.globalMetrics().addMetric(MLConf$.MODULE$.TRAIN_ERROR(), ErrorMetric$.MODULE$.apply(trainDataStore.numRow));
        this.globalMetrics().addMetric(MLConf$.MODULE$.VALID_ERROR(), ErrorMetric$.MODULE$.apply(validDataStore.numRow));
        while (true) {
            BoxedUnit boxedUnit;
            GBDTPhase gBDTPhase = controller.phase;
            GBDTPhase gBDTPhase2 = GBDTPhase.FINISHED;
            if (!(gBDTPhase != null ? !((Object)((Object)gBDTPhase)).equals((Object)gBDTPhase2) : gBDTPhase2 != null)) {
                this.LOG().info((Object)new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Task[", "] finishes training, "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.ctx().getTaskIndex())}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"train phase cost ", " ms, "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToLong((long)(System.currentTimeMillis() - trainStartTs))}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"total clock ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)controller.clock)}))).toString());
                return this.model();
            }
            this.LOG().info((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"******Current phase: ", ", clock[", "]******"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{var9_7.phase, BoxesRunTime.boxToInteger((int)var9_7.clock)})));
            GBDTPhase gBDTPhase3 = var9_7.phase;
            if (((Object)((Object)GBDTPhase.CREATE_SKETCH)).equals((Object)gBDTPhase3)) {
                var9_7.createSketch();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.GET_SKETCH)).equals((Object)gBDTPhase3)) {
                var9_7.getSketch();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.SAMPLE_FEATURE)).equals((Object)gBDTPhase3)) {
                var9_7.sampleFeature();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.NEW_TREE)).equals((Object)gBDTPhase3)) {
                var9_7.createNewTree();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.RUN_ACTIVE)).equals((Object)gBDTPhase3)) {
                var9_7.runActiveNode();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.FIND_SPLIT)).equals((Object)gBDTPhase3)) {
                var9_7.findSplit();
                boxedUnit = BoxedUnit.UNIT;
            } else if (((Object)((Object)GBDTPhase.AFTER_SPLIT)).equals((Object)gBDTPhase3)) {
                var9_7.afterSplit();
                boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!((Object)((Object)GBDTPhase.FINISH_TREE)).equals((Object)gBDTPhase3)) break;
                var9_7.finishCurrentTree();
                this.updateMetrics((GBDTController)var9_7);
                boxedUnit = BoxedUnit.UNIT;
            }
            var9_7.updatePhase();
            var9_7.incrementClock();
            this.ctx().incEpoch();
        }
        throw new AngelException(new StringBuilder().append((Object)"Unrecognizable GBDT phase: ").append((Object)var9_7.phase).toString());
    }

    public GBDTLearner(TaskContext ctx) {
        super(ctx);
        this.model = new GBDTModel(this.conf(), ctx);
    }
}

