/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism.parameterserver;

import io.aeron.driver.MediaDriver;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.parallelism.factory.TrainerContext;
import org.deeplearning4j.parallelism.parameterserver.ParameterServerTrainer;
import org.deeplearning4j.parallelism.trainer.Trainer;
import org.nd4j.parameterserver.client.ParameterServerClient;
import org.nd4j.parameterserver.node.ParameterServerNode;

public class ParameterServerTrainerContext
implements TrainerContext {
    private ParameterServerNode parameterServerNode;
    private MediaDriver mediaDriver;
    private MediaDriver.Context mediaDriverContext;
    private int statusServerPort = 33000;
    private int numUpdatesPerEpoch = 1;
    private String[] parameterServerArgs;
    private int numWorkers = 1;

    public void init(Model model, Object ... args) {
        this.mediaDriverContext = new MediaDriver.Context();
        this.mediaDriver = MediaDriver.launchEmbedded((MediaDriver.Context)this.mediaDriverContext);
        this.parameterServerNode = new ParameterServerNode(this.mediaDriver, this.statusServerPort, this.numWorkers);
        if (this.parameterServerArgs == null) {
            this.parameterServerArgs = new String[]{"-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", this.mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(this.statusServerPort), "-u", String.valueOf(this.numUpdatesPerEpoch)};
        }
    }

    public Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, WorkspaceMode mode, int averagingFrequency) {
        return ParameterServerTrainer.builder().originalModel(model).parameterServerClient(ParameterServerClient.builder().aeron(this.parameterServerNode.getAeron()).ndarrayRetrieveUrl(this.parameterServerNode.getSubscriber()[threadId].getResponder().connectionUrl()).ndarraySendUrl(this.parameterServerNode.getSubscriber()[threadId].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(this.statusServerPort).subscriberPort(40625 + threadId).subscriberStream(12 + threadId).build()).replicatedModel(model).threadId(threadId).parallelWrapper(wrapper).useMDS(useMDS).build();
    }

    public void finalizeRound(Model originalModel, Model ... models) {
    }

    public void finalizeTraining(Model originalModel, Model ... models) {
    }
}

