/*
 * Decompiled with CFR 0.152.
 */
package io.asyncer.r2dbc.mysql;

import io.asyncer.r2dbc.mysql.Capability;
import io.asyncer.r2dbc.mysql.ServerVersion;
import io.asyncer.r2dbc.mysql.authentication.MySqlAuthProvider;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.client.FluxExchangeable;
import io.asyncer.r2dbc.mysql.constant.CompressionAlgorithm;
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
import io.asyncer.r2dbc.mysql.message.client.SslRequest;
import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage;
import io.asyncer.r2dbc.mysql.message.server.AuthMoreDataMessage;
import io.asyncer.r2dbc.mysql.message.server.ChangeAuthMessage;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader;
import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest;
import io.asyncer.r2dbc.mysql.message.server.OkMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.asyncer.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.spi.R2dbcNonTransientResourceException;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import java.security.AccessController;
import java.util.Collections;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.SynchronousSink;
import reactor.util.concurrent.Queues;

final class HandshakeExchangeable
extends FluxExchangeable<Void> {
    private static final InternalLogger logger = InternalLoggerFactory.getInstance(HandshakeExchangeable.class);
    private static final Map<String, String> ATTRIBUTES = Collections.emptyMap();
    private static final String CLI_SPECIFIC = "HY000";
    private static final int HANDSHAKE_VERSION = 10;
    private final Sinks.Many<SubsequenceClientMessage> requests = Sinks.many().unicast().onBackpressureBuffer((Queue)Queues.one().get());
    private final Client client;
    private final SslMode sslMode;
    private final String database;
    private final String user;
    @Nullable
    private final CharSequence password;
    private final Set<CompressionAlgorithm> compressions;
    private final int zstdCompressionLevel;
    private boolean handshake = true;
    private MySqlAuthProvider authProvider;
    private byte[] salt;
    private boolean sslCompleted;

    HandshakeExchangeable(Client client, SslMode sslMode, String database, String user, @Nullable CharSequence password, Set<CompressionAlgorithm> compressions, int zstdCompressionLevel) {
        this.client = client;
        this.sslMode = sslMode;
        this.database = database;
        this.user = user;
        this.password = password;
        this.compressions = compressions;
        this.zstdCompressionLevel = zstdCompressionLevel;
        this.sslCompleted = sslMode == SslMode.TUNNEL;
    }

    public void subscribe(CoreSubscriber<? super ClientMessage> actual) {
        this.requests.asFlux().subscribe(actual);
    }

    @Override
    public void accept(ServerMessage message, SynchronousSink<Void> sink) {
        if (message instanceof ErrorMessage) {
            sink.error((Throwable)((ErrorMessage)message).toException());
            return;
        }
        if (this.handshake) {
            this.handshake = false;
            if (message instanceof HandshakeRequest) {
                HandshakeRequest request = (HandshakeRequest)message;
                Capability capability = this.initHandshake(request);
                if (capability.isSslEnabled()) {
                    this.emitNext(SslRequest.from(capability, this.client.getContext().getClientCollation().getId()), sink);
                } else {
                    this.emitNext(this.createHandshakeResponse(capability), sink);
                }
            } else {
                sink.error((Throwable)new R2dbcPermissionDeniedException("Unexpected message type '" + message.getClass().getSimpleName() + "' in init phase"));
            }
            return;
        }
        if (message instanceof OkMessage) {
            logger.trace("Connection (id {}) login success", (Object)this.client.getContext().getConnectionId());
            this.client.loginSuccess();
            sink.complete();
        } else if (message instanceof SyntheticSslResponseMessage) {
            this.sslCompleted = true;
            this.emitNext(this.createHandshakeResponse(this.client.getContext().getCapability()), sink);
        } else if (message instanceof AuthMoreDataMessage) {
            AuthMoreDataMessage msg = (AuthMoreDataMessage)message;
            if (msg.isFailed()) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Connection (id {}) fast authentication failed, use full authentication", (Object)this.client.getContext().getConnectionId());
                }
                this.emitNext(this.createAuthResponse("full authentication"), sink);
            }
        } else if (message instanceof ChangeAuthMessage) {
            ChangeAuthMessage msg = (ChangeAuthMessage)message;
            this.authProvider = MySqlAuthProvider.build(msg.getAuthType());
            this.salt = msg.getSalt();
            this.emitNext(this.createAuthResponse("change authentication"), sink);
        } else {
            sink.error((Throwable)new R2dbcPermissionDeniedException("Unexpected message type '" + message.getClass().getSimpleName() + "' in login phase"));
        }
    }

    public void dispose() {
        this.requests.tryEmitComplete();
    }

    private void emitNext(SubsequenceClientMessage message, SynchronousSink<Void> sink) {
        Sinks.EmitResult result = this.requests.tryEmitNext((Object)message);
        if (result != Sinks.EmitResult.OK) {
            sink.error((Throwable)new IllegalStateException("Fail to emit a login request due to " + result));
        }
    }

    private AuthResponse createAuthResponse(String phase) {
        MySqlAuthProvider authProvider = this.getAndNextProvider();
        if (authProvider.isSslNecessary() && !this.sslCompleted) {
            throw new R2dbcPermissionDeniedException(HandshakeExchangeable.authFails(authProvider.getType(), phase), CLI_SPECIFIC);
        }
        return new AuthResponse(authProvider.authentication(this.password, this.salt, this.client.getContext().getClientCollation()));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private Capability clientCapability(Capability serverCapability) {
        Capability.Builder builder = serverCapability.mutate();
        builder.disableSessionTrack();
        builder.disableDatabasePinned();
        builder.disableIgnoreAmbiguitySpace();
        builder.disableInteractiveTimeout();
        if (this.sslMode == SslMode.TUNNEL) {
            builder.disableSsl();
        } else if (!serverCapability.isSslEnabled()) {
            if (this.sslMode.requireSsl()) {
                throw new R2dbcPermissionDeniedException("Server does not support SSL but mode '" + (Object)((Object)this.sslMode) + "' requires SSL", CLI_SPECIFIC);
            }
            if (this.sslMode.startSsl()) {
                this.client.sslUnsupported();
            }
        } else if (!this.sslMode.startSsl()) {
            builder.disableSsl();
        }
        if (this.isZstdAllowed(serverCapability)) {
            if (HandshakeExchangeable.isZstdSupported()) {
                builder.disableZlibCompression();
            } else {
                logger.warn("Server supports zstd, but zstd-jni dependency is missing");
                if (this.isZlibAllowed(serverCapability)) {
                    builder.disableZstdCompression();
                } else {
                    if (!this.compressions.contains((Object)CompressionAlgorithm.UNCOMPRESSED)) throw new R2dbcNonTransientResourceException("Environment does not support a compression algorithm in " + this.compressions + ", config does not allow uncompressed mode", CLI_SPECIFIC);
                    builder.disableCompression();
                }
            }
        } else if (this.isZlibAllowed(serverCapability)) {
            builder.disableZstdCompression();
        } else {
            if (!this.compressions.contains((Object)CompressionAlgorithm.UNCOMPRESSED)) throw new R2dbcPermissionDeniedException("Environment does not support a compression algorithm in " + this.compressions + ", config does not allow uncompressed mode", CLI_SPECIFIC);
            builder.disableCompression();
        }
        if (this.database.isEmpty()) {
            builder.disableConnectWithDatabase();
        }
        if (this.client.getContext().getLocalInfilePath() == null) {
            builder.disableLoadDataLocalInfile();
        }
        if (!ATTRIBUTES.isEmpty()) return builder.build();
        builder.disableConnectAttributes();
        return builder.build();
    }

    private Capability initHandshake(HandshakeRequest message) {
        HandshakeHeader header = message.getHeader();
        short handshakeVersion = header.getProtocolVersion();
        ServerVersion serverVersion = header.getServerVersion();
        if (handshakeVersion < 10) {
            logger.warn("MySQL use handshake V{}, server version is {}, maybe most features are unavailable", (Object)handshakeVersion, (Object)serverVersion);
        }
        Capability capability = this.clientCapability(message.getServerCapability());
        this.client.getContext().initHandshake(header.getConnectionId(), serverVersion, capability);
        this.authProvider = MySqlAuthProvider.build(message.getAuthType());
        this.salt = message.getSalt();
        return capability;
    }

    private MySqlAuthProvider getAndNextProvider() {
        MySqlAuthProvider authProvider = this.authProvider;
        this.authProvider = authProvider.next();
        return authProvider;
    }

    private HandshakeResponse createHandshakeResponse(Capability capability) {
        MySqlAuthProvider authProvider = this.getAndNextProvider();
        if (authProvider.isSslNecessary() && !this.sslCompleted) {
            throw new R2dbcPermissionDeniedException(HandshakeExchangeable.authFails(authProvider.getType(), "handshake"), CLI_SPECIFIC);
        }
        byte[] authorization = authProvider.authentication(this.password, this.salt, this.client.getContext().getClientCollation());
        String authType = authProvider.getType();
        if ("".equals(authType)) {
            authType = "caching_sha2_password";
        }
        return HandshakeResponse.from(capability, this.client.getContext().getClientCollation().getId(), this.user, authorization, authType, this.database, ATTRIBUTES, this.zstdCompressionLevel);
    }

    private boolean isZstdAllowed(Capability capability) {
        return capability.isZstdCompression() && this.compressions.contains((Object)CompressionAlgorithm.ZSTD);
    }

    private boolean isZlibAllowed(Capability capability) {
        return capability.isZlibCompression() && this.compressions.contains((Object)CompressionAlgorithm.ZLIB);
    }

    private static String authFails(String authType, String phase) {
        return "Authentication type '" + authType + "' must require SSL in " + phase + " phase";
    }

    private static boolean isZstdSupported() {
        try {
            ClassLoader loader = AccessController.doPrivileged(() -> {
                ClassLoader cl = Thread.currentThread().getContextClassLoader();
                return cl == null ? ClassLoader.getSystemClassLoader() : cl;
            });
            Class.forName("com.github.luben.zstd.Zstd", false, loader);
            return true;
        }
        catch (ClassNotFoundException e) {
            return false;
        }
    }
}

