/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.rabit;

import hex.tree.xgboost.rabit.RabitWorker;
import hex.tree.xgboost.rabit.util.LinkMap;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import water.H2O;
import water.util.Log;

public class RabitTrackerH2O
implements IRabitTracker {
    public static final int MAGIC = 65433;
    private ServerSocketChannel sock;
    private int port = 9091;
    private int workers;
    private Map<String, String> envs = new HashMap<String, String>();
    private volatile RabitTrackerH2OThread trackerThread;

    public RabitTrackerH2O(int workers) {
        if (workers < 1) {
            throw new IllegalStateException("workers must be greater than or equal to one (1).");
        }
        this.workers = workers;
        Log.debug((Object[])new Object[]{"Rabit tracker started on port ", this.port});
    }

    public Map<String, String> getWorkerEnvs() {
        this.envs.put("DMLC_NUM_WORKER", String.valueOf(this.workers));
        this.envs.put("DMLC_NUM_SERVER", "0");
        this.envs.put("DMLC_TRACKER_URI", H2O.SELF_ADDRESS.getHostAddress());
        this.envs.put("DMLC_TRACKER_PORT", Integer.toString(this.port));
        this.envs.put("rabit_world_size", Integer.toString(this.workers));
        return this.envs;
    }

    public boolean start(long timeout) {
        RabitTrackerH2OThread trackerThread;
        boolean tryToBind = true;
        while (tryToBind) {
            try {
                this.sock = ServerSocketChannel.open();
                this.sock.socket().setReceiveBufferSize(65536);
                InetSocketAddress isa = new InetSocketAddress(H2O.SELF_ADDRESS, this.port);
                this.sock.socket().bind(isa);
                tryToBind = false;
            }
            catch (IOException e) {
                ++this.port;
                try {
                    this.sock.close();
                }
                catch (IOException socketCloseException) {
                    Log.warn((Object[])new Object[]{"Failed to close Rabit Tracker socket on port ", this.sock.socket().getLocalPort()});
                }
                if (this.port <= 9999) continue;
                throw new RuntimeException("Failed to bind Rabit tracker to a socket in range 9091-9999", e);
            }
        }
        if (null != this.trackerThread) {
            throw new IllegalStateException("Rabit tracker already started.");
        }
        this.trackerThread = trackerThread = new RabitTrackerH2OThread(this);
        trackerThread.start();
        return true;
    }

    public void stop() {
        assert (this.trackerThread != null);
        try {
            this.trackerThread.interrupt();
        }
        catch (SecurityException e) {
            Log.err((Object[])new Object[]{"Could not interrupt a thread in RabitTrackerH2O: " + this.trackerThread.toString()});
        }
        this.trackerThread.terminateSocketChannels();
        this.trackerThread = null;
        try {
            this.sock.close();
            this.port = 9091;
        }
        catch (IOException e) {
            Log.err((Object[])new Object[]{"Failed to close Rabit tracker socket.", e});
        }
    }

    public int waitFor(long timeout) {
        while (null != this.trackerThread && this.trackerThread.isAlive()) {
            try {
                this.trackerThread.join(timeout);
            }
            catch (InterruptedException e) {
                Log.debug((Object[])new Object[]{"Rabit tracker thread got suddenly interrupted.", e});
            }
        }
        return 0;
    }

    public void uncaughtException(Thread t, Throwable e) {
        this.stop();
    }

    private class RabitTrackerH2OThread
    extends Thread {
        private RabitTrackerH2O tracker;
        private LinkMap linkMap;
        private Map<String, Integer> jobToRankMap = new HashMap<String, Integer>();
        private final List<SocketChannel> socketChannels = new ArrayList<SocketChannel>();
        private static final String PRINT_CMD = "print";
        private static final String SHUTDOWN_CMD = "shutdown";
        private static final String START_CMD = "start";
        private static final String RECOVER_CMD = "recover";
        private static final String NULL_STR = "null";

        private RabitTrackerH2OThread(RabitTrackerH2O tracker) {
            this.setPriority(9);
            this.setName("TCP-" + tracker.sock);
            this.tracker = tracker;
        }

        private final void terminateSocketChannels() {
            for (SocketChannel channel : this.socketChannels) {
                try {
                    channel.close();
                }
                catch (IOException e) {
                    Log.warn((Object[])new Object[]{"Unable to close RabitTracerH2O SocketChannel on port ", channel.socket().getPort()});
                }
            }
        }

        @Override
        public void run() {
            HashSet<Integer> shutdown = new HashSet<Integer>();
            HashMap<Integer, RabitWorker> waitConn = new HashMap<Integer, RabitWorker>();
            ArrayList<RabitWorker> pending = new ArrayList<RabitWorker>();
            ArrayDeque<Integer> todoNodes = new ArrayDeque<Integer>(this.tracker.workers);
            while (!RabitTrackerH2OThread.interrupted() && shutdown.size() != this.tracker.workers) {
                try {
                    SocketChannel channel = this.tracker.sock.accept();
                    this.socketChannels.add(channel);
                    RabitWorker worker = new RabitWorker(channel);
                    if (PRINT_CMD.equals(worker.cmd)) {
                        String msg = worker.receiver().getStr();
                        Log.warn((Object[])new Object[]{"Rabit worker: ", msg});
                        continue;
                    }
                    if (SHUTDOWN_CMD.equals(worker.cmd)) {
                        assert (worker.rank >= 0 && !shutdown.contains(worker.rank));
                        assert (!waitConn.containsKey(worker));
                        shutdown.add(worker.rank);
                        channel.socket().close();
                        Log.debug((Object[])new Object[]{"Received ", worker.cmd, " signal from ", worker.rank});
                        continue;
                    }
                    assert (START_CMD.equals(worker.cmd) || RECOVER_CMD.equals(worker.cmd));
                    if (null == this.linkMap) {
                        assert (START_CMD.equals(worker.cmd));
                        this.linkMap = new LinkMap(this.tracker.workers);
                        for (int i = 0; i < this.tracker.workers; ++i) {
                            todoNodes.add(i);
                        }
                    } else assert (worker.worldSize == -1 || worker.worldSize == this.tracker.workers);
                    if (RECOVER_CMD.equals(worker.cmd)) assert (worker.rank >= 0);
                    int rank = worker.decideRank(this.jobToRankMap);
                    if (-1 == rank) {
                        assert (todoNodes.size() != 0);
                        pending.add(worker);
                        if (pending.size() == todoNodes.size()) {
                            Collections.sort(pending);
                            for (RabitWorker p : pending) {
                                rank = (Integer)todoNodes.poll();
                                if (!NULL_STR.equals(p.jobId)) {
                                    this.jobToRankMap.put(p.jobId, rank);
                                }
                                p.assignRank(rank, waitConn, this.linkMap);
                                if (p.waitAccept > 0) {
                                    waitConn.put(rank, p);
                                }
                                Log.debug((Object[])new Object[]{"Received " + p.cmd + " signal from " + p.host + ":" + p.workerPort + ". Assigned rank " + p.rank});
                            }
                        }
                        if (!todoNodes.isEmpty()) continue;
                        Log.debug((Object[])new Object[]{"All " + this.tracker.workers + " Rabit workers are getting started."});
                        continue;
                    }
                    worker.assignRank(rank, waitConn, this.linkMap);
                    if (worker.waitAccept <= 0) continue;
                    waitConn.put(rank, worker);
                }
                catch (IOException e) {
                    Log.err((Object[])new Object[]{"Exception in Rabit tracker.", e});
                }
            }
            Log.debug((Object[])new Object[]{"All Rabit nodes finished."});
        }
    }
}

