/*
 * Decompiled with CFR 0.152.
 */
package dev.miku.r2dbc.mysql;

import dev.miku.r2dbc.mysql.ConnectionContext;
import dev.miku.r2dbc.mysql.ExceptionFactory;
import dev.miku.r2dbc.mysql.ServerVersion;
import dev.miku.r2dbc.mysql.authentication.MySqlAuthProvider;
import dev.miku.r2dbc.mysql.client.Client;
import dev.miku.r2dbc.mysql.constant.SslMode;
import dev.miku.r2dbc.mysql.message.client.AuthResponse;
import dev.miku.r2dbc.mysql.message.client.ClientMessage;
import dev.miku.r2dbc.mysql.message.client.HandshakeResponse;
import dev.miku.r2dbc.mysql.message.client.SslRequest;
import dev.miku.r2dbc.mysql.message.server.AuthMoreDataMessage;
import dev.miku.r2dbc.mysql.message.server.ChangeAuthMessage;
import dev.miku.r2dbc.mysql.message.server.ErrorMessage;
import dev.miku.r2dbc.mysql.message.server.HandshakeHeader;
import dev.miku.r2dbc.mysql.message.server.HandshakeRequest;
import dev.miku.r2dbc.mysql.message.server.OkMessage;
import dev.miku.r2dbc.mysql.message.server.ServerMessage;
import dev.miku.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import java.util.Collections;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.SynchronousSink;
import reactor.util.annotation.Nullable;

final class InitHandler
implements BiConsumer<ServerMessage, SynchronousSink<Void>>,
Predicate<ServerMessage> {
    private static final Logger logger = LoggerFactory.getLogger(InitHandler.class);
    private static final Map<String, String> ATTRIBUTES = Collections.emptyMap();
    private static final String CLI_SPECIFIC = "HY000";
    private static final int HANDSHAKE_VERSION = 10;
    private final EmitterProcessor<ClientMessage> requests;
    private final Client client;
    private final SslMode sslMode;
    private final String database;
    private final String user;
    @Nullable
    private final CharSequence password;
    private final ConnectionContext context;
    private boolean handshake = true;
    private MySqlAuthProvider authProvider;
    private byte[] salt;
    private boolean sslCompleted;

    InitHandler(EmitterProcessor<ClientMessage> requests, Client client, SslMode sslMode, String database, String user, @Nullable CharSequence password, ConnectionContext context) {
        this.requests = requests;
        this.client = client;
        this.sslMode = sslMode;
        this.database = database;
        this.user = user;
        this.password = password;
        this.context = context;
        this.sslCompleted = sslMode == SslMode.TUNNEL;
    }

    @Override
    public void accept(ServerMessage message, SynchronousSink<Void> sink) {
        if (message instanceof ErrorMessage) {
            sink.error((Throwable)ExceptionFactory.createException((ErrorMessage)message, null));
            return;
        }
        if (this.handshake) {
            this.handshake = false;
            if (message instanceof HandshakeRequest) {
                int capabilities = this.initHandshake((HandshakeRequest)message);
                if ((capabilities & 0x800) == 0) {
                    this.requests.onNext((Object)this.createHandshakeResponse(capabilities));
                } else {
                    this.requests.onNext((Object)SslRequest.from(capabilities, this.context.getClientCollation().getId()));
                }
            } else {
                sink.error((Throwable)new R2dbcPermissionDeniedException("Unexpected message type '" + message.getClass().getSimpleName() + "' in init phase"));
            }
            return;
        }
        if (message instanceof OkMessage) {
            this.requests.onComplete();
            this.client.loginSuccess();
        } else if (message instanceof SyntheticSslResponseMessage) {
            this.sslCompleted = true;
            this.requests.onNext((Object)this.createHandshakeResponse(this.context.getCapabilities()));
        } else if (message instanceof AuthMoreDataMessage) {
            if (((AuthMoreDataMessage)message).isFailed()) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Connection (id {}) fast authentication failed, auto-try to use full authentication", (Object)this.context.getConnectionId());
                }
                this.requests.onNext((Object)this.createAuthResponse("full authentication"));
            }
        } else if (message instanceof ChangeAuthMessage) {
            ChangeAuthMessage msg = (ChangeAuthMessage)message;
            this.authProvider = MySqlAuthProvider.build(msg.getAuthType());
            this.salt = msg.getSalt();
            this.requests.onNext((Object)this.createAuthResponse("change authentication"));
        } else {
            sink.error((Throwable)new R2dbcPermissionDeniedException("Unexpected message type '" + message.getClass().getSimpleName() + "' in login phase"));
        }
    }

    @Override
    public boolean test(ServerMessage message) {
        return message instanceof ErrorMessage || message instanceof OkMessage;
    }

    private AuthResponse createAuthResponse(String phase) {
        MySqlAuthProvider authProvider = this.getAndNextProvider();
        if (authProvider.isSslNecessary() && !this.sslCompleted) {
            throw new R2dbcPermissionDeniedException(InitHandler.formatAuthFails(authProvider.getType(), phase), CLI_SPECIFIC);
        }
        return new AuthResponse(authProvider.authentication(this.password, this.salt, this.context.getClientCollation()));
    }

    private int clientCapabilities(int serverCapabilities) {
        int capabilities = serverCapabilities & 0x413FFA0F;
        if (this.sslMode == SslMode.TUNNEL) {
            capabilities &= 0xFFFFF7FF;
        } else if ((capabilities & 0x800) == 0) {
            if (this.sslMode.requireSsl()) {
                throw new R2dbcPermissionDeniedException("Server version '" + this.context.getServerVersion() + "' does not support SSL but mode '" + (Object)((Object)this.sslMode) + "' requires SSL", CLI_SPECIFIC);
            }
            if (this.sslMode.startSsl()) {
                this.client.sslUnsupported();
            }
        } else {
            if (!this.sslMode.startSsl()) {
                capabilities &= 0xFFFFF7FF;
            }
            if (!this.sslMode.verifyCertificate()) {
                capabilities &= 0xBFFFFFFF;
            }
        }
        if (this.database.isEmpty() && (capabilities & 8) != 0) {
            capabilities &= 0xFFFFFFF7;
        }
        if (ATTRIBUTES.isEmpty() && (capabilities & 0x100000) != 0) {
            capabilities &= 0xFFEFFFFF;
        }
        return capabilities;
    }

    private int initHandshake(HandshakeRequest message) {
        HandshakeHeader header = message.getHeader();
        short handshakeVersion = header.getProtocolVersion();
        ServerVersion serverVersion = header.getServerVersion();
        if (handshakeVersion < 10) {
            logger.warn("The MySQL server use old handshake V{}, server version is {}, maybe most features are not available", (Object)handshakeVersion, (Object)serverVersion);
        }
        int capabilities = this.clientCapabilities(message.getServerCapabilities());
        this.context.init(header.getConnectionId(), serverVersion, capabilities);
        this.authProvider = MySqlAuthProvider.build(message.getAuthType());
        this.salt = message.getSalt();
        return capabilities;
    }

    private MySqlAuthProvider getAndNextProvider() {
        MySqlAuthProvider authProvider = this.authProvider;
        this.authProvider = authProvider.next();
        return authProvider;
    }

    private HandshakeResponse createHandshakeResponse(int capabilities) {
        MySqlAuthProvider authProvider = this.getAndNextProvider();
        if (authProvider.isSslNecessary() && !this.sslCompleted) {
            throw new R2dbcPermissionDeniedException(InitHandler.formatAuthFails(authProvider.getType(), "handshake"), CLI_SPECIFIC);
        }
        byte[] authorization = authProvider.authentication(this.password, this.salt, this.context.getClientCollation());
        String authType = authProvider.getType();
        if ("".equals(authType)) {
            authType = "caching_sha2_password";
        }
        return HandshakeResponse.from(capabilities, this.context.getClientCollation().getId(), this.user, authorization, authType, this.database, ATTRIBUTES);
    }

    private static String formatAuthFails(String authType, String phase) {
        return "Authentication type '" + authType + "' must require SSL in " + phase + " phase";
    }
}

