/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */

package com.mulesoft.modules.agent.broker.internal.state;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.A2A;

import static java.util.UUID.randomUUID;

import static io.a2a.spec.TaskState.INPUT_REQUIRED;

import org.mule.runtime.api.lock.LockFactory;
import org.mule.runtime.api.store.ObjectDoesNotExistException;
import org.mule.runtime.api.store.ObjectStore;
import org.mule.runtime.api.store.ObjectStoreException;
import org.mule.runtime.core.api.util.func.CheckedSupplier;
import org.mule.runtime.extension.api.exception.ModuleException;

import com.mulesoft.modules.agent.broker.internal.extension.AgentsBroker;
import com.mulesoft.modules.agent.broker.internal.state.model.Conversation;
import com.mulesoft.modules.agent.broker.internal.state.model.TaskContext;

import java.io.Serializable;
import java.util.concurrent.locks.Lock;

import javax.inject.Inject;

import io.a2a.spec.Message;

public class ConversationService {

  @Inject
  private LockFactory lockFactory;

  public Conversation getOrCreateConversation(Message message, String fallbackContextId, AgentsBroker broker) {
    final var store = broker.getConversationStore();
    final var contextId = message.getContextId();
    return withLock(getConversationKey(contextId, broker), () -> {
      if (contextId != null) {
        try {
          return store.retrieve(getConversationKey(contextId, broker));
        } catch (ObjectDoesNotExistException e) {
          throw new ModuleException("contextId '%s' not found".formatted(contextId), A2A, e);
        }
      }

      var conversation = new Conversation(fallbackContextId);
      store.store(getConversationKey(conversation.getId(), broker), conversation);

      return conversation;
    });
  }

  public TaskContext getTaskFor(Message message,
                                Conversation conversation,
                                String fallbackTaskId,
                                AgentsBroker broker) {
    var store = broker.getTasksObjectStore();
    TaskContext task;
    TaskContext prevTaskContext;

    if (message.getTaskId() != null) {

      if (message.getContextId() == null) {
        throw new ModuleException("Invalid A2A message: contextId must be specified when referring to a taskId", A2A);
      }

      final var taskKey = getTaskKey(message.getTaskId(), broker);
      prevTaskContext = withLock(taskKey, () -> {
        try {
          return store.retrieve(taskKey);
        } catch (ObjectDoesNotExistException e) {
          throw new ModuleException("Invalid A2A message: Task '%s' does not exists".formatted(message.getTaskId()), A2A, e);
        } catch (ObjectStoreException e) {
          throw new ModuleException("Error accessing task store", A2A, e);
        }
      });

      if (!prevTaskContext.getConversationId().equals(conversation.getId())) {
        throw new ModuleException("Invalid A2A message: Task '%s' is not part of this conversation"
            .formatted(message.getTaskId()), A2A);
      }

      if (prevTaskContext.getConversationState().getTaskState() != INPUT_REQUIRED) {
        task = createTask(randomUUID().toString(), conversation);
        task.refines(prevTaskContext);
      } else {
        task = prevTaskContext;
      }
    } else {
      task = createTask(fallbackTaskId, conversation);
    }

    if (message.getReferenceTaskIds() != null) {
      message.getReferenceTaskIds().forEach(task.getReferencedTaskIds()::add);
    }

    try {
      upsert(task, broker);
    } catch (Exception e) {
      throw new ModuleException("Error accessing task store", A2A, e);
    }
    return task;
  }

  private TaskContext createTask(String id, Conversation conversation) {
    return new TaskContext(id, conversation.getId());
  }

  public void upsert(Conversation conversation, AgentsBroker broker) throws ObjectStoreException {
    doUpsert(getConversationKey(conversation.getId(), broker), conversation, broker.getConversationStore());
  }

  public void upsert(TaskContext taskContext, AgentsBroker broker) throws ObjectStoreException {
    doUpsert(getTaskKey(taskContext.getTaskId(), broker), taskContext, broker.getTasksObjectStore());
  }

  private <T extends Serializable> void doUpsert(String key, T value, ObjectStore<T> store) throws ObjectStoreException {
    withLock(key, () -> {
      if (store.contains(key)) {
        store.remove(key);
      }

      store.store(key, value);
      return null;
    });

  }

  public <T> T synchronizedConversation(String conversationId, AgentsBroker broker, CheckedSupplier<T> supplier) {
    return withLock(getConversationKey(conversationId, broker), supplier::get);
  }

  private <T> T withLock(String lockId, CheckedSupplier<T> supplier) {
    Lock lock = lockFactory.createLock(lockId);
    lock.lock();
    try {
      return supplier.get();
    } finally {
      lock.unlock();
    }
  }

  private String getConversationKey(String conversationId, AgentsBroker broker) {
    return "[" + broker.getConfigName() + "]" + "-" + conversationId;
  }

  private String getTaskKey(String taskId, AgentsBroker broker) {
    return "[" + broker.getConfigName() + "]" + "-" + taskId;
  }
}
