/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.modules.agent.broker.internal.extension.connection.session;

import com.mulesoft.modules.agent.broker.internal.error.BrokerErrorTypes;
import com.mulesoft.modules.agent.broker.internal.extension.connection.openai.OpenAISettings;
import com.mulesoft.modules.agent.broker.internal.extension.connection.session.BaseLLMSession;
import com.mulesoft.modules.agent.broker.internal.llm.LLMRequest;
import com.mulesoft.modules.agent.broker.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.broker.internal.state.model.ToolSelection;
import com.mulesoft.modules.agent.broker.internal.util.ToolUtils;
import com.openai.client.OpenAIClient;
import com.openai.models.responses.FunctionTool;
import com.openai.models.responses.ResponseCreateParams;
import com.openai.models.responses.ResponseFunctionToolCall;
import com.openai.models.responses.ResponseInputItem;
import com.openai.models.responses.ResponseOutputItem;
import com.openai.models.responses.ResponseOutputMessage;
import com.openai.models.responses.Tool;
import com.openai.models.responses.ToolChoiceOptions;
import io.a2a.util.Utils;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.mule.runtime.api.functional.Either;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;

public class FunctionBasedSession
extends BaseLLMSession {
    private static final String PROVIDE_STRUCTURED_ANSWER = "provide_structured_answer";
    private static final Tool STRUCTURED_OUTPUT_FUNCTION = Tool.ofFunction((FunctionTool)FunctionTool.builder().name("provide_structured_answer").description("Provides an answer to the user's question, either when the goal is completed, has failed or additional input that cannot be obtained through any other function is required. This ensures the output is always a JSON structure").strict(true).parameters(FunctionTool.Parameters.builder().putAllAdditionalProperties(ToolUtils.openAiLLMOutputSchema()).build()).build());

    private static List<Tool> enhanceTools(List<Tool> tools) {
        ArrayList<Tool> enhanced = new ArrayList<Tool>(tools.size() + 1);
        enhanced.addAll(tools);
        enhanced.add(STRUCTURED_OUTPUT_FUNCTION);
        return enhanced;
    }

    public FunctionBasedSession(OpenAIClient client, LLMRequest llmRequest, OpenAISettings settings, List<Tool> tools) {
        super(client, llmRequest, settings, FunctionBasedSession.enhanceTools(tools));
    }

    @Override
    protected CompletableFuture<Either<LLMOutput, ToolSelection>> doGetNext() {
        return this.client.async().responses().create(this.builder.build()).thenApply(response -> {
            for (ResponseOutputItem item : response.output()) {
                if (item.isReasoning()) {
                    this.onReasoningItem(item.asReasoning());
                    continue;
                }
                if (item.isFunctionCall()) {
                    return this.onFunctionCall(item.asFunctionCall());
                }
                if (item.isCustomToolCall()) {
                    return this.onCustomCall(item.asCustomToolCall());
                }
                if (!item.isMessage()) continue;
                ResponseOutputMessage message = item.asMessage();
                this.inputs.add(ResponseInputItem.ofResponseOutputMessage((ResponseOutputMessage)message));
                throw new ModuleException("LLM returned a message but a function call was expected", (ErrorTypeDefinition)BrokerErrorTypes.LLM_ERROR);
            }
            throw new ModuleException("LLM didn't provide a valid response", (ErrorTypeDefinition)BrokerErrorTypes.LLM_ERROR);
        });
    }

    @Override
    protected Either<LLMOutput, ToolSelection> onFunctionCall(ResponseFunctionToolCall toolCall) {
        return (Either)super.onFunctionCall(toolCall).reduce(Either::left, selection -> {
            if (PROVIDE_STRUCTURED_ANSWER.equals(selection.getToolId())) {
                try {
                    LLMOutput output = (LLMOutput)Utils.OBJECT_MAPPER.readValue(selection.getInput(), LLMOutput.class);
                    output.setInternalReasoning(selection.internalReasoning());
                    return Either.left((Object)output);
                }
                catch (Exception e) {
                    throw new ModuleException("Could not deserialize LLM output from JSON response", (ErrorTypeDefinition)BrokerErrorTypes.LLM_ERROR, (Throwable)e);
                }
            }
            return Either.right((Object)selection);
        });
    }

    @Override
    protected ResponseCreateParams.Builder newRequestBuilder() {
        return super.newRequestBuilder().toolChoice(ToolChoiceOptions.REQUIRED);
    }
}

