/*
 * Decompiled with CFR 0.152.
 */
package org.cojen.tupl.repl;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Consumer;
import org.cojen.tupl.diag.EventListener;
import org.cojen.tupl.io.Utils;
import org.cojen.tupl.repl.ChannelInputStream;
import org.cojen.tupl.repl.ChannelManager;
import org.cojen.tupl.repl.EncodingOutputStream;
import org.cojen.tupl.repl.ErrorCodes;
import org.cojen.tupl.repl.GroupFile;
import org.cojen.tupl.repl.JoinException;

class GroupJoiner {
    static final int OP_NOP = 0;
    static final int OP_ERROR = 1;
    static final int OP_ADDRESS = 2;
    static final int OP_JOINED = 3;
    static final int OP_UNJOIN_ADDRESS = 4;
    static final int OP_UNJOIN_MEMBER = 5;
    static final int OP_UNJOINED = 6;
    private final EventListener mEventListener;
    private final File mFile;
    private final long mGroupToken1;
    private final long mGroupToken2;
    private final SocketAddress mLocalAddress;
    private final SocketAddress mBindAddress;
    private Selector mSelector;
    private SocketChannel[] mSeedChannels;
    private SocketChannel mLeaderChannel;
    GroupFile mGroupFile;
    boolean mReplySuccess;
    long mPrevTerm;
    long mTerm;
    long mPosition;

    GroupJoiner(EventListener eventListener, File groupFile, long groupToken1, long groupToken2, SocketAddress localAddress, SocketAddress listenAddress) {
        this.mEventListener = eventListener;
        this.mFile = groupFile;
        this.mGroupToken1 = groupToken1;
        this.mGroupToken2 = groupToken2;
        this.mLocalAddress = localAddress;
        InetSocketAddress bindAddr = null;
        if (listenAddress instanceof InetSocketAddress) {
            InetSocketAddress isa = (InetSocketAddress)listenAddress;
            bindAddr = new InetSocketAddress(isa.getAddress(), 0);
        }
        this.mBindAddress = bindAddr;
    }

    GroupJoiner(long groupToken1, long groupToken2) {
        this(null, null, groupToken1, groupToken2, null, null);
    }

    void join(Set<SocketAddress> seeds, int timeoutMillis) throws IOException {
        try {
            this.doJoin(seeds, timeoutMillis, out -> {
                out.write(2);
                out.encodeStr(this.mLocalAddress.toString());
            });
        }
        finally {
            this.close();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void unjoin(Set<SocketAddress> seeds, int timeoutMillis, long memberId) throws IOException {
        try {
            this.doJoin(seeds, timeoutMillis, out -> {
                out.write(5);
                out.encodeLongLE(memberId);
            });
        }
        finally {
            this.close();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void unjoin(Set<SocketAddress> seeds, int timeoutMillis, SocketAddress memberAddr) throws IOException {
        try {
            this.doJoin(seeds, timeoutMillis, out -> {
                out.write(4);
                out.encodeStr(memberAddr.toString());
            });
        }
        finally {
            this.close();
        }
    }

    private void doJoin(Set<SocketAddress> seeds, long timeoutMillis, Consumer<EncodingOutputStream> cout) throws IOException {
        String fullMessage;
        block17: {
            if (seeds == null) {
                throw new IllegalArgumentException();
            }
            if (this.mSeedChannels != null) {
                throw new IllegalStateException();
            }
            EncodingOutputStream out = new EncodingOutputStream();
            out.write(ChannelManager.newConnectHeader(0L, 0L, 2, this.mGroupToken1, this.mGroupToken2));
            cout.accept(out);
            byte[] command = out.toByteArray();
            this.mSelector = Selector.open();
            this.mSeedChannels = new SocketChannel[seeds.size()];
            int i = 0;
            for (SocketAddress addr : seeds) {
                SocketChannel channel = SocketChannel.open();
                this.mSeedChannels[i++] = channel;
                this.prepareChannel(channel, addr);
            }
            int expected = seeds.size();
            TreeSet<String> joinFailureMessages = new TreeSet<String>();
            TreeSet<String> connectFailureMessages = new TreeSet<String>();
            long end = System.currentTimeMillis() + timeoutMillis;
            while (expected > 0 && timeoutMillis > 0L) {
                this.mSelector.select(timeoutMillis);
                Set<SelectionKey> keys = this.mSelector.selectedKeys();
                for (SelectionKey key : keys) {
                    SocketChannel channel = (SocketChannel)key.channel();
                    try {
                        if (key.isConnectable()) {
                            channel.finishConnect();
                            key.interestOps(4);
                            continue;
                        }
                        if (key.isWritable()) {
                            channel.write(ByteBuffer.wrap(command));
                            key.interestOps(1);
                            continue;
                        }
                        key.cancel();
                        channel.configureBlocking(true);
                        SocketAddress addr = this.processReply(channel.socket(), timeoutMillis);
                        --expected;
                        if (addr == null || this.mLeaderChannel != null || seeds.contains(addr)) continue;
                        ++expected;
                        channel = SocketChannel.open();
                        this.prepareChannel(channel, addr);
                        this.mLeaderChannel = channel;
                    }
                    catch (JoinException e) {
                        Utils.closeQuietly(channel);
                        --expected;
                        joinFailureMessages.add(e.getMessage());
                    }
                    catch (IOException e) {
                        Utils.closeQuietly(channel);
                        --expected;
                        connectFailureMessages.add(e.toString());
                    }
                }
                keys.clear();
                if (this.mReplySuccess) {
                    return;
                }
                timeoutMillis = end - System.currentTimeMillis();
            }
            LinkedHashSet<String> failureMessages = new LinkedHashSet<String>();
            failureMessages.addAll(joinFailureMessages);
            failureMessages.addAll(connectFailureMessages);
            if (failureMessages.isEmpty()) {
                fullMessage = "timed out";
            } else {
                StringBuilder b = null;
                Iterator it = failureMessages.iterator();
                do {
                    String message = (String)it.next();
                    if (b == null) {
                        if (!it.hasNext()) {
                            fullMessage = message;
                            break block17;
                        }
                        b = new StringBuilder();
                    } else {
                        b.append("; ");
                    }
                    b.append(message);
                } while (it.hasNext());
                fullMessage = b.toString();
            }
        }
        throw new JoinException(fullMessage);
    }

    private void prepareChannel(SocketChannel channel, SocketAddress addr) throws IOException {
        if (this.mBindAddress != null) {
            channel.bind(this.mBindAddress);
        }
        channel.configureBlocking(false);
        channel.register(this.mSelector, 8);
        channel.connect(addr);
    }

    private SocketAddress processReply(Socket s, long timeoutMillis) throws IOException {
        if (timeoutMillis >= 0L) {
            int intTimeout = timeoutMillis == 0L ? 1 : (int)Math.min(timeoutMillis, Integer.MAX_VALUE);
            s.setSoTimeout(intTimeout);
        }
        SocketAddress addr = null;
        byte[] header = ChannelManager.readHeader(s, false, 0L, this.mGroupToken1, this.mGroupToken2);
        if (header != null) {
            ChannelInputStream cin = new ChannelInputStream(s.getInputStream(), 1000, false);
            int op = cin.read();
            if (op == 2) {
                InetSocketAddress isa;
                addr = GroupFile.parseSocketAddress(cin.readStr(cin.readIntLE()));
                if (!(addr instanceof InetSocketAddress) || (isa = (InetSocketAddress)addr).getAddress().isAnyLocalAddress()) {
                    addr = null;
                }
            } else if (op == 3) {
                this.mPrevTerm = cin.readLongLE();
                this.mTerm = cin.readLongLE();
                this.mPosition = cin.readLongLE();
                try (FileOutputStream out = new FileOutputStream(this.mFile);){
                    cin.drainTo(out);
                }
                this.mGroupFile = GroupFile.open(this.mEventListener, this.mFile, this.mLocalAddress, false);
                this.mReplySuccess = true;
            } else if (op == 6) {
                this.mReplySuccess = true;
            } else if (op == 1) {
                throw new JoinException(ErrorCodes.toString(cin.readByte()));
            }
        }
        Utils.closeQuietly(s);
        return addr;
    }

    private void close() {
        Utils.closeQuietly(this.mSelector);
        this.mSelector = null;
        if (this.mSeedChannels != null) {
            for (SocketChannel channel : this.mSeedChannels) {
                Utils.closeQuietly(channel);
            }
            this.mSeedChannels = null;
        }
        Utils.closeQuietly(this.mLeaderChannel);
        this.mLeaderChannel = null;
    }
}

