/*
 * Decompiled with CFR 0.152.
 */
package io.nats.client.impl;

import io.nats.client.Options;
import io.nats.client.impl.DataPort;
import io.nats.client.impl.NatsConnection;
import io.nats.client.support.NatsInetAddress;
import io.nats.client.support.NatsUri;
import io.nats.client.support.WebSocket;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketException;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import org.jspecify.annotations.NonNull;

public class SocketDataPort
implements DataPort {
    protected NatsConnection connection;
    protected String host;
    protected int port;
    protected Socket socket;
    protected boolean isSecure = false;
    protected int soLinger;
    protected InputStream in;
    protected OutputStream out;

    @Override
    public void afterConstruct(Options options) {
        this.soLinger = options.getSocketSoLinger();
    }

    @Override
    public void connect(@NonNull String serverURI, @NonNull NatsConnection conn, long timeoutNanos) throws IOException {
        try {
            this.connect(conn, new NatsUri(serverURI), timeoutNanos);
        }
        catch (URISyntaxException e) {
            throw new IOException(e);
        }
    }

    @Override
    public void connect(@NonNull NatsConnection conn, @NonNull NatsUri nuri, long timeoutNanos) throws IOException {
        this.connection = conn;
        Options options = this.connection.getOptions();
        long timeout = timeoutNanos / 1000000L;
        this.host = nuri.getHost();
        this.port = nuri.getPort();
        try {
            if (options.isEnableFastFallback()) {
                this.socket = this.connectToFastestIp(options, this.host, this.port, (int)timeout);
            } else {
                this.socket = this.createSocket(options);
                this.socket.connect(new InetSocketAddress(this.host, this.port), (int)timeout);
            }
            if (this.soLinger > -1) {
                this.socket.setSoLinger(true, this.soLinger);
            }
            if (options.getSocketReadTimeoutMillis() > 0) {
                this.socket.setSoTimeout(options.getSocketReadTimeoutMillis());
            }
            if (SocketDataPort.isWebsocketScheme(nuri.getScheme())) {
                if ("wss".equalsIgnoreCase(nuri.getScheme())) {
                    this.upgradeToSecure();
                }
                try {
                    this.socket = new WebSocket(this.socket, this.host, options.getHttpRequestInterceptors());
                }
                catch (Exception ex) {
                    this.socket.close();
                    throw ex;
                }
            }
            this.in = this.socket.getInputStream();
            this.out = this.socket.getOutputStream();
        }
        catch (Exception e) {
            if (this.socket != null) {
                try {
                    this.socket.close();
                }
                catch (Exception exception) {
                    // empty catch block
                }
            }
            this.socket = null;
            if (e instanceof IOException) {
                throw e;
            }
            throw new IOException(e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void upgradeToSecure() throws IOException {
        Options options = this.connection.getOptions();
        SSLContext context = options.getSslContext();
        SSLSocketFactory factory = context.getSocketFactory();
        Duration timeout = options.getConnectionTimeout();
        SSLSocket sslSocket = (SSLSocket)factory.createSocket(this.socket, this.host, this.port, true);
        sslSocket.setUseClientMode(true);
        CompletableFuture waitForHandshake = new CompletableFuture();
        HandshakeCompletedListener hcl = evt -> waitForHandshake.complete(null);
        sslSocket.addHandshakeCompletedListener(hcl);
        sslSocket.startHandshake();
        try {
            waitForHandshake.get(timeout.toNanos(), TimeUnit.NANOSECONDS);
        }
        catch (Exception ex) {
            this.connection.handleCommunicationIssue(ex);
            return;
        }
        finally {
            sslSocket.removeHandshakeCompletedListener(hcl);
        }
        this.socket = sslSocket;
        this.in = sslSocket.getInputStream();
        this.out = sslSocket.getOutputStream();
        this.isSecure = true;
    }

    @Override
    public int read(byte[] dst, int off, int len) throws IOException {
        return this.in.read(dst, off, len);
    }

    @Override
    public void write(byte[] src, int toWrite) throws IOException {
        this.out.write(src, 0, toWrite);
    }

    @Override
    public void shutdownInput() throws IOException {
        if (!this.isSecure && this.socket != null) {
            this.socket.shutdownInput();
        }
    }

    @Override
    public void close() throws IOException {
        if (this.socket != null) {
            this.socket.close();
        }
    }

    @Override
    public void forceClose() throws IOException {
        if (this.socket != null) {
            try {
                this.socket.setSoLinger(true, 0);
            }
            catch (SocketException socketException) {
                // empty catch block
            }
            this.close();
        }
    }

    @Override
    public void flush() throws IOException {
        this.out.flush();
    }

    protected static boolean isWebsocketScheme(String scheme) {
        return "ws".equalsIgnoreCase(scheme) || "wss".equalsIgnoreCase(scheme);
    }

    private Socket connectToFastestIp(Options options, String hostname, int port, int timeoutMillis) throws IOException {
        List<InetAddress> ips = Arrays.asList(NatsInetAddress.getAllByName(hostname));
        ExecutorService executor = options.getExecutor();
        long CONNECT_DELAY_MILLIS = 250L;
        ArrayList<Callable<Socket>> connectionTasks = new ArrayList<Callable<Socket>>();
        for (int i = 0; i < ips.size(); ++i) {
            InetAddress ip = ips.get(i);
            int delayMillis = i * (int)CONNECT_DELAY_MILLIS;
            connectionTasks.add(() -> {
                if (delayMillis > 0) {
                    try {
                        Thread.sleep(delayMillis);
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
                Socket socket = this.createSocket(options);
                socket.connect(new InetSocketAddress(ip, port), timeoutMillis);
                return socket;
            });
        }
        try {
            return (Socket)executor.invokeAny(connectionTasks);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        catch (ExecutionException executionException) {
            // empty catch block
        }
        throw new IOException("No responsive IP found for " + hostname);
    }

    private Socket createSocket(Options options) throws SocketException {
        Socket socket = options.getProxy() != null ? new Socket(options.getProxy()) : new Socket();
        socket.setTcpNoDelay(true);
        socket.setReceiveBufferSize(0x200000);
        socket.setSendBufferSize(0x200000);
        return socket;
    }
}

