/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.broker.internal.state;

import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;
import io.a2a.spec.Message;
import java.io.Serializable;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
import javax.inject.Inject;
import org.mule.runtime.api.lock.LockFactory;
import org.mule.runtime.api.store.ObjectDoesNotExistException;
import org.mule.runtime.api.store.ObjectStore;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.func.CheckedSupplier;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;

public class ConversationService {
    @Inject
    private LockFactory lockFactory;

    public Conversation getOrCreateConversation(Message message, String fallbackContextId, AgentsBroker broker) {
        ObjectStore<Conversation> store = broker.getConversationStore();
        if (message.getContextId() != null) {
            return (Conversation)this.withLock(this.getConversationKey(message.getContextId(), broker), () -> {
                try {
                    return (Conversation)store.retrieve(this.getConversationKey(message.getContextId(), broker));
                }
                catch (ObjectDoesNotExistException e) {
                    throw new ModuleException("contextId '%s' not found".formatted(message.getContextId()), (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
                }
            });
        }
        return (Conversation)this.withLock(this.getConversationKey(fallbackContextId, broker), () -> {
            Conversation conversation = new Conversation(fallbackContextId);
            store.store(this.getConversationKey(conversation.getId(), broker), (Serializable)conversation);
            return conversation;
        });
    }

    public TaskContext getTaskFor(Message message, Conversation conversation, String fallbackTaskId, AgentsBroker broker) {
        TaskContext task;
        ObjectStore<TaskContext> store = broker.getTasksObjectStore();
        if (message.getTaskId() != null) {
            if (message.getContextId() == null) {
                throw new ModuleException("Invalid A2A message: contextId must be specified when referring to a taskId", (ErrorTypeDefinition)BrokerErrorTypes.A2A);
            }
            String taskKey = this.getTaskKey(message.getTaskId(), broker);
            TaskContext prevTaskContext = (TaskContext)this.withLock(taskKey, () -> {
                try {
                    return (TaskContext)store.retrieve(taskKey);
                }
                catch (ObjectDoesNotExistException e) {
                    throw new ModuleException("Invalid A2A message: Task '%s' does not exists".formatted(message.getTaskId()), (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
                }
                catch (ObjectStoreException e) {
                    throw new ModuleException("Error accessing task store", (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
                }
            });
            if (!prevTaskContext.getConversationId().equals(conversation.getId())) {
                throw new ModuleException("Invalid A2A message: Task '%s' is not part of this conversation".formatted(message.getTaskId()), (ErrorTypeDefinition)BrokerErrorTypes.A2A);
            }
            switch (prevTaskContext.getConversationState().getTaskState()) {
                case INPUT_REQUIRED: 
                case AUTH_REQUIRED: {
                    task = prevTaskContext;
                    break;
                }
                default: {
                    task = prevTaskContext.createContinuation(UUID.randomUUID().toString());
                    break;
                }
            }
        } else {
            task = new TaskContext(fallbackTaskId, conversation.getId());
        }
        if (message.getReferenceTaskIds() != null) {
            message.getReferenceTaskIds().forEach(task.getReferencedTaskIds()::add);
        }
        try {
            this.upsert(task, broker);
        }
        catch (Exception e) {
            throw new ModuleException("Error accessing task store", (ErrorTypeDefinition)BrokerErrorTypes.A2A, (Throwable)e);
        }
        return task;
    }

    public void upsert(TaskContext taskContext, AgentsBroker broker) {
        this.doUpsert(this.getTaskKey(taskContext.getTaskId(), broker), taskContext, broker.getTasksObjectStore());
    }

    private <T extends Serializable> void doUpsert(String key, T value, ObjectStore<T> store) {
        this.withLock(key, () -> {
            if (store.contains(key)) {
                store.remove(key);
            }
            store.store(key, value);
            return null;
        });
    }

    public <T> T synchronizedConversation(String conversationId, AgentsBroker broker, CheckedSupplier<T> supplier) {
        return this.withLock(this.getConversationKey(conversationId, broker), () -> supplier.get());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <T> T withLock(String lockId, CheckedSupplier<T> supplier) {
        Lock lock = this.lockFactory.createLock(lockId);
        lock.lock();
        try {
            Object object = supplier.get();
            return (T)object;
        }
        finally {
            lock.unlock();
        }
    }

    private String getConversationKey(String conversationId, AgentsBroker broker) {
        return "[" + broker.getConfigName() + "]-" + conversationId;
    }

    private String getTaskKey(String taskId, AgentsBroker broker) {
        return "[" + broker.getConfigName() + "]-" + taskId;
    }
}

