/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.lda;

import com.tencent.angel.client.AngelClient;
import com.tencent.angel.client.AngelClientFactory;
import com.tencent.angel.data.inputformat.BalanceInputFormat;
import com.tencent.angel.ml.core.MLRunner;
import com.tencent.angel.ml.core.MLRunner$class;
import com.tencent.angel.ml.core.conf.MLConf$;
import com.tencent.angel.ml.lda.LDAModel;
import com.tencent.angel.ml.lda.LDAModel$;
import com.tencent.angel.ml.lda.LDAPredictTask;
import com.tencent.angel.ml.lda.LDATrainTask;
import com.tencent.angel.ml.model.MLModel;
import com.tencent.angel.ml.model.PSModel;
import com.tencent.angel.worker.task.BaseTask;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001A3A!\u0001\u0002\u0001\u001b\tIA\nR!Sk:tWM\u001d\u0006\u0003\u0007\u0011\t1\u0001\u001c3b\u0015\t)a!\u0001\u0002nY*\u0011q\u0001C\u0001\u0006C:<W\r\u001c\u0006\u0003\u0013)\tq\u0001^3oG\u0016tGOC\u0001\f\u0003\r\u0019w.\\\u0002\u0001'\r\u0001aB\u0006\t\u0003\u001fQi\u0011\u0001\u0005\u0006\u0003#I\tA\u0001\\1oO*\t1#\u0001\u0003kCZ\f\u0017BA\u000b\u0011\u0005\u0019y%M[3diB\u0011qCG\u0007\u00021)\u0011\u0011\u0004B\u0001\u0005G>\u0014X-\u0003\u0002\u001c1\tAQ\n\u0014*v]:,'\u000fC\u0003\u001e\u0001\u0011\u0005a$\u0001\u0004=S:LGO\u0010\u000b\u0002?A\u0011\u0001\u0005A\u0007\u0002\u0005!9!\u0005\u0001b\u0001\n\u0003\u0019\u0013a\u0001'P\u000fV\tA\u0005\u0005\u0002&]5\taE\u0003\u0002(Q\u00059An\\4hS:<'BA\u0015+\u0003\u001d\u0019w.\\7p]NT!a\u000b\u0017\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005i\u0013aA8sO&\u0011qF\n\u0002\u0004\u0019><\u0007BB\u0019\u0001A\u0003%A%\u0001\u0003M\u001f\u001e\u0003\u0003\"B\u001a\u0001\t\u0003!\u0014\u0001C:fi\u000e{gNZ:\u0015\u0005UZ\u0004C\u0001\u001c:\u001b\u00059$\"\u0001\u001d\u0002\u000bM\u001c\u0017\r\\1\n\u0005i:$\u0001B+oSRDQ\u0001\u0010\u001aA\u0002u\nAaY8oMB\u0011aHQ\u0007\u0002\u007f)\u0011A\b\u0011\u0006\u0003\u0003*\na\u0001[1e_>\u0004\u0018BA\"@\u00055\u0019uN\u001c4jOV\u0014\u0018\r^5p]\")Q\t\u0001C\u0001\r\u0006Q1/\u001a;K-6{\u0005\u000f^:\u0015\u0005U:\u0005\"\u0002\u001fE\u0001\u0004i\u0004\"B%\u0001\t\u0003R\u0015!\u0002;sC&tGCA\u001bL\u0011\u0015a\u0004\n1\u0001>\u0011\u0015i\u0005\u0001\"\u0011O\u0003\u001d\u0001(/\u001a3jGR$\"!N(\t\u000bqb\u0005\u0019A\u001f")
public class LDARunner
implements MLRunner {
    private final Log LOG;

    @Override
    public final void train(Configuration conf, MLModel model, Class<? extends BaseTask<?, ?, ?>> taskClass) {
        MLRunner$class.train(this, conf, model, taskClass);
    }

    @Override
    public final void predict(Configuration conf, MLModel model, Class<? extends BaseTask<?, ?, ?>> taskClass) {
        MLRunner$class.predict(this, conf, model, taskClass);
    }

    @Override
    public void submit(Configuration conf) throws Exception {
        MLRunner$class.submit(this, conf);
    }

    public Log LOG() {
        return this.LOG;
    }

    public void setConfs(Configuration conf) {
        conf.setInt("angel.worker.max-attempts", 1);
        conf.setInt("angel.worker.task.number", 1);
        conf.set("angel.input.format", BalanceInputFormat.class.getName());
        int numTopics = conf.getInt(LDAModel$.MODULE$.TOPIC_NUM(), 10);
        int numWorkers = conf.getInt("angel.workergroup.number", 10);
        int numServers = conf.getInt("angel.ps.number", 10);
        int numThreads = conf.getInt(MLConf$.MODULE$.ANGEL_WORKER_THREAD_NUM(), 2);
        float alpha = conf.getFloat(LDAModel$.MODULE$.ALPHA(), 50.0f / (float)numTopics);
        float beta = conf.getFloat(LDAModel$.MODULE$.BETA(), 0.01f);
        this.LOG().info((Object)new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"numTopics=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numTopics)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" numWorkers=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numWorkers)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" numPs=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numServers)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" numThreads=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numThreads)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" alpha=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToFloat((float)alpha)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" beta=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToFloat((float)beta)}))).toString());
    }

    public void setJVMOpts(Configuration conf) {
        int totalMemory = conf.getInt("angel.worker.memory.mb", -1);
        if (totalMemory == -1) {
            totalMemory = conf.getInt("angel.worker.memory.gb", 1) * 1000;
        }
        double heapMemoryFraction = conf.getDouble("angel.jvm.heap.fraction", 0.8);
        double directMemoryFraction = 1.0 - heapMemoryFraction;
        int heapSize = (int)((double)totalMemory * heapMemoryFraction);
        int directSize = (int)((double)totalMemory * directMemoryFraction);
        String javaOpts = new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-Xmx", "M "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)heapSize)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+UseConcMarkSweepGC "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+PrintGCTimeStamps "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+PrintGCDetails "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:MaxDirectMemorySize=", "M"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)directSize)}))).toString();
        this.LOG().info((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"worker JVM settings: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{javaOpts})));
        conf.set("angel.worker.java.opts", javaOpts);
        totalMemory = conf.getInt("angel.ps.memory.mb", -1);
        if (totalMemory == -1) {
            totalMemory = conf.getInt("angel.ps.memory.gb", 1) * 1000;
        }
        directMemoryFraction = package$.MODULE$.min(0.5, directMemoryFraction * (double)2);
        heapMemoryFraction = 1.0 - directMemoryFraction;
        heapSize = (int)((double)totalMemory * heapMemoryFraction);
        directSize = (int)((double)totalMemory * directMemoryFraction);
        javaOpts = new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-Xmx", "M "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)heapSize)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+UseConcMarkSweepGC "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+PrintGCTimeStamps "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:+PrintGCDetails "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"-XX:MaxDirectMemorySize=", "M"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)directSize)}))).toString();
        conf.set("angel.ps.java.opts", javaOpts);
        this.LOG().info((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"ps JVM settings: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{javaOpts})));
    }

    @Override
    public void train(Configuration conf) {
        if (conf.get("angel.save.model.path") == null) {
            conf.set("angel.save.model.path", conf.get(LDAModel$.MODULE$.SAVE_PATH()));
        }
        this.setConfs(conf);
        this.setJVMOpts(conf);
        AngelClient client = AngelClientFactory.get((Configuration)conf);
        client.startPSServer();
        client.loadModel((MLModel)new LDAModel(conf, LDAModel$.MODULE$.$lessinit$greater$default$2()));
        client.runTask(LDATrainTask.class);
        client.waitForCompletion();
        client.stop();
    }

    @Override
    public void predict(Configuration conf) {
        this.setConfs(conf);
        this.setJVMOpts(conf);
        AngelClient client = AngelClientFactory.get((Configuration)conf);
        client.startPSServer();
        LDAModel model = new LDAModel(conf, LDAModel$.MODULE$.$lessinit$greater$default$2());
        String path = conf.get("angel.load.model.path");
        for (Map.Entry entry : model.getPSModels().entrySet()) {
            client.addMatrix(((PSModel)entry.getValue()).getContext());
        }
        conf.unset("angel.load.model.path");
        client.createMatrices();
        conf.set("angel.load.model.path", path);
        client.runTask(LDAPredictTask.class);
        client.waitForCompletion();
        client.stop();
    }

    public LDARunner() {
        MLRunner$class.$init$(this);
        this.LOG = LogFactory.getLog(LDARunner.class);
    }
}

