/*
 * Decompiled with CFR 0.152.
 */
package it.auties.whatsapp.crypto;

import it.auties.curve25519.Curve25519;
import it.auties.whatsapp.controller.Keys;
import it.auties.whatsapp.crypto.AesCbc;
import it.auties.whatsapp.crypto.GroupCipher;
import it.auties.whatsapp.crypto.Hkdf;
import it.auties.whatsapp.crypto.Hmac;
import it.auties.whatsapp.crypto.SessionBuilder;
import it.auties.whatsapp.exception.HmacValidationException;
import it.auties.whatsapp.model.signal.keypair.SignalKeyPair;
import it.auties.whatsapp.model.signal.message.SignalMessage;
import it.auties.whatsapp.model.signal.message.SignalPreKeyMessage;
import it.auties.whatsapp.model.signal.session.Session;
import it.auties.whatsapp.model.signal.session.SessionAddress;
import it.auties.whatsapp.model.signal.session.SessionChain;
import it.auties.whatsapp.model.signal.session.SessionState;
import it.auties.whatsapp.util.BytesHelper;
import it.auties.whatsapp.util.KeyHelper;
import it.auties.whatsapp.util.Validate;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.NonNull;

public record SessionCipher(@NonNull SessionAddress address, @NonNull Keys keys) {
    public SessionCipher(@NonNull SessionAddress address, @NonNull Keys keys) {
        if (address == null) {
            throw new NullPointerException("address is marked non-null but is null");
        }
        if (keys == null) {
            throw new NullPointerException("keys is marked non-null but is null");
        }
    }

    public GroupCipher.CipheredMessageResult encrypt(byte[] data) {
        if (data == null) {
            return new GroupCipher.CipheredMessageResult(null, "unavailable");
        }
        SessionState currentState = this.loadSession().currentState().orElseThrow(() -> new NoSuchElementException("Missing session for address %s".formatted(this.address)));
        Validate.isTrue(this.keys.hasTrust(this.address, currentState.remoteIdentityKey()), "Untrusted key", SecurityException.class, new Object[0]);
        SessionChain chain = currentState.findChain(currentState.ephemeralKeyPair().encodedPublicKey()).orElseThrow(() -> new NoSuchElementException("Missing chain for %s".formatted(this.address)));
        this.fillMessageKeys(chain, chain.counter().get() + 1);
        byte[] currentKey = chain.messageKeys().get(chain.counter().get());
        byte[][] secrets = Hkdf.deriveSecrets(currentKey, "WhisperMessageKeys".getBytes(StandardCharsets.UTF_8));
        chain.messageKeys().remove(chain.counter().get());
        byte[] iv = Arrays.copyOf(secrets[2], 16);
        byte[] encrypted = AesCbc.encrypt(iv, data, secrets[0]);
        String encryptedMessageType = this.getMessageType(currentState);
        byte[] encryptedMessage = this.encrypt(currentState, chain, secrets[1], encrypted);
        return new GroupCipher.CipheredMessageResult(encryptedMessage, encryptedMessageType);
    }

    private String getMessageType(SessionState currentState) {
        return currentState.hasPreKey() ? "pkmsg" : "msg";
    }

    private byte[] encrypt(SessionState state, SessionChain chain, byte[] key, byte[] encrypted) {
        SignalMessage message = new SignalMessage(state.ephemeralKeyPair().encodedPublicKey(), chain.counter().get(), state.previousCounter(), encrypted, encodedMessage -> this.createMessageSignature(state, key, (byte[])encodedMessage));
        byte[] serializedMessage = message.serialized();
        if (!state.hasPreKey()) {
            return serializedMessage;
        }
        SignalPreKeyMessage preKeyMessage = new SignalPreKeyMessage(state.pendingPreKey().preKeyId(), state.pendingPreKey().baseKey(), this.keys.identityKeyPair().encodedPublicKey(), serializedMessage, this.keys.registrationId(), state.pendingPreKey().signedKeyId());
        return preKeyMessage.serialized();
    }

    private byte[] createMessageSignature(SessionState state, byte[] key, byte[] encodedMessage) {
        byte[] macInput = BytesHelper.concat(this.keys.identityKeyPair().encodedPublicKey(), state.remoteIdentityKey(), encodedMessage);
        byte[] sha256 = Hmac.calculateSha256(macInput, key);
        return Arrays.copyOfRange(sha256, 0, 8);
    }

    private void fillMessageKeys(SessionChain chain, int counter) {
        if (chain.counter().get() >= counter) {
            return;
        }
        Validate.isTrue(counter - chain.counter().get() <= 2000, "Message overflow: expected <= %s, got %s", 2000, counter - chain.counter().get());
        Validate.isTrue(chain.key().get() != null, "Closed chain", new Object[0]);
        byte[] messagesHmac = Hmac.calculateSha256(new byte[]{1}, chain.key().get());
        chain.messageKeys().put(chain.counter().get() + 1, messagesHmac);
        byte[] keyHmac = Hmac.calculateSha256(new byte[]{2}, chain.key().get());
        chain.key().set(keyHmac);
        chain.counter().getAndIncrement();
        this.fillMessageKeys(chain, counter);
    }

    public byte[] decrypt(SignalPreKeyMessage message) {
        Session session = this.loadSession(this::createSession);
        SessionBuilder builder = new SessionBuilder(this.address, this.keys);
        builder.createIncoming(session, message);
        SessionState state = session.findState(message.version(), message.baseKey()).orElseThrow(() -> new NoSuchElementException("Missing state"));
        return this.decrypt(message.signalMessage(), state);
    }

    private Optional<Session> createSession() {
        Session newSession = new Session();
        this.keys.putSession(this.address, newSession);
        return Optional.of(newSession);
    }

    public byte[] decrypt(SignalMessage message) {
        Session session = this.loadSession();
        return (byte[])session.states().stream().map(state -> this.tryDecrypt(message, (SessionState)state)).flatMap(Optional::stream).findFirst().orElseThrow(() -> new NoSuchElementException("Cannot decrypt message: no suitable session found"));
    }

    private Optional<byte[]> tryDecrypt(SignalMessage message, SessionState state) {
        try {
            Validate.isTrue(this.keys.hasTrust(this.address, state.remoteIdentityKey()), "Untrusted key", new Object[0]);
            return Optional.of(this.decrypt(message, state));
        }
        catch (Throwable throwable) {
            return Optional.empty();
        }
    }

    private byte[] decrypt(SignalMessage message, SessionState state) {
        this.maybeStepRatchet(message, state);
        SessionChain chain = state.findChain(message.ephemeralPublicKey()).orElseThrow(() -> new NoSuchElementException("Invalid chain"));
        this.fillMessageKeys(chain, message.counter());
        Validate.isTrue(chain.hasMessageKey(message.counter()), "Key used already or never filled", new Object[0]);
        byte[] messageKey = chain.messageKeys().get(message.counter());
        byte[][] secrets = Hkdf.deriveSecrets(messageKey, "WhisperMessageKeys".getBytes(StandardCharsets.UTF_8));
        byte[] hmacValue = BytesHelper.concat(state.remoteIdentityKey(), this.keys.identityKeyPair().encodedPublicKey(), message.serialized());
        byte[] hmacInput = Arrays.copyOfRange(hmacValue, 0, hmacValue.length - 8);
        byte[] hmacSha256 = Hmac.calculateSha256(hmacInput, secrets[1]);
        byte[] hmac = Arrays.copyOf(hmacSha256, 8);
        Validate.isTrue(Arrays.equals(message.signature(), hmac), "message_decryption", HmacValidationException.class, new Object[0]);
        byte[] iv = Arrays.copyOf(secrets[2], 16);
        byte[] plaintext = AesCbc.decrypt(iv, message.ciphertext(), secrets[0]);
        state.pendingPreKey(null);
        return plaintext;
    }

    private void maybeStepRatchet(SignalMessage message, SessionState state) {
        if (state.hasChain(message.ephemeralPublicKey())) {
            return;
        }
        Optional<SessionChain> previousRatchet = state.findChain(state.lastRemoteEphemeralKey());
        previousRatchet.ifPresent(chain -> {
            this.fillMessageKeys((SessionChain)chain, state.previousCounter());
            chain.key().set(null);
        });
        this.calculateRatchet(message, state, false);
        Optional<SessionChain> previousCounter = state.findChain(state.ephemeralKeyPair().encodedPublicKey());
        previousCounter.ifPresent(chain -> {
            state.previousCounter(chain.counter().get());
            state.removeChain(state.ephemeralKeyPair().encodedPublicKey());
        });
        state.ephemeralKeyPair(SignalKeyPair.random());
        this.calculateRatchet(message, state, true);
        state.lastRemoteEphemeralKey(message.ephemeralPublicKey());
    }

    private void calculateRatchet(SignalMessage message, SessionState state, boolean sending) {
        byte[] sharedSecret = Curve25519.sharedKey((byte[])KeyHelper.withoutHeader(message.ephemeralPublicKey()), (byte[])state.ephemeralKeyPair().privateKey());
        byte[][] masterKey = Hkdf.deriveSecrets(sharedSecret, state.rootKey(), "WhisperRatchet".getBytes(StandardCharsets.UTF_8), 2);
        byte[] chainKey = sending ? state.ephemeralKeyPair().encodedPublicKey() : message.ephemeralPublicKey();
        state.addChain(chainKey, new SessionChain(-1, masterKey[1]));
        state.rootKey(masterKey[0]);
    }

    private Session loadSession() {
        return this.loadSession(() -> this.keys.findSessionByAddress(new SessionAddress(this.address.name(), 0)));
    }

    private Session loadSession(Supplier<Optional<Session>> defaultSupplier) {
        return this.keys.findSessionByAddress(this.address).or(defaultSupplier).orElseThrow(() -> new NoSuchElementException("Missing session for: %s".formatted(this.address)));
    }
}

