/*
 * Decompiled with CFR 0.152.
 */
package org.iris_events.consumer;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.DeliverCallback;
import com.rabbitmq.client.Delivery;
import com.rabbitmq.client.Envelope;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
import io.quarkus.security.AuthenticationFailedException;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Context;
import io.vertx.core.Vertx;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.spi.CDI;
import java.lang.annotation.Annotation;
import java.lang.invoke.MethodHandle;
import java.security.Principal;
import java.util.Optional;
import org.iris_events.annotations.Message;
import org.iris_events.auth.IrisJwtValidator;
import org.iris_events.common.MDCEnricher;
import org.iris_events.context.EventContext;
import org.iris_events.context.IrisContext;
import org.iris_events.context.MethodHandleContext;
import org.iris_events.producer.EventProducer;
import org.iris_events.routing.RoutingDetailsProvider;
import org.iris_events.runtime.AnnotationValueExtractor;
import org.iris_events.runtime.IrisExceptionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DeliverCallbackProvider {
    private final EventContext eventContext;
    private final ObjectMapper objectMapper;
    private final EventProducer producer;
    private final IrisContext irisContext;
    private final Object eventHandlerInstance;
    private final MethodHandle methodHandle;
    private final MethodHandleContext methodHandleContext;
    private final IrisJwtValidator jwtValidator;
    private final IrisExceptionHandler errorHandler;
    private final RoutingDetailsProvider routingDetailsProvider;
    private static final Logger log = LoggerFactory.getLogger(DeliverCallbackProvider.class);

    public DeliverCallbackProvider(ObjectMapper objectMapper, EventProducer producer, IrisContext irisContext, EventContext eventContext, Object eventHandlerInstance, MethodHandle methodHandle, MethodHandleContext methodHandleContext, IrisJwtValidator jwtValidator, IrisExceptionHandler errorHandler, RoutingDetailsProvider routingDetailsProvider) {
        this.objectMapper = objectMapper;
        this.producer = producer;
        this.irisContext = irisContext;
        this.eventHandlerInstance = eventHandlerInstance;
        this.methodHandle = methodHandle;
        this.methodHandleContext = methodHandleContext;
        this.jwtValidator = jwtValidator;
        this.eventContext = eventContext;
        this.errorHandler = errorHandler;
        this.routingDetailsProvider = routingDetailsProvider;
    }

    public DeliverCallback createDeliverCallback(Channel channel) {
        return (consumerTag, message) -> {
            Context newDuplicatedContext = VertxContext.createNewDuplicatedContext();
            DeliverCallbackProvider.runOnContext(newDuplicatedContext, () -> this.handleMessage(channel, message));
        };
    }

    private static void runOnContext(Context context, Runnable runnable) {
        if (Vertx.currentContext() == context) {
            runnable.run();
        } else {
            context.runOnContext(x -> runnable.run());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleMessage(Channel channel, Delivery message) {
        ManagedContext requestContext = Arc.container().requestContext();
        if (requestContext.isActive()) {
            try {
                this.doInvoke(channel, message);
            }
            catch (Throwable e) {
                this.errorHandler.handleException(this.irisContext, message, channel, e);
            }
            finally {
                MDCEnricher.clear();
            }
        } else {
            requestContext.activate();
            try {
                this.doInvoke(channel, message);
            }
            catch (Throwable e) {
                this.errorHandler.handleException(this.irisContext, message, channel, e);
            }
            finally {
                requestContext.terminate();
                MDCEnricher.clear();
            }
        }
    }

    private void doInvoke(Channel channel, Delivery message) throws Throwable {
        AMQP.BasicProperties properties = message.getProperties();
        Envelope envelope = message.getEnvelope();
        this.eventContext.setBasicProperties(properties);
        this.eventContext.setEnvelope(envelope);
        MDCEnricher.enrichMDC((AMQP.BasicProperties)properties);
        if (this.jwtValidator != null) {
            this.authorizeMessage();
        }
        Object handlerClassInstance = this.methodHandleContext.getHandlerClass().cast(this.eventHandlerInstance);
        Object messageObject = this.objectMapper.readValue(message.getBody(), this.methodHandleContext.getEventClass());
        Object invocationResult = this.methodHandle.invoke(handlerClassInstance, messageObject);
        Optional<Class<?>> optionalReturnEventClass = Optional.ofNullable(this.methodHandleContext.getReturnEventClass());
        if (this.irisContext.isRpc()) {
            log.trace("DeliverCallbackProvider handling RPC message!");
            Optional<String> requestId = this.eventContext.getMessageId();
            if (requestId.isEmpty()) {
                throw new RuntimeException("RPC event without requestId can not be processed");
            }
            if (optionalReturnEventClass.isEmpty()) {
                throw new RuntimeException("RPC message handler without non-void return class can not be processed");
            }
            this.replyMessage(invocationResult, optionalReturnEventClass.get(), this.eventContext.getAmqpBasicProperties().getReplyTo());
        } else {
            optionalReturnEventClass.ifPresent(returnEventClass -> this.forwardMessage(invocationResult, (Class<?>)returnEventClass));
        }
        channel.basicAck(message.getEnvelope().getDeliveryTag(), false);
    }

    public IrisContext getIrisContext() {
        return this.irisContext;
    }

    private void authorizeMessage() {
        try {
            SecurityIdentity securityIdentity = this.jwtValidator.authenticate(this.irisContext.getHandlerRolesAllowed());
            Instance association = CDI.current().select(CurrentIdentityAssociation.class, new Annotation[0]);
            if (!association.isResolvable()) {
                throw new AuthenticationFailedException("JWT identity association not resolvable.");
            }
            Optional.ofNullable(securityIdentity).map(SecurityIdentity::getPrincipal).map(Principal::getName).ifPresent(subject -> MDCEnricher.put((String)"userId", (String)subject));
            ((CurrentIdentityAssociation)association.get()).setIdentity(securityIdentity);
        }
        catch (SecurityException securityException) {
            throw IrisExceptionHandler.getSecurityException(securityException);
        }
    }

    private void forwardMessage(Object invocationResult, Class<?> returnEventClass) {
        if (invocationResult == null) {
            return;
        }
        Object returnClassInstance = returnEventClass.cast(invocationResult);
        this.producer.send(returnClassInstance);
    }

    private void replyMessage(Object invocationResult, Class<?> returnEventClass, String replyTo) {
        Object returnClassInstance = returnEventClass.cast(invocationResult);
        this.sendRpcResponse(returnClassInstance, replyTo);
    }

    private void sendRpcResponse(Object message, String replyTo) {
        log.trace("Sending RPC response");
        Message messageAnnotation = AnnotationValueExtractor.getMessageAnnotation(message);
        this.producer.publish(message, this.routingDetailsProvider.getRpcRoutingDetails(messageAnnotation, replyTo));
    }
}

