/*
 * Decompiled with CFR 0.152.
 */
package dev.miku.r2dbc.mysql;

import dev.miku.r2dbc.mysql.ExceptionFactory;
import dev.miku.r2dbc.mysql.authentication.MySqlAuthProvider;
import dev.miku.r2dbc.mysql.client.Client;
import dev.miku.r2dbc.mysql.constant.SslMode;
import dev.miku.r2dbc.mysql.message.client.AuthResponse;
import dev.miku.r2dbc.mysql.message.client.ExchangeableMessage;
import dev.miku.r2dbc.mysql.message.client.HandshakeResponse;
import dev.miku.r2dbc.mysql.message.client.SslRequest;
import dev.miku.r2dbc.mysql.message.server.AuthMoreDataMessage;
import dev.miku.r2dbc.mysql.message.server.ChangeAuthMessage;
import dev.miku.r2dbc.mysql.message.server.ErrorMessage;
import dev.miku.r2dbc.mysql.message.server.HandshakeHeader;
import dev.miku.r2dbc.mysql.message.server.HandshakeRequest;
import dev.miku.r2dbc.mysql.message.server.OkMessage;
import dev.miku.r2dbc.mysql.message.server.ServerMessage;
import dev.miku.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
import dev.miku.r2dbc.mysql.util.AssertUtils;
import dev.miku.r2dbc.mysql.util.ConnectionContext;
import dev.miku.r2dbc.mysql.util.ServerVersion;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Mono;
import reactor.util.annotation.Nullable;

final class LoginFlow {
    private static final Logger logger = LoggerFactory.getLogger(LoginFlow.class);
    private static final Map<String, String> ATTRIBUTES = Collections.emptyMap();
    private static final int CURRENT_HANDSHAKE_VERSION = 10;
    private final Client client;
    private final ConnectionContext context;
    private final SslMode sslMode;
    private final String database;
    private volatile boolean sslCompleted = false;
    private volatile MySqlAuthProvider authProvider;
    private volatile String username;
    private volatile CharSequence password;
    private volatile byte[] salt;

    private LoginFlow(Client client, SslMode sslMode, String database, ConnectionContext context, String username, @Nullable CharSequence password) {
        this.client = AssertUtils.requireNonNull(client, "client must not be null");
        this.sslMode = AssertUtils.requireNonNull(sslMode, "sslMode must not be null");
        this.database = AssertUtils.requireNonNull(database, "database must not be null");
        this.context = AssertUtils.requireNonNull(context, "context must not be null");
        this.username = AssertUtils.requireNonNull(username, "username must not be null");
        this.password = password;
    }

    private void initHandshake(HandshakeRequest message) {
        HandshakeHeader header = message.getHeader();
        short handshakeVersion = header.getProtocolVersion();
        ServerVersion serverVersion = header.getServerVersion();
        if (handshakeVersion < 10) {
            logger.warn("The MySQL server use old handshake V{}, server version is {}, maybe most features are not available", (Object)handshakeVersion, (Object)serverVersion);
        }
        this.context.setConnectionId(header.getConnectionId());
        this.context.setServerVersion(serverVersion);
        this.context.setCapabilities(this.calculateClientCapabilities(message.getServerCapabilities()));
        this.authProvider = MySqlAuthProvider.build(message.getAuthType());
        this.salt = message.getSalt();
    }

    private boolean useSsl() {
        return (this.context.getCapabilities() & 0x800) != 0;
    }

    private void changeAuth(ChangeAuthMessage message) {
        this.authProvider = MySqlAuthProvider.build(message.getAuthType());
        this.salt = message.getSalt();
    }

    private SslRequest createSslRequest() {
        return SslRequest.from(this.context.getCapabilities(), this.context.getCollation().getId());
    }

    private MySqlAuthProvider getAndNextProvider() {
        MySqlAuthProvider authProvider = this.authProvider;
        this.authProvider = authProvider.next();
        return authProvider;
    }

    private Mono<HandshakeResponse> createHandshakeResponse() {
        return Mono.fromSupplier(() -> {
            MySqlAuthProvider authProvider = this.getAndNextProvider();
            if (authProvider.isSslNecessary() && !this.sslCompleted) {
                throw new R2dbcPermissionDeniedException(LoginFlow.formatAuthFails(authProvider.getType(), "handshake"), "HY000");
            }
            String username = this.username;
            if (username == null) {
                throw new IllegalStateException("username must not be null when login");
            }
            byte[] authorization = authProvider.authentication(this.password, this.salt, this.context.getCollation());
            String authType = authProvider.getType();
            if ("".equals(authType)) {
                authType = "caching_sha2_password";
            }
            return HandshakeResponse.from(this.context.getCapabilities(), this.context.getCollation().getId(), username, authorization, authType, this.database, ATTRIBUTES);
        });
    }

    private Mono<AuthResponse> createAuthResponse(String phase) {
        return Mono.fromSupplier(() -> {
            MySqlAuthProvider authProvider = this.getAndNextProvider();
            if (authProvider.isSslNecessary() && !this.sslCompleted) {
                throw new R2dbcPermissionDeniedException(LoginFlow.formatAuthFails(authProvider.getType(), phase), "HY000");
            }
            return new AuthResponse(authProvider.authentication(this.password, this.salt, this.context.getCollation()));
        });
    }

    private int calculateClientCapabilities(int serverCapabilities) {
        int clientCapabilities = serverCapabilities & 0x413FFA0F;
        if ((clientCapabilities & 0x800) == 0) {
            if (this.sslMode.requireSsl()) {
                String message = String.format("Server version '%s' does not support SSL but mode '%s' requires SSL", new Object[]{this.context.getServerVersion(), this.sslMode});
                throw new R2dbcPermissionDeniedException(message, "HY000");
            }
            if (this.sslMode.startSsl()) {
                this.client.sslUnsupported();
            }
        } else {
            if (!this.sslMode.startSsl()) {
                clientCapabilities &= 0xFFFFF7FF;
            }
            if (!this.sslMode.verifyCertificate()) {
                clientCapabilities &= 0xBFFFFFFF;
            }
        }
        if (this.database.isEmpty() && (clientCapabilities & 8) != 0) {
            clientCapabilities &= 0xFFFFFFF7;
        }
        if (ATTRIBUTES.isEmpty() && (clientCapabilities & 0x100000) != 0) {
            clientCapabilities &= 0xFFEFFFFF;
        }
        return clientCapabilities;
    }

    private void clearAuthentication() {
        this.username = null;
        this.password = null;
        this.salt = null;
        this.authProvider = null;
    }

    private void loginSuccess() {
        this.clearAuthentication();
        this.client.loginSuccess();
    }

    private void loginFailed() {
        this.clearAuthentication();
        this.client.forceClose().subscribe();
    }

    static Mono<Client> login(Client client, SslMode sslMode, String database, ConnectionContext context, String username, @Nullable CharSequence password) {
        LoginFlow flow = new LoginFlow(client, sslMode, database, context, username, password);
        EmitterProcessor stateMachine = EmitterProcessor.create((int)1, (boolean)true);
        stateMachine.onNext((Object)State.INIT);
        Consumer<State> onStateNext = state -> {
            if (state == State.COMPLETED) {
                logger.debug("Login succeed, cleanup intermediate variables");
                flow.loginSuccess();
                stateMachine.onComplete();
            } else {
                stateMachine.onNext((Object)state);
            }
        };
        Consumer<Throwable> onStateError = arg_0 -> ((EmitterProcessor)stateMachine).onError(arg_0);
        return stateMachine.doOnNext(state -> {
            logger.debug("Login state {} handling", (Object)state);
            state.handle(flow).subscribe(onStateNext, onStateError);
        }).doOnError(ignored -> flow.loginFailed()).then(Mono.just((Object)client));
    }

    private static String formatAuthFails(String authType, String phase) {
        return String.format("Authentication type '%s' must require SSL in %s phase", authType, phase);
    }

    private static enum State {
        INIT{

            @Override
            Mono<State> handle(LoginFlow flow) {
                return flow.client.receiveOnly().handle((message, sink) -> {
                    if (message instanceof ErrorMessage) {
                        sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
                    } else if (message instanceof HandshakeRequest) {
                        flow.initHandshake((HandshakeRequest)message);
                        if (flow.useSsl()) {
                            sink.next((Object)SSL);
                            sink.complete();
                        } else {
                            sink.next((Object)HANDSHAKE);
                            sink.complete();
                        }
                    } else {
                        sink.error((Throwable)new IllegalStateException(String.format("Unexpected message type '%s' in handshake init phase", message.getClass().getSimpleName())));
                    }
                });
            }
        }
        ,
        SSL{
            private final Predicate<ServerMessage> complete = message -> message instanceof ErrorMessage || message instanceof SyntheticSslResponseMessage;

            @Override
            Mono<State> handle(LoginFlow flow) {
                return flow.client.exchange(flow.createSslRequest(), this.complete).handle((message, sink) -> {
                    if (message instanceof ErrorMessage) {
                        sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
                    } else if (message instanceof SyntheticSslResponseMessage) {
                        flow.sslCompleted = true;
                        sink.next((Object)HANDSHAKE);
                    } else {
                        sink.error((Throwable)new IllegalStateException(String.format("Unexpected message type '%s' in SSL handshake phase", message.getClass().getSimpleName())));
                    }
                }).last();
            }
        }
        ,
        HANDSHAKE{
            private final Predicate<ServerMessage> complete = message -> message instanceof ErrorMessage || message instanceof OkMessage || message instanceof AuthMoreDataMessage && ((AuthMoreDataMessage)message).getAuthMethodData()[0] != 3 || message instanceof ChangeAuthMessage;

            @Override
            Mono<State> handle(LoginFlow flow) {
                return flow.createHandshakeResponse().flatMapMany(message -> flow.client.exchange((ExchangeableMessage)message, this.complete)).handle((message, sink) -> {
                    if (message instanceof ErrorMessage) {
                        sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
                    } else if (message instanceof OkMessage) {
                        sink.next((Object)COMPLETED);
                    } else if (message instanceof AuthMoreDataMessage) {
                        if (((AuthMoreDataMessage)message).getAuthMethodData()[0] != 3) {
                            if (logger.isDebugEnabled()) {
                                logger.debug("Connection (id {}) fast authentication failed, auto-try to use full authentication", (Object)flow.context.getConnectionId());
                            }
                            sink.next((Object)FULL_AUTH);
                        }
                    } else if (message instanceof ChangeAuthMessage) {
                        flow.changeAuth((ChangeAuthMessage)message);
                        sink.next((Object)CHANGE_AUTH);
                    } else {
                        sink.error((Throwable)new IllegalStateException(String.format("Unexpected message type '%s' in handshake response phase", message.getClass().getSimpleName())));
                    }
                }).last();
            }
        }
        ,
        CHANGE_AUTH{
            private final Predicate<ServerMessage> complete = message -> message instanceof ErrorMessage || message instanceof OkMessage || message instanceof AuthMoreDataMessage && ((AuthMoreDataMessage)message).getAuthMethodData()[0] != 3;

            @Override
            Mono<State> handle(LoginFlow flow) {
                return flow.createAuthResponse("change authentication").flatMapMany(response -> flow.client.exchange((ExchangeableMessage)response, this.complete)).handle((message, sink) -> {
                    if (message instanceof ErrorMessage) {
                        sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
                    } else if (message instanceof OkMessage) {
                        sink.next((Object)COMPLETED);
                    } else if (message instanceof AuthMoreDataMessage) {
                        if (((AuthMoreDataMessage)message).getAuthMethodData()[0] != 3) {
                            if (logger.isDebugEnabled()) {
                                logger.debug("Connection (id {}) fast authentication failed, auto-try to use full authentication", (Object)flow.context.getConnectionId());
                            }
                            sink.next((Object)FULL_AUTH);
                        }
                    } else {
                        sink.error((Throwable)new IllegalStateException(String.format("Unexpected message type '%s' in full authentication phase", message.getClass().getSimpleName())));
                    }
                }).last();
            }
        }
        ,
        FULL_AUTH{
            private final Predicate<ServerMessage> complete = message -> message instanceof ErrorMessage || message instanceof OkMessage;

            @Override
            Mono<State> handle(LoginFlow flow) {
                return flow.createAuthResponse("full authentication").flatMapMany(response -> flow.client.exchange((ExchangeableMessage)response, this.complete)).handle((message, sink) -> {
                    if (message instanceof ErrorMessage) {
                        sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
                    } else if (message instanceof OkMessage) {
                        sink.next((Object)COMPLETED);
                    } else {
                        sink.error((Throwable)new IllegalStateException(String.format("Unexpected message type '%s' in full authentication phase", message.getClass().getSimpleName())));
                    }
                }).last();
            }
        }
        ,
        COMPLETED{

            @Override
            Mono<State> handle(LoginFlow flow) {
                return Mono.just((Object)((Object)COMPLETED));
            }
        };


        abstract Mono<State> handle(LoginFlow var1);
    }
}

