/*
 * Decompiled with CFR 0.152.
 */
package com.rethinkdb.net;

import com.rethinkdb.gen.exc.ReqlAuthError;
import com.rethinkdb.gen.exc.ReqlDriverError;
import com.rethinkdb.gen.proto.Protocol;
import com.rethinkdb.gen.proto.Version;
import com.rethinkdb.net.ConnectionSocket;
import com.rethinkdb.utils.Internals;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.util.Arrays;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;
import javax.crypto.Mac;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import org.jetbrains.annotations.Nullable;

class HandshakeProtocol {
    private static final HandshakeProtocol FINISHED = new HandshakeProtocol();
    public static final Version VERSION = Version.V1_0;
    public static final Long SUB_PROTOCOL_VERSION = 0L;
    public static final Protocol PROTOCOL = Protocol.JSON;
    public static final String CLIENT_KEY = "Client Key";
    public static final String SERVER_KEY = "Server Key";

    private HandshakeProtocol() {
    }

    static void doHandshake(ConnectionSocket socket, String username, String password, Long timeout) {
        for (HandshakeProtocol handshake = new WaitingForProtocolRange(username, password); handshake != FINISHED; handshake = ((HandshakeProtocol)handshake).nextState(socket.readCString(timeout))) {
            ByteBuffer toWrite = ((HandshakeProtocol)handshake).toSend();
            if (toWrite == null) continue;
            socket.write(toWrite);
        }
    }

    @Nullable
    protected ByteBuffer toSend() {
        throw new IllegalStateException();
    }

    protected HandshakeProtocol nextState(String response) {
        throw new IllegalStateException();
    }

    private static void throwIfFailure(Map<String, Object> json) {
        if (!((Boolean)json.get("success")).booleanValue()) {
            Long errorCode = (Long)json.get("error_code");
            if (errorCode >= 10L && errorCode <= 20L) {
                throw new ReqlAuthError((String)json.get("error"));
            }
            throw new ReqlDriverError((String)json.get("error"));
        }
    }

    static class ScramAttributes {
        @Nullable
        String _authIdentity;
        @Nullable
        String _username;
        @Nullable
        String _nonce;
        @Nullable
        String _headerAndChannelBinding;
        @Nullable
        byte[] _salt;
        @Nullable
        Integer _iterationCount;
        @Nullable
        String _clientProof;
        @Nullable
        byte[] _serverSignature;
        @Nullable
        String _error;
        @Nullable
        String _originalString;

        ScramAttributes() {
        }

        static ScramAttributes from(ScramAttributes other) {
            ScramAttributes out = new ScramAttributes();
            out._authIdentity = other._authIdentity;
            out._username = other._username;
            out._nonce = other._nonce;
            out._headerAndChannelBinding = other._headerAndChannelBinding;
            out._salt = other._salt;
            out._iterationCount = other._iterationCount;
            out._clientProof = other._clientProof;
            out._serverSignature = other._serverSignature;
            out._error = other._error;
            return out;
        }

        static ScramAttributes from(String input) {
            ScramAttributes sa = new ScramAttributes();
            sa._originalString = input;
            for (String section : input.split(",")) {
                String[] keyVal = section.split("=", 2);
                sa.setAttribute(keyVal[0], keyVal[1]);
            }
            return sa;
        }

        private void setAttribute(String key, String val) {
            switch (key) {
                case "a": {
                    this._authIdentity = val;
                    break;
                }
                case "n": {
                    this._username = val;
                    break;
                }
                case "r": {
                    this._nonce = val;
                    break;
                }
                case "m": {
                    throw new ReqlAuthError("m field disallowed");
                }
                case "c": {
                    this._headerAndChannelBinding = val;
                    break;
                }
                case "s": {
                    this._salt = Base64.getDecoder().decode(val);
                    break;
                }
                case "i": {
                    this._iterationCount = Integer.parseInt(val);
                    break;
                }
                case "p": {
                    this._clientProof = val;
                    break;
                }
                case "v": {
                    this._serverSignature = Base64.getDecoder().decode(val);
                    break;
                }
                case "e": {
                    this._error = val;
                    break;
                }
            }
        }

        public String toString() {
            if (this._originalString != null) {
                return this._originalString;
            }
            StringJoiner j = new StringJoiner(",");
            if (this._username != null) {
                j.add("n=" + this._username);
            }
            if (this._nonce != null) {
                j.add("r=" + this._nonce);
            }
            if (this._headerAndChannelBinding != null) {
                j.add("c=" + this._headerAndChannelBinding);
            }
            if (this._clientProof != null) {
                j.add("p=" + this._clientProof);
            }
            return j.toString();
        }

        ScramAttributes username(String username) {
            ScramAttributes next = ScramAttributes.from(this);
            next._username = username.replace("=", "=3D").replace(",", "=2C");
            return next;
        }

        ScramAttributes nonce(String nonce) {
            ScramAttributes next = ScramAttributes.from(this);
            next._nonce = nonce;
            return next;
        }

        ScramAttributes headerAndChannelBinding(String hacb) {
            ScramAttributes next = ScramAttributes.from(this);
            next._headerAndChannelBinding = hacb;
            return next;
        }

        ScramAttributes clientProof(byte[] clientProof) {
            ScramAttributes next = ScramAttributes.from(this);
            next._clientProof = Base64.getEncoder().encodeToString(clientProof);
            return next;
        }
    }

    static class WaitingForAuthSuccess
    extends HandshakeProtocol {
        private final byte[] serverSignature;
        private final ScramAttributes auth;

        public WaitingForAuthSuccess(byte[] serverSignature, ScramAttributes auth) {
            this.serverSignature = serverSignature;
            this.auth = auth;
        }

        @Override
        public ByteBuffer toSend() {
            byte[] authJson = ("{\"authentication\":\"" + this.auth + "\"}").getBytes(StandardCharsets.UTF_8);
            return ByteBuffer.allocate(authJson.length + 1).order(ByteOrder.LITTLE_ENDIAN).put(authJson).put(new byte[1]);
        }

        @Override
        public HandshakeProtocol nextState(String response) {
            Map<String, Object> json = Internals.readJson(response);
            HandshakeProtocol.throwIfFailure(json);
            ScramAttributes auth = ScramAttributes.from((String)json.get("authentication"));
            if (!MessageDigest.isEqual(auth._serverSignature, this.serverSignature)) {
                throw new ReqlAuthError("Invalid server signature");
            }
            return FINISHED;
        }
    }

    static class WaitingForAuthResponse
    extends HandshakeProtocol {
        private final String nonce;
        private final byte[] password;
        private final ScramAttributes clientFirstMessageBare;

        WaitingForAuthResponse(String nonce, byte[] password, ScramAttributes clientFirstMessageBare) {
            this.nonce = nonce;
            this.password = password;
            this.clientFirstMessageBare = clientFirstMessageBare;
        }

        @Override
        public ByteBuffer toSend() {
            return null;
        }

        @Override
        public HandshakeProtocol nextState(String response) {
            Map<String, Object> json = Internals.readJson(response);
            HandshakeProtocol.throwIfFailure(json);
            ScramAttributes serverScram = ScramAttributes.from((String)json.get("authentication"));
            if (!Objects.requireNonNull(serverScram._nonce).startsWith(this.nonce)) {
                throw new ReqlAuthError("Invalid nonce from server");
            }
            ScramAttributes clientScram = new ScramAttributes().headerAndChannelBinding("biws").nonce(serverScram._nonce);
            byte[] saltedPassword = PBKDF2.compute(this.password, serverScram._salt, serverScram._iterationCount);
            byte[] clientKey = WaitingForAuthResponse.hmac(saltedPassword, HandshakeProtocol.CLIENT_KEY);
            byte[] storedKey = WaitingForAuthResponse.sha256(clientKey);
            String authMessage = this.clientFirstMessageBare + "," + serverScram + "," + clientScram;
            byte[] clientSignature = WaitingForAuthResponse.hmac(storedKey, authMessage);
            byte[] clientProof = WaitingForAuthResponse.xor(clientKey, clientSignature);
            byte[] serverKey = WaitingForAuthResponse.hmac(saltedPassword, HandshakeProtocol.SERVER_KEY);
            byte[] serverSignature = WaitingForAuthResponse.hmac(serverKey, authMessage);
            return new WaitingForAuthSuccess(serverSignature, clientScram.clientProof(clientProof));
        }

        static byte[] sha256(byte[] clientKey) {
            try {
                return MessageDigest.getInstance("SHA-256").digest(clientKey);
            }
            catch (NoSuchAlgorithmException e) {
                throw new ReqlDriverError(e);
            }
        }

        static byte[] hmac(byte[] key, String string) {
            try {
                Mac mac = Mac.getInstance("HmacSHA256");
                mac.init(new SecretKeySpec(key, "HmacSHA256"));
                return mac.doFinal(string.getBytes(StandardCharsets.UTF_8));
            }
            catch (InvalidKeyException | NoSuchAlgorithmException e) {
                throw new ReqlDriverError(e);
            }
        }

        static byte[] xor(byte[] a, byte[] b) {
            if (a.length != b.length) {
                throw new ReqlDriverError("arrays must be the same length");
            }
            byte[] result = new byte[a.length];
            for (int i = 0; i < result.length; ++i) {
                result[i] = (byte)(a[i] ^ b[i]);
            }
            return result;
        }

        private static class PBKDF2 {
            private static final Map<PBKDF2, byte[]> cache = new ConcurrentHashMap<PBKDF2, byte[]>();
            final byte[] password;
            final byte[] salt;
            final int iterations;

            static byte[] compute(byte[] password, byte[] salt, Integer iterationCount) {
                return cache.computeIfAbsent(new PBKDF2(password, salt, iterationCount), PBKDF2::compute);
            }

            PBKDF2(byte[] password, byte[] salt, int iterations) {
                this.password = password;
                this.salt = salt;
                this.iterations = iterations;
            }

            public boolean equals(Object o) {
                if (this == o) {
                    return true;
                }
                if (o == null || this.getClass() != o.getClass()) {
                    return false;
                }
                PBKDF2 that = (PBKDF2)o;
                if (this.iterations != that.iterations) {
                    return false;
                }
                if (!Arrays.equals(this.password, that.password)) {
                    return false;
                }
                return Arrays.equals(this.salt, that.salt);
            }

            public int hashCode() {
                int result = Arrays.hashCode(this.password);
                result = 31 * result + Arrays.hashCode(this.salt);
                result = 31 * result + this.iterations;
                return result;
            }

            public byte[] compute() {
                PBEKeySpec spec = new PBEKeySpec(new String(this.password, StandardCharsets.UTF_8).toCharArray(), this.salt, this.iterations, 256);
                try {
                    return SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256").generateSecret(spec).getEncoded();
                }
                catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
                    throw new ReqlDriverError(e);
                }
            }
        }
    }

    static class WaitingForProtocolRange
    extends HandshakeProtocol {
        private static final SecureRandom secureRandom = new SecureRandom();
        private static final int NONCE_BYTES = 18;
        private final String nonce;
        private final ScramAttributes clientFirstMessageBare;
        private final byte[] password;

        WaitingForProtocolRange(String username, String password) {
            this.password = password.getBytes(StandardCharsets.UTF_8);
            this.nonce = WaitingForProtocolRange.makeNonce();
            this.clientFirstMessageBare = new ScramAttributes().username(username).nonce(this.nonce);
        }

        @Override
        public ByteBuffer toSend() {
            byte[] jsonBytes = ("{\"protocol_version\":" + SUB_PROTOCOL_VERSION + ",\"authentication_method\":\"SCRAM-SHA-256\",\"authentication\":\"n,," + this.clientFirstMessageBare + "\"}").getBytes(StandardCharsets.UTF_8);
            return ByteBuffer.allocate(4 + jsonBytes.length + 1).order(ByteOrder.LITTLE_ENDIAN).putInt(WaitingForProtocolRange.VERSION.value).put(jsonBytes).put(new byte[1]);
        }

        @Override
        public HandshakeProtocol nextState(String response) {
            Map<String, Object> json = Internals.readJson(response);
            HandshakeProtocol.throwIfFailure(json);
            long minVersion = (Long)json.get("min_protocol_version");
            long maxVersion = (Long)json.get("max_protocol_version");
            if (SUB_PROTOCOL_VERSION < minVersion || SUB_PROTOCOL_VERSION > maxVersion) {
                throw new ReqlDriverError("Unsupported protocol version " + SUB_PROTOCOL_VERSION + ", expected between " + minVersion + " and " + maxVersion);
            }
            return new WaitingForAuthResponse(this.nonce, this.password, this.clientFirstMessageBare);
        }

        static String makeNonce() {
            byte[] rawNonce = new byte[18];
            secureRandom.nextBytes(rawNonce);
            return Base64.getEncoder().encodeToString(rawNonce);
        }
    }
}

