/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.broker.internal.extension.connection.session;

import static com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes.LLM_ERROR;
import static com.mulesoft.modules.agent.broker.internal.util.ToolUtils.openAiLLMOutputSchema;

import static com.openai.models.responses.ResponseInputItem.ofResponseOutputMessage;
import static com.openai.models.responses.Tool.ofFunction;
import static com.openai.models.responses.ToolChoiceOptions.REQUIRED;
import static io.a2a.util.Utils.OBJECT_MAPPER;
import static org.mule.runtime.api.functional.Either.left;
import static org.mule.runtime.api.functional.Either.right;

import org.mule.runtime.api.functional.Either;
import org.mule.runtime.extension.api.exception.ModuleException;

import com.mulesoft.modules.agent.broker.internal.extension.connection.openai.OpenAISettings;
import com.mulesoft.modules.agent.broker.internal.llm.LLMRequest;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolSelection;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;

import com.openai.client.OpenAIClient;
import com.openai.models.responses.FunctionTool;
import com.openai.models.responses.ResponseCreateParams;
import com.openai.models.responses.ResponseFunctionToolCall;
import com.openai.models.responses.ResponseOutputItem;
import com.openai.models.responses.Tool;

public class FunctionBasedSession extends BaseLLMSession {

  private static final String PROVIDE_STRUCTURED_ANSWER = "provide_structured_answer";
  private static final Tool STRUCTURED_OUTPUT_FUNCTION = ofFunction(FunctionTool.builder()
      .name(PROVIDE_STRUCTURED_ANSWER)
      .description("Provides an answer to the user's question, either when the goal is completed, has failed or additional " +
          "input that cannot be obtained through any other function is required. This ensures the output is always a JSON structure")
      .strict(true)
      .parameters(FunctionTool.Parameters.builder()
          .putAllAdditionalProperties(openAiLLMOutputSchema())
          .build())
      .build());

  private static List<Tool> enhanceTools(List<Tool> tools) {
    var enhanced = new ArrayList<Tool>(tools.size() + 1);
    enhanced.addAll(tools);
    enhanced.add(STRUCTURED_OUTPUT_FUNCTION);

    return enhanced;
  }

  public FunctionBasedSession(OpenAIClient client, LLMRequest llmRequest, OpenAISettings settings, List<Tool> tools) {
    super(client, llmRequest, settings, enhanceTools(tools));
  }

  @Override
  protected CompletableFuture<Either<LLMOutput, ToolSelection>> doGetNext() {
    return client.async().responses().create(builder.build()).thenApply(response -> {

      for (ResponseOutputItem item : response.output()) {
        if (item.isReasoning()) {
          onReasoningItem(item.asReasoning());
        } else if (item.isFunctionCall()) {
          return onFunctionCall(item.asFunctionCall());
        } else if (item.isCustomToolCall()) {
          return onCustomCall(item.asCustomToolCall());
        } else if (item.isMessage()) {
          var message = item.asMessage();
          inputs.add(ofResponseOutputMessage(message));
          throw new ModuleException("LLM returned a message but a function call was expected", LLM_ERROR);
        }
      }

      throw new ModuleException("LLM didn't provide a valid response", LLM_ERROR);
    });
  }

  @Override
  protected Either<LLMOutput, ToolSelection> onFunctionCall(ResponseFunctionToolCall toolCall) {
    return super.onFunctionCall(toolCall).reduce(Either::left, selection -> {
      if (PROVIDE_STRUCTURED_ANSWER.equals(selection.getToolId())) {
        try {
          var output = OBJECT_MAPPER.readValue(selection.getInput(), LLMOutput.class);
          output.setInternalReasoning(selection.getInternalReasoning());
          toolCall.id().ifPresent(output::setId);

          return left(output);
        } catch (Exception e) {
          throw new ModuleException("Could not deserialize LLM output from JSON response", LLM_ERROR, e);
        }
      }

      return right(selection);
    });
  }

  @Override
  protected ResponseCreateParams.Builder newRequestBuilder() {
    return super.newRequestBuilder().toolChoice(REQUIRED);
  }
}
