/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.agentexecutor;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncCommandAction;
import org.bsc.langgraph4j.action.AsyncNodeActionWithConfig;
import org.bsc.langgraph4j.action.Command;
import org.bsc.langgraph4j.action.InterruptionMetadata;
import org.bsc.langgraph4j.agent.AgentEx;
import org.bsc.langgraph4j.agentexecutor.AgentExecutorBuilder;
import org.bsc.langgraph4j.agentexecutor.CallModel;
import org.bsc.langgraph4j.langchain4j.serializer.jackson.LC4jJacksonStateSerializer;
import org.bsc.langgraph4j.langchain4j.serializer.std.LC4jStateSerializer;
import org.bsc.langgraph4j.langchain4j.tool.LC4jToolService;
import org.bsc.langgraph4j.prebuilt.MessagesState;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.state.Channels;
import org.bsc.langgraph4j.utils.CollectionsUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface AgentExecutorEx {
    public static final Logger log = LoggerFactory.getLogger(AgentExecutorEx.class);

    public static AsyncNodeActionWithConfig<State> executeTooL(LC4jToolService toolService, String actionName) {
        return AsyncNodeActionWithConfig.node_async((state, config) -> {
            log.trace("ExecuteTool");
            List toolExecutionRequests = state.lastMessage().filter(m -> ChatMessageType.AI == m.type()).map(m -> (AiMessage)m).filter(AiMessage::hasToolExecutionRequests).map(AiMessage::toolExecutionRequests).map(requests -> requests.stream().filter(req -> Objects.equals(req.name(), actionName)).toList()).orElseThrow(() -> new IllegalArgumentException("no tool execution request found!"));
            List<ToolExecutionResultMessage> results = toolExecutionRequests.stream().map(arg_0 -> ((LC4jToolService)toolService).execute(arg_0)).filter(Optional::isPresent).map(Optional::get).toList();
            return Map.of("tool_execution_results", results);
        });
    }

    private static AsyncNodeActionWithConfig<State> dispatchTools(Set<String> approvals) {
        return AsyncNodeActionWithConfig.node_async((state, config) -> {
            log.trace("DispatchTools");
            Optional<List> toolExecutionRequests = state.lastMessage().filter(m -> ChatMessageType.AI == m.type()).map(m -> (AiMessage)m).filter(AiMessage::hasToolExecutionRequests).map(AiMessage::toolExecutionRequests);
            if (toolExecutionRequests.isEmpty()) {
                return Map.of("agent_response", "no tool execution request found!");
            }
            List requests = toolExecutionRequests.get();
            return requests.stream().filter(request -> state.toolExecutionResults().stream().noneMatch(r -> Objects.equals(r.toolName(), request.name()))).findFirst().map(result -> approvals.contains(result.name()) ? String.format("approval_%s", result.name()) : result.name()).map(actionId -> Map.of("next_action", actionId)).orElseGet(() -> CollectionsUtils.mapOf((Object)"messages", state.toolExecutionResults(), (Object)"tool_execution_results", (Object)AgentState.MARK_FOR_RESET, (Object)"next_action", (Object)AgentState.MARK_FOR_REMOVAL));
        });
    }

    private static AsyncCommandAction<State> approvalAction() {
        return (state, config) -> {
            CompletableFuture<Command> result = new CompletableFuture<Command>();
            if (state.value("approval_result").isEmpty()) {
                result.completeExceptionally(new IllegalStateException(String.format("resume property '%s' not found!", "approval_result")));
                return result;
            }
            String resumeState = (String)state.value("approval_result").orElseThrow(() -> new IllegalStateException(String.format("resume property '%s' not found!", "approval_result")));
            if (Objects.equals(resumeState, AgentEx.ApprovalState.APPROVED.name())) {
                result.complete(new Command(resumeState, Map.of("approval_result", AgentState.MARK_FOR_REMOVAL)));
            } else {
                String actionName = state.nextAction().map(v -> v.replace("approval_", "")).orElseThrow(() -> new IllegalStateException("no next action found!"));
                List<ToolExecutionRequest> tools = state.toolExecutionRequestsByName(actionName);
                if (tools.isEmpty()) {
                    throw new IllegalStateException("no tool execution request found!");
                }
                List<ToolExecutionResultMessage> toolResponses = tools.stream().map(toolRequest -> ToolExecutionResultMessage.from((ToolExecutionRequest)toolRequest, (String)"execution has been DENIED!")).toList();
                result.complete(new Command(resumeState, Map.of("messages", toolResponses, "tool_execution_results", "execution has been DENIED!", "approval_result", AgentState.MARK_FOR_REMOVAL)));
            }
            return result;
        };
    }

    private static AsyncCommandAction<State> shouldContinue() {
        return AsyncCommandAction.command_async((state, config) -> state.finalResponse().map(res -> new Command("end")).orElse(new Command("continue")));
    }

    private static AsyncCommandAction<State> dispatchAction() {
        return AsyncCommandAction.command_async((state, config) -> state.nextAction().map(Command::new).orElseGet(() -> new Command("model")));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends AgentExecutorBuilder<State, Builder> {
        private final Map<String, AgentEx.ApprovalNodeAction<ChatMessage, State>> approvals = new LinkedHashMap<String, AgentEx.ApprovalNodeAction<ChatMessage, State>>();

        public Builder approvalOn(String actionId, BiFunction<String, State, InterruptionMetadata<State>> interruptionMetadataProvider) {
            AgentEx.ApprovalNodeAction action = AgentEx.ApprovalNodeAction.builder().interruptionMetadataProvider(interruptionMetadataProvider).build();
            this.approvals.put(actionId, (AgentEx.ApprovalNodeAction<ChatMessage, State>)action);
            return this;
        }

        public StateGraph<State> build() throws GraphStateException {
            if (this.streamingChatModel != null && this.chatModel != null) {
                throw new IllegalArgumentException("chatLanguageModel and streamingChatLanguageModel are mutually exclusive!");
            }
            if (this.streamingChatModel == null && this.chatModel == null) {
                throw new IllegalArgumentException("a chatLanguageModel or streamingChatLanguageModel is required!");
            }
            if (this.stateSerializer == null) {
                this.stateSerializer = Serializers.STD.object();
            }
            Map tools = this.toolMap();
            LC4jToolService toolService = new LC4jToolService(tools);
            return AgentEx.builder().stateSerializer(this.stateSerializer).schema(State.SCHEMA).toolName(ToolSpecification::name).callModelAction(new CallModel<State>(this)).dispatchToolsAction(AgentExecutorEx.dispatchTools(this.approvals.keySet())).executeToolFactory(toolName -> AgentExecutorEx.executeTooL(toolService, toolName)).shouldContinueEdge(AgentExecutorEx.shouldContinue()).approvalActionEdge(AgentExecutorEx.approvalAction()).dispatchActionEdge(AgentExecutorEx.dispatchAction()).build(tools.keySet(), this.approvals);
        }
    }

    public static class State
    extends MessagesState<ChatMessage> {
        static final Map<String, Channel<?>> SCHEMA = CollectionsUtils.mergeMap((Map)MessagesState.SCHEMA, Map.of("tool_execution_results", Channels.appender(ArrayList::new)));
        public static final String FINAL_RESPONSE = "agent_response";

        public State(Map<String, Object> initData) {
            super(initData);
        }

        public List<ToolExecutionResultMessage> toolExecutionResults() {
            return (List)this.value("tool_execution_results").orElseThrow(() -> new RuntimeException("messages not found"));
        }

        public Optional<String> nextAction() {
            return this.value("next_action");
        }

        public Optional<String> finalResponse() {
            return this.value(FINAL_RESPONSE);
        }

        private List<ToolExecutionRequest> toolExecutionRequestsByName(String actionName) {
            return this.lastMessage().filter(m -> ChatMessageType.AI == m.type()).map(AiMessage.class::cast).filter(AiMessage::hasToolExecutionRequests).map(AiMessage::toolExecutionRequests).map(requests -> requests.stream().filter(req -> Objects.equals(req.name(), actionName)).toList()).orElseGet(List::of);
        }
    }

    public static enum Serializers {
        STD((StateSerializer<State>)new LC4jStateSerializer(State::new)),
        JSON((StateSerializer<State>)new LC4jJacksonStateSerializer(State::new));

        private final StateSerializer<State> serializer;

        private Serializers(StateSerializer<State> serializer) {
            this.serializer = serializer;
        }

        public StateSerializer<State> object() {
            return this.serializer;
        }
    }
}

