/*
 * Decompiled with CFR 0.152.
 */
package io.asyncer.r2dbc.mysql.client;

import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.asyncer.r2dbc.mysql.client.CompressionDuplexCodec;
import io.asyncer.r2dbc.mysql.client.PacketEvent;
import io.asyncer.r2dbc.mysql.client.SslState;
import io.asyncer.r2dbc.mysql.client.WriteSubscriber;
import io.asyncer.r2dbc.mysql.client.ZlibCompressor;
import io.asyncer.r2dbc.mysql.client.ZstdCompressor;
import io.asyncer.r2dbc.mysql.internal.util.AssertUtils;
import io.asyncer.r2dbc.mysql.internal.util.OperatorUtils;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedFetchMessage;
import io.asyncer.r2dbc.mysql.message.client.SslRequest;
import io.asyncer.r2dbc.mysql.message.server.ColumnCountMessage;
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
import io.asyncer.r2dbc.mysql.message.server.DecodeContext;
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
import io.asyncer.r2dbc.mysql.message.server.PreparedOkMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessageDecoder;
import io.asyncer.r2dbc.mysql.message.server.ServerStatusMessage;
import io.asyncer.r2dbc.mysql.message.server.SyntheticMetadataMessage;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.SocketAddress;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.jetbrains.annotations.Nullable;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;

final class MessageDuplexCodec
extends ByteToMessageDecoder
implements ChannelOutboundHandler {
    static final String NAME = "R2dbcMySqlMessageDuplexCodec";
    private static final InternalLogger logger = InternalLoggerFactory.getInstance(MessageDuplexCodec.class);
    private final AtomicInteger sequenceId = new AtomicInteger(0);
    private DecodeContext decodeContext = DecodeContext.login();
    private final ConnectionContext context;
    private final ServerMessageDecoder decoder = new ServerMessageDecoder();
    private int frameLength = -1;

    MessageDuplexCodec(ConnectionContext context) {
        this.context = AssertUtils.requireNonNull(context, "context must not be null");
    }

    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
        DecodeContext context;
        ServerMessage message;
        ByteBuf frame = this.decode(in);
        if (frame != null && (message = this.decoder.decode(frame, this.context, context = this.decodeContext)) != null) {
            this.handleDecoded(out, message);
        }
    }

    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
        if (msg instanceof ClientMessage) {
            ByteBufAllocator allocator = ctx.alloc();
            ClientMessage message = (ClientMessage)msg;
            Flux encoded = Flux.from(message.encode(allocator, this.context));
            OperatorUtils.envelope((Flux<? extends ByteBuf>)encoded, allocator, this.sequenceId, message.isCumulative()).subscribe((CoreSubscriber)new WriteSubscriber(ctx, promise));
            if (msg instanceof PrepareQueryMessage) {
                this.setDecodeContext(DecodeContext.prepareQuery());
            } else if (msg instanceof PreparedFetchMessage) {
                this.setDecodeContext(DecodeContext.fetch());
            } else if (msg instanceof SslRequest) {
                ctx.channel().pipeline().fireUserEventTriggered((Object)SslState.BRIDGING);
            }
        } else {
            if (logger.isWarnEnabled()) {
                logger.warn("Unknown message type {} on writing", msg.getClass());
            }
            ReferenceCountUtil.release((Object)msg);
        }
    }

    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
        if (evt instanceof PacketEvent) {
            switch ((PacketEvent)((Object)evt)) {
                case RESET_SEQUENCE: {
                    logger.trace("Reset sequence id");
                    this.sequenceId.set(0);
                    break;
                }
                case USE_COMPRESSION: {
                    logger.trace("Reset sequence id");
                    this.sequenceId.set(0);
                    if (this.context.getCapability().isZstdCompression()) {
                        MessageDuplexCodec.enableZstdCompression(ctx);
                        break;
                    }
                    if (this.context.getCapability().isZlibCompression()) {
                        MessageDuplexCodec.enableZlibCompression(ctx);
                        break;
                    }
                    logger.warn("Unexpected event compression triggered, no capability found");
                    break;
                }
            }
        }
        ctx.fireUserEventTriggered(evt);
    }

    public void flush(ChannelHandlerContext ctx) {
        ctx.flush();
    }

    public void channelInactive(ChannelHandlerContext ctx) {
        this.decoder.dispose();
        ctx.fireChannelInactive();
    }

    @Nullable
    private ByteBuf decode(ByteBuf in) {
        if (this.frameLength == -1) {
            if (in.readableBytes() < 3) {
                return null;
            }
            this.frameLength = in.getUnsignedMediumLE(in.readerIndex()) + 4;
        }
        if (in.readableBytes() < this.frameLength) {
            return null;
        }
        in.skipBytes(3);
        short sequenceId = in.readUnsignedByte();
        ByteBuf frame = in.readRetainedSlice(this.frameLength - 4);
        logger.trace("Decoded frame with sequence id: {}, total size: {}", (Object)sequenceId, (Object)this.frameLength);
        this.sequenceId.set(sequenceId + 1);
        this.frameLength = -1;
        return frame;
    }

    private void handleDecoded(List<Object> out, ServerMessage msg) {
        if (msg instanceof ServerStatusMessage) {
            this.context.setServerStatuses(((ServerStatusMessage)msg).getServerStatuses());
        }
        if (msg instanceof CompleteMessage) {
            this.setDecodeContext(DecodeContext.command());
        } else if (msg instanceof SyntheticMetadataMessage) {
            if (((SyntheticMetadataMessage)msg).isCompleted()) {
                this.setDecodeContext(DecodeContext.command());
            }
        } else {
            if (msg instanceof ColumnCountMessage) {
                this.setDecodeContext(DecodeContext.result(this.context.getCapability().isEofDeprecated(), ((ColumnCountMessage)msg).getTotalColumns()));
                return;
            }
            if (msg instanceof PreparedOkMessage) {
                int parameters;
                PreparedOkMessage message = (PreparedOkMessage)msg;
                int columns = message.getTotalColumns();
                if (columns > -(parameters = message.getTotalParameters())) {
                    this.setDecodeContext(DecodeContext.preparedMetadata(this.context.getCapability().isEofDeprecated(), columns, parameters));
                } else {
                    this.setDecodeContext(DecodeContext.command());
                }
            } else if (msg instanceof ErrorMessage) {
                this.setDecodeContext(DecodeContext.command());
            }
        }
        out.add(msg);
    }

    private void setDecodeContext(DecodeContext context) {
        this.decodeContext = context;
        if (logger.isDebugEnabled()) {
            logger.debug("Decode context change to {}", (Object)context);
        }
    }

    public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
        ctx.bind(localAddress, promise);
    }

    public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
        ctx.connect(remoteAddress, localAddress, promise);
    }

    public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
        ctx.disconnect(promise);
    }

    public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
        ctx.close(promise);
    }

    public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
        ctx.deregister(promise);
    }

    public void read(ChannelHandlerContext ctx) {
        ctx.read();
    }

    private static void enableZstdCompression(ChannelHandlerContext ctx) {
        CompressionDuplexCodec handler = new CompressionDuplexCodec(new ZstdCompressor(3));
        if (ctx.pipeline().get("R2dbcMysqlCompressionDuplexCodec") != null) {
            logger.warn("Unexpected event, compression already enabled");
        } else {
            logger.debug("Compression zstd enabled for subsequent packets");
            ctx.pipeline().addBefore(NAME, "R2dbcMysqlCompressionDuplexCodec", (ChannelHandler)handler);
        }
    }

    private static void enableZlibCompression(ChannelHandlerContext ctx) {
        CompressionDuplexCodec handler = new CompressionDuplexCodec(new ZlibCompressor());
        if (ctx.pipeline().get("R2dbcMysqlCompressionDuplexCodec") != null) {
            logger.warn("Unexpected event, compression already enabled");
        } else {
            logger.debug("Compression zlib enabled for subsequent packets");
            ctx.pipeline().addBefore(NAME, "R2dbcMysqlCompressionDuplexCodec", (ChannelHandler)handler);
        }
    }
}

