/*
 * Decompiled with CFR 0.152.
 */
package io.r2dbc.mssql.message.token;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.r2dbc.mssql.message.ClientMessage;
import io.r2dbc.mssql.message.header.HeaderOptions;
import io.r2dbc.mssql.message.header.Status;
import io.r2dbc.mssql.message.header.Type;
import io.r2dbc.mssql.message.tds.ContextualTdsFragment;
import io.r2dbc.mssql.message.tds.Decode;
import io.r2dbc.mssql.message.tds.Encode;
import io.r2dbc.mssql.message.tds.ProtocolException;
import io.r2dbc.mssql.message.tds.TdsFragment;
import io.r2dbc.mssql.message.token.TokenStream;
import io.r2dbc.mssql.util.Assert;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import reactor.util.annotation.Nullable;

public final class Prelogin
implements TokenStream,
ClientMessage {
    private static final HeaderOptions HEADER_OPTIONS = HeaderOptions.create(Type.PRE_LOGIN, Status.of(Status.StatusBit.EOM));
    private final List<? extends Token> tokens;

    public Prelogin(List<? extends Token> tokens) {
        Assert.requireNonNull(tokens, "Tokens must not be null");
        this.tokens = tokens;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Prelogin decode(ByteBuf buffer) {
        Assert.requireNonNull(buffer, "ByteBuf must not be null");
        ArrayList<Token> decodedTokens = new ArrayList<Token>();
        Prelogin prelogin = new Prelogin(decodedTokens);
        TokenDecodingState decodingState = TokenDecodingState.create(buffer);
        while (decodingState.canDecode()) {
            byte type = Decode.asByte(buffer);
            if (type == -1) {
                decodedTokens.add(Terminator.INSTANCE);
                break;
            }
            if (type == 0) {
                decodedTokens.add(Version.decode(decodingState));
                continue;
            }
            if (type == 1) {
                decodedTokens.add(Encryption.decode(decodingState));
                continue;
            }
            if (type == 2) {
                decodedTokens.add(InstanceValidation.decode(decodingState));
                continue;
            }
            decodedTokens.add(UnknownToken.decode(type, decodingState));
        }
        buffer.skipBytes(buffer.readableBytes());
        return prelogin;
    }

    public List<? extends Token> getTokens() {
        return this.tokens;
    }

    public <T extends Token> Optional<T> getToken(Class<? extends T> tokenType) {
        Assert.requireNonNull(tokenType, "Token type must not be null");
        for (Token token : this.tokens) {
            if (!tokenType.isInstance(token)) continue;
            return Optional.of(tokenType.cast(token));
        }
        return Optional.empty();
    }

    public <T extends Token> T getRequiredToken(Class<? extends T> tokenType) {
        Assert.requireNonNull(tokenType, "Token type must not be null");
        return (T)((Token)this.getToken(tokenType).orElseThrow(() -> new NoSuchElementException(String.format("No token of type [%s] available", tokenType.getName()))));
    }

    @Override
    public String getName() {
        return "PRELOGIN";
    }

    @Override
    public TdsFragment encode(ByteBufAllocator allocator, int packetSize) {
        Assert.requireNonNull(allocator, "ByteBufAllocator must not be null");
        ByteBuf buffer = allocator.buffer(Prelogin.getSize(this.tokens));
        this.encode(buffer);
        return new ContextualTdsFragment(HEADER_OPTIONS, buffer);
    }

    void encode(ByteBuf buffer) {
        int tokenHeaderLength = 0;
        for (Token token : this.tokens) {
            tokenHeaderLength += token.getTokenHeaderLength();
        }
        int position = tokenHeaderLength;
        for (Token token : this.tokens) {
            token.encodeToken(buffer, position);
            position += token.getDataLength();
        }
        for (Token token : this.tokens) {
            token.encodeStream(buffer);
            position += token.getDataLength();
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof Prelogin)) {
            return false;
        }
        Prelogin prelogin = (Prelogin)o;
        return Objects.equals(this.tokens, prelogin.tokens);
    }

    public int hashCode() {
        return Objects.hash(this.tokens);
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append(this.getName());
        sb.append(" [tokens=").append(this.tokens);
        sb.append(']');
        return sb.toString();
    }

    private static int getSize(List<? extends Token> tokens) {
        int size = 8;
        for (Token token : tokens) {
            size += token.getTotalLength();
        }
        return size;
    }

    static class TokenDecodingState {
        ByteBuf buffer;
        int initialReaderIndex;
        int readPositionOffset;

        TokenDecodingState(ByteBuf buffer) {
            this.initialReaderIndex = buffer.readerIndex();
            this.buffer = buffer;
        }

        public static TokenDecodingState create(ByteBuf byteBuf) {
            return new TokenDecodingState(byteBuf);
        }

        public boolean canDecode() {
            return this.buffer.readableBytes() > 0;
        }

        void afterTokenDecoded() {
            this.readPositionOffset = this.buffer.readerIndex() - this.initialReaderIndex;
        }

        ByteBuf readBody(int position, short length) {
            this.buffer.skipBytes(position - 5 - this.readPositionOffset);
            ByteBuf data = this.buffer.alloc().buffer((int)length);
            this.buffer.readBytes(data, (int)length);
            return data;
        }
    }

    @FunctionalInterface
    static interface LengthValidator {
        public static final LengthValidator IGNORE = length -> {};

        public void validate(short var1);

        public static LengthValidator ignore() {
            return IGNORE;
        }
    }

    @FunctionalInterface
    static interface DecodeFunction<T> {
        public T decode(short var1, ByteBuf var2);
    }

    public static class UnknownToken
    extends Token {
        UnknownToken(int type, int length) {
            super(type, length);
        }

        public static UnknownToken decode(byte type, TokenDecodingState toDecode) {
            return UnknownToken.decode(toDecode, LengthValidator.ignore(), (length, body) -> new UnknownToken(type, length));
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }
    }

    public static class TraceId
    extends Token {
        @Nullable
        private final UUID connectionId;
        @Nullable
        private final UUID activityId;
        private final int activitySequence;

        public TraceId(@Nullable UUID connectionId, @Nullable UUID activityId, int activitySequence) {
            super(5, 36);
            this.connectionId = connectionId;
            this.activityId = activityId;
            this.activitySequence = activitySequence;
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
            if (this.connectionId != null) {
                buffer.writeLong(this.connectionId.getMostSignificantBits());
                buffer.writeLong(this.connectionId.getLeastSignificantBits());
            } else {
                buffer.writeLong(0L);
                buffer.writeLong(0L);
            }
            if (this.activityId != null) {
                buffer.writeLong(this.activityId.getMostSignificantBits());
                buffer.writeLong(this.activityId.getLeastSignificantBits());
            } else {
                buffer.writeLong(0L);
                buffer.writeLong(0L);
            }
            Encode.intBigEndian(buffer, this.activitySequence);
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" [connectionId=").append(this.connectionId);
            sb.append(", activityId=").append(this.activityId);
            sb.append(", activitySequence=").append(this.activitySequence);
            sb.append(']');
            return sb.toString();
        }
    }

    public static class ThreadId
    extends Token {
        public static final byte TYPE = 3;
        private final int threadId;

        public ThreadId(int threadId) {
            super(3, 4);
            this.threadId = threadId;
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
            buffer.writeInt(this.threadId);
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" [threadId=").append(this.threadId);
            sb.append(']');
            return sb.toString();
        }
    }

    public static class Encryption
    extends Token {
        public static final byte TYPE = 1;
        public static final byte ENCRYPT_OFF = 0;
        public static final byte ENCRYPT_ON = 1;
        public static final byte ENCRYPT_NOT_SUP = 2;
        public static final byte ENCRYPT_REQ = 3;
        private final byte encryption;

        public Encryption(byte encryption) {
            super(1, 1);
            this.encryption = encryption;
        }

        public static Encryption decode(TokenDecodingState toDecode) {
            return Encryption.decode(toDecode, length -> {
                if (length != 1) {
                    throw ProtocolException.invalidTds(String.format("Invalid encryption length: %s", length));
                }
            }, (length, body) -> {
                byte encryption = Decode.asByte(body);
                return new Encryption(encryption);
            });
        }

        public byte getEncryption() {
            return this.encryption;
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
            Encode.asByte(buffer, this.encryption);
        }

        public boolean requiresSslHandshake() {
            return this.getEncryption() == 3 || this.getEncryption() == 0 || this.getEncryption() == 1;
        }

        public boolean requiresLoginSslHandshake() {
            return this.getEncryption() == 0;
        }

        public boolean requiresConnectionSslHandshake() {
            return this.getEncryption() == 1 || this.getEncryption() == 3;
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" [encryption=").append(this.encryption);
            sb.append(']');
            return sb.toString();
        }
    }

    public static class InstanceValidation
    extends Token {
        static final String MSSQLSERVER_VALUE = "MSSQLServer";
        public static final byte TYPE = 2;
        private final byte[] instanceName;

        public InstanceValidation(String instanceName) {
            this(InstanceValidation.toBytes(instanceName));
        }

        private InstanceValidation(byte[] instanceName) {
            super(2, Assert.requireNonNull(instanceName, "Instance name must not be null").length);
            this.instanceName = instanceName;
        }

        public static InstanceValidation decode(TokenDecodingState toDecode) {
            return InstanceValidation.decode(toDecode, LengthValidator.ignore(), (length, body) -> {
                byte[] validation = new byte[length];
                body.readBytes(validation, 0, (int)length);
                return new InstanceValidation(validation);
            });
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
            buffer.writeBytes(this.instanceName);
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" [instanceName=").append(this.instanceName == null ? "null" : new String(this.instanceName));
            sb.append(']');
            return sb.toString();
        }

        private static byte[] toBytes(String instanceName) {
            Assert.requireNonNull(instanceName, "Instance name must not be null");
            byte[] name = instanceName.getBytes(StandardCharsets.UTF_8);
            byte[] result = new byte[name.length + 1];
            System.arraycopy(name, 0, result, 0, name.length);
            return result;
        }
    }

    public static class Version
    extends Token {
        public static final byte TYPE = 0;
        private final int version;
        private final short subbuild;

        public Version(int version, int subbuild) {
            this(version, (byte)subbuild);
        }

        public Version(int version, short subbuild) {
            super(0, 6);
            this.version = version;
            this.subbuild = subbuild;
        }

        public static Version decode(TokenDecodingState toDecode) {
            return Version.decode(toDecode, length -> {
                if (length != 6) {
                    throw ProtocolException.invalidTds(String.format("Invalid version length: %s", length));
                }
            }, (length, body) -> {
                byte major = Decode.asByte(body);
                byte minor = Decode.asByte(body);
                short build = body.readShort();
                return new Version((int)major, build);
            });
        }

        public int getVersion() {
            return this.version;
        }

        public short getSubbuild() {
            return this.subbuild;
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
            Encode.dword(buffer, this.version);
            Encode.shortBE(buffer, this.subbuild);
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" [version=").append(this.version);
            sb.append(", subbuild=").append(this.subbuild);
            sb.append(']');
            return sb.toString();
        }
    }

    public static class Terminator
    extends Token {
        public static final Terminator INSTANCE = new Terminator();
        public static final byte TYPE = -1;

        Terminator() {
            super(-1, 0);
        }

        @Override
        public void encodeToken(ByteBuf buffer, int position) {
            buffer.writeByte((int)this.getType());
        }

        @Override
        public int getTokenHeaderLength() {
            return 1;
        }

        @Override
        public void encodeStream(ByteBuf buffer) {
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            sb.append(this.getClass().getSimpleName());
            sb.append(" []");
            return sb.toString();
        }
    }

    public static abstract class Token {
        private byte type;
        private int length;

        Token(int type, int length) {
            if (type > 127) {
                throw new IllegalArgumentException("Type " + type + " exceeds byte value");
            }
            this.type = (byte)type;
            this.length = length;
        }

        static <T extends Token> T decode(TokenDecodingState toDecode, LengthValidator validator, DecodeFunction<T> decoder) {
            Assert.requireNonNull(toDecode, "TokenDecodingState must not be null");
            Assert.requireNonNull(validator, "LengthValidator must not be null");
            Assert.requireNonNull(decoder, "DecodeFunction must not be null");
            ByteBuf buffer = toDecode.buffer;
            short position = buffer.readShort();
            short length = buffer.readShort();
            validator.validate(length);
            buffer.markReaderIndex();
            ByteBuf data = toDecode.readBody(position, length);
            Token result = (Token)decoder.decode(length, data);
            data.release();
            buffer.resetReaderIndex();
            toDecode.afterTokenDecoded();
            return (T)result;
        }

        public void encodeToken(ByteBuf buffer, int position) {
            Encode.asByte(buffer, this.type);
            Encode.uShortBE(buffer, position);
            Encode.uShortBE(buffer, this.length);
        }

        byte getType() {
            return this.type;
        }

        int getLength() {
            return this.length;
        }

        public abstract void encodeStream(ByteBuf var1);

        final int getTotalLength() {
            return this.getDataLength() + this.getTokenHeaderLength();
        }

        int getTokenHeaderLength() {
            return 5;
        }

        int getDataLength() {
            return this.length;
        }
    }

    public static class Builder {
        private Integer threadId;
        @Nullable
        private UUID connectionId;
        @Nullable
        private UUID activityId;
        private int activitySequence;
        private byte encryption = 0;
        private String instanceName = "MSSQLServer";

        private Builder() {
        }

        public Builder withConnectionId(UUID connectionId) {
            Assert.requireNonNull(connectionId, "ConnectionID must not be null");
            this.connectionId = connectionId;
            return this;
        }

        public Builder withActivityId(UUID activityId) {
            Assert.requireNonNull(activityId, "Activity ID must not be null");
            this.activityId = activityId;
            return this;
        }

        public Builder withActivitySequence(int activitySequence) {
            this.activitySequence = activitySequence;
            return this;
        }

        public Builder withThreadId(int threadId) {
            this.threadId = threadId;
            return this;
        }

        public Builder withEncryptionDisabled() {
            this.encryption = 0;
            return this;
        }

        public Builder withEncryptionEnabled() {
            this.encryption = 1;
            return this;
        }

        public Builder withEncryptionNotSupported() {
            this.encryption = (byte)2;
            return this;
        }

        public Builder withInstanceName(String instanceName) {
            Assert.requireNonNull(instanceName, "Instance name must not be null");
            this.instanceName = instanceName;
            return this;
        }

        public Prelogin build() {
            ArrayList<Token> tokens = new ArrayList<Token>();
            tokens.add(new Version(0, 0));
            tokens.add(new Encryption(this.encryption));
            tokens.add(new InstanceValidation(this.instanceName));
            if (this.threadId != null) {
                tokens.add(new ThreadId(this.threadId));
            }
            if (this.connectionId != null) {
                tokens.add(new TraceId(this.connectionId, this.activityId, this.activitySequence));
            }
            tokens.add(Terminator.INSTANCE);
            return new Prelogin(tokens);
        }
    }
}

