/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.azure.synapse.ml.lightgbm;

import com.microsoft.azure.synapse.ml.core.env.StreamUtilities$;
import com.microsoft.azure.synapse.ml.core.utils.ClusterUtil$;
import com.microsoft.azure.synapse.ml.core.utils.FaultToleranceUtils$;
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants$;
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils$;
import com.microsoft.azure.synapse.ml.lightgbm.NetworkManager;
import com.microsoft.azure.synapse.ml.lightgbm.NetworkParams;
import com.microsoft.azure.synapse.ml.lightgbm.NetworkTopologyInfo;
import com.microsoft.azure.synapse.ml.lightgbm.PartitionTaskContext;
import com.microsoft.azure.synapse.ml.lightgbm.TaskInstrumentationMeasures;
import com.microsoft.azure.synapse.ml.lightgbm.TaskMessageInfo;
import com.microsoft.azure.synapse.ml.lightgbm.TrainingContext;
import com.microsoft.ml.lightgbm.lightgbmlib;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.spark.BarrierTaskContext;
import org.apache.spark.BarrierTaskContext$;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Some;
import scala.Tuple6;
import scala.collection.Seq;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.concurrent.ExecutionContext$;
import scala.concurrent.ExecutionContextExecutor;
import scala.concurrent.duration.Duration;
import scala.concurrent.duration.Duration$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

public final class NetworkManager$
implements scala.Serializable {
    public static NetworkManager$ MODULE$;

    static {
        new NetworkManager$();
    }

    public NetworkManager create(int numTasks, SparkSession spark, int driverListenPort, double timeout, boolean useBarrierExecutionMode) {
        ExecutionContextExecutor context = ExecutionContext$.MODULE$.fromExecutor((Executor)Executors.newSingleThreadExecutor());
        ServerSocket driverServerSocket = new ServerSocket(driverListenPort);
        Duration duration = Duration$.MODULE$.apply(timeout, TimeUnit.SECONDS);
        if (duration.isFinite()) {
            driverServerSocket.setSoTimeout((int)duration.toMillis());
        }
        String host = ClusterUtil$.MODULE$.getDriverHost(spark);
        int port = driverServerSocket.getLocalPort();
        return new NetworkManager(numTasks, driverServerSocket, host, port, timeout, useBarrierExecutionMode);
    }

    public NetworkTopologyInfo getGlobalNetworkInfo(TrainingContext ctx, Logger log, long taskId, int partitionId, boolean shouldExecuteTraining, TaskInstrumentationMeasures measures) {
        measures.markNetworkInitializationStart();
        NetworkParams networkParams = ctx.networkParams();
        NetworkTopologyInfo out = (NetworkTopologyInfo)StreamUtilities$.MODULE$.using((AutoCloseable)this.findOpenPort(ctx, log).get(), (Function1 & Serializable & scala.Serializable)openPort -> {
            int localListenPort = openPort.getLocalPort();
            log.info(new StringBuilder(43).append("LightGBM task ").append(taskId).append(" connecting to host: ").append(networkParams.ipAddress()).append(", port: ").append(networkParams.port()).toString());
            return (NetworkTopologyInfo)FaultToleranceUtils$.MODULE$.retryWithTimeout(FaultToleranceUtils$.MODULE$.retryWithTimeout$default$1(), (Function0 & Serializable & scala.Serializable)() -> MODULE$.getNetworkTopologyInfoFromDriver(networkParams, taskId, partitionId, localListenPort, log, shouldExecuteTraining));
        }).get();
        measures.markNetworkInitializationStop();
        return out;
    }

    private NetworkTopologyInfo getNetworkTopologyInfoFromDriver(NetworkParams networkParams, long taskId, int partitionId, int localListenPort, Logger log, boolean shouldExecuteTraining) {
        return (NetworkTopologyInfo)StreamUtilities$.MODULE$.using((AutoCloseable)new Socket(networkParams.ipAddress(), networkParams.port()), (Function1 & Serializable & scala.Serializable)driverSocket -> (NetworkTopologyInfo)StreamUtilities$.MODULE$.usingMany((Seq)new .colon.colon((Object)new BufferedReader(new InputStreamReader(driverSocket.getInputStream())), (List)new .colon.colon((Object)new BufferedWriter(new OutputStreamWriter(driverSocket.getOutputStream())), (List)Nil$.MODULE$)), (Function1 & Serializable & scala.Serializable)io -> {
            BufferedReader driverInput = (BufferedReader)io.head();
            BufferedWriter driverOutput = (BufferedWriter)io.apply(1);
            TaskMessageInfo taskStatus = new TaskMessageInfo(shouldExecuteTraining ? LightGBMConstants$.MODULE$.EnabledTask() : LightGBMConstants$.MODULE$.IgnoreStatus(), driverSocket.getLocalAddress().getHostAddress(), localListenPort, partitionId, LightGBMUtils$.MODULE$.getExecutorId());
            String message = taskStatus.toString();
            log.info(new StringBuilder(41).append("task ").append(taskId).append(" sending status message to driver: ").append(message).append(" ").toString());
            driverOutput.write(new StringBuilder(1).append(message).append("\n").toString());
            driverOutput.flush();
            if (networkParams.barrierExecutionMode()) {
                BarrierTaskContext context = BarrierTaskContext$.MODULE$.get();
                context.barrier();
                if (context.partitionId() == 0) {
                    MODULE$.setFinishedStatus(networkParams, log);
                }
            }
            String lightGbmMachineList = driverInput.readLine();
            String partitionsByExecutorStr = driverInput.readLine();
            if (partitionsByExecutorStr == null || lightGbmMachineList == null) {
                String message2 = new StringBuilder(110).append("Received bad network information. Task ").append(taskId).append(", partition ").append(partitionId).append(" received").append("partition topology: '").append(partitionsByExecutorStr).append("', nodes for network init: '").append(lightGbmMachineList).append("'").toString();
                throw new Exception(message2);
            }
            log.info(new StringBuilder(49).append("task ").append(taskId).append(", partition ").append(partitionId).append(" received partition topology: '").append(partitionsByExecutorStr).append("'").toString());
            log.info(new StringBuilder(53).append("task ").append(taskId).append(", partition ").append(partitionId).append(" received nodes for network init: '").append(lightGbmMachineList).append("'").toString());
            int[] executorPartitionIds = MODULE$.parseExecutorPartitionList(partitionsByExecutorStr, taskStatus.executorId(), log);
            return new NetworkTopologyInfo(lightGbmMachineList, executorPartitionIds, localListenPort);
        }).get()).get();
    }

    private int[] parseExecutorPartitionList(String partitionsByExecutorStr, String executorId, Logger log) {
        String[] partitionsByExecutor = partitionsByExecutorStr.split(":");
        Option executorListStr = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])partitionsByExecutor)).find((Function1 & Serializable & scala.Serializable)line -> BoxesRunTime.boxToBoolean((boolean)line.startsWith(new StringBuilder(1).append(executorId).append("=").toString())));
        if (executorListStr.isEmpty()) {
            throw new Exception(new StringBuilder(47).append("Could not find partitions for executor ").append(executorId).append(". List: ").append(partitionsByExecutorStr).toString());
        }
        log.info(new StringBuilder(33).append("executor ").append(executorId).append(" received partitions: '").append(executorListStr).append("'").toString());
        String partitionList = ((String)executorListStr.get()).split("=")[1];
        return (int[])new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])partitionList.split(","))).map((Function1 & Serializable & scala.Serializable)str -> BoxesRunTime.boxToInteger((int)NetworkManager$.$anonfun$parseExecutorPartitionList$2(str)), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).sorted((Ordering)Ordering.Int$.MODULE$);
    }

    public void initLightGBMNetwork(PartitionTaskContext ctx, Logger log, int retry, long delay) {
        log.info(new StringBuilder(46).append("Calling NetworkInit on local port ").append(ctx.localListenPort()).append(" with value ").append(ctx.lightGBMNetworkString()).toString());
        try {
            LightGBMUtils$.MODULE$.validate(lightgbmlib.LGBM_NetworkInit((String)ctx.lightGBMNetworkString(), (int)ctx.localListenPort(), (int)LightGBMConstants$.MODULE$.DefaultListenTimeout(), (int)ctx.lightGBMNetworkMachineCount()), "Network init");
            log.info(new StringBuilder(51).append("NetworkInit succeeded. LightGBM task listening on: ").append(ctx.localListenPort()).toString());
        }
        catch (Throwable throwable) {
            Throwable throwable2 = throwable;
            if (throwable2 instanceof Exception ? true : throwable2 != null) {
                log.info(new StringBuilder(65).append("NetworkInit failed with exception on local port ").append(ctx.localListenPort()).append(" with exception: ").append(throwable2).toString());
                Thread.sleep(delay);
                if (retry == 0) {
                    log.info(new StringBuilder(49).append("NetworkInit reached maximum exceptions on retry: ").append(throwable2).toString());
                    throw throwable2;
                }
                log.info(new StringBuilder(37).append("Retrying NetworkInit with local port ").append(ctx.localListenPort()).toString());
                this.initLightGBMNetwork(ctx, log, retry - 1, delay * 2L);
            }
            throw throwable;
        }
    }

    public int initLightGBMNetwork$default$3() {
        return LightGBMConstants$.MODULE$.NetworkRetries();
    }

    public long initLightGBMNetwork$default$4() {
        return LightGBMConstants$.MODULE$.InitialDelay();
    }

    public int getMainWorkerPort(String nodes, Logger log) {
        String[] nodesList = nodes.split(",");
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])nodesList)).isEmpty()) {
            throw new Exception("Error: could not split nodes list correctly");
        }
        String mainNode = nodesList[0];
        String[] hostAndPort = mainNode.split(":");
        if (hostAndPort.length != 2) {
            throw new Exception("Error: could not parse main worker host and port correctly");
        }
        String mainHost = hostAndPort[0];
        String mainPort = hostAndPort[1];
        log.info(new StringBuilder(46).append("LightGBM setting main worker host: ").append(mainHost).append(" and port: ").append(mainPort).toString());
        return new StringOps(Predef$.MODULE$.augmentString(mainPort)).toInt();
    }

    private Option<Socket> findOpenPort(TrainingContext ctx, Logger log) {
        int defaultListenPort = ctx.networkParams().defaultListenPort();
        int basePort = defaultListenPort + LightGBMUtils$.MODULE$.getWorkerId() * ctx.numTasksPerExecutor();
        if (basePort > LightGBMConstants$.MODULE$.MaxPort()) {
            throw new Exception(new StringBuilder(78).append("Error: port ").append(basePort).append(" out of range, possibly due to too many executors or unknown error").toString());
        }
        IntRef localListenPort = IntRef.create((int)basePort);
        ObjectRef taskServerSocket = ObjectRef.create((Object)None$.MODULE$);
        this.findPort$1(taskServerSocket, localListenPort, log, basePort);
        log.info(new StringBuilder(27).append("Successfully bound to port ").append(localListenPort.elem).toString());
        return (Option)taskServerSocket.elem;
    }

    private void setFinishedStatus(NetworkParams networkParams, Logger log) {
        StreamUtilities$.MODULE$.using((AutoCloseable)new Socket(networkParams.ipAddress(), networkParams.port()), (Function1 & Serializable & scala.Serializable)driverSocket -> {
            NetworkManager$.$anonfun$setFinishedStatus$1(log, driverSocket);
            return BoxedUnit.UNIT;
        }).get();
    }

    public TaskMessageInfo parseWorkerMessage(String message) {
        String status;
        String[] components = message.split(":");
        String string = status = components[0];
        String string2 = LightGBMConstants$.MODULE$.FinishedStatus();
        if (!(string != null ? !string.equals(string2) : string2 != null)) {
            return new TaskMessageInfo(status);
        }
        if (components.length != 5) {
            throw new Exception(new StringBuilder(20).append("Unexpected message: ").append(message).toString());
        }
        String host = components[1];
        int port = new StringOps(Predef$.MODULE$.augmentString(components[2])).toInt();
        int partitionId = new StringOps(Predef$.MODULE$.augmentString(components[3])).toInt();
        String executorId = components[4];
        return new TaskMessageInfo(status, host, port, partitionId, executorId);
    }

    public NetworkManager apply(int numTasks, ServerSocket driverServerSocket, String host, int port, double timeout, boolean useBarrierExecutionMode) {
        return new NetworkManager(numTasks, driverServerSocket, host, port, timeout, useBarrierExecutionMode);
    }

    public Option<Tuple6<Object, ServerSocket, String, Object, Object, Object>> unapply(NetworkManager x$0) {
        if (x$0 == null) {
            return None$.MODULE$;
        }
        return new Some((Object)new Tuple6((Object)BoxesRunTime.boxToInteger((int)x$0.numTasks()), (Object)x$0.driverServerSocket(), (Object)x$0.host(), (Object)BoxesRunTime.boxToInteger((int)x$0.port()), (Object)BoxesRunTime.boxToDouble((double)x$0.timeout()), (Object)BoxesRunTime.boxToBoolean((boolean)x$0.useBarrierExecutionMode())));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ int $anonfun$parseExecutorPartitionList$2(String str) {
        return new StringOps(Predef$.MODULE$.augmentString(str)).toInt();
    }

    private final void findPort$1(ObjectRef taskServerSocket$1, IntRef localListenPort$3, Logger log$3, int basePort$1) {
        while (true) {
            try {
                taskServerSocket$1.elem = Option$.MODULE$.apply((Object)new Socket());
                ((Socket)((Option)taskServerSocket$1.elem).get()).bind(new InetSocketAddress(localListenPort$3.elem));
            }
            catch (IOException iOException) {
                log$3.warn(new StringBuilder(26).append("Could not bind to port ").append(localListenPort$3.elem).append("...").toString());
                ++localListenPort$3.elem;
                if (localListenPort$3.elem <= LightGBMConstants$.MODULE$.MaxPort()) continue;
                throw new Exception(new StringBuilder(72).append("Error: port ").append(basePort$1).append(" out of range, possibly due to networking or firewall issues").toString());
                if (localListenPort$3.elem - basePort$1 <= 1000) continue;
                throw new Exception("Error: Could not find open port after 1k tries");
            }
            break;
        }
    }

    public static final /* synthetic */ void $anonfun$setFinishedStatus$2(Logger log$4, BufferedWriter driverOutput) {
        log$4.info("sending finished status to driver");
        driverOutput.write(new StringBuilder(1).append(LightGBMConstants$.MODULE$.FinishedStatus()).append("\n").toString());
        driverOutput.flush();
    }

    public static final /* synthetic */ void $anonfun$setFinishedStatus$1(Logger log$4, Socket driverSocket) {
        StreamUtilities$.MODULE$.using((AutoCloseable)new BufferedWriter(new OutputStreamWriter(driverSocket.getOutputStream())), (Function1 & Serializable & scala.Serializable)driverOutput -> {
            NetworkManager$.$anonfun$setFinishedStatus$2(log$4, driverOutput);
            return BoxedUnit.UNIT;
        }).get();
    }

    private NetworkManager$() {
        MODULE$ = this;
    }
}

