/*
 * Decompiled with CFR 0.152.
 */
package com.google.adk.flows.llmflows;

import com.google.adk.Telemetry;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.Part;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Scope;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.MaybeSource;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

public final class Functions {
    private static final String AF_FUNCTION_CALL_ID_PREFIX = "adk-";

    public static String generateClientFunctionCallId() {
        return AF_FUNCTION_CALL_ID_PREFIX + String.valueOf(UUID.randomUUID());
    }

    public static void populateClientFunctionCallId(Event modelResponseEvent) {
        Optional<Content> originalContentOptional = modelResponseEvent.content();
        if (originalContentOptional.isEmpty()) {
            return;
        }
        Content originalContent = originalContentOptional.get();
        List originalParts = (List)originalContent.parts().orElse(ImmutableList.of());
        if (originalParts.stream().noneMatch(part -> part.functionCall().isPresent())) {
            return;
        }
        ArrayList<Part> newParts = new ArrayList<Part>();
        boolean modified = false;
        for (Part part2 : originalParts) {
            if (part2.functionCall().isPresent()) {
                FunctionCall functionCall = (FunctionCall)part2.functionCall().get();
                if (functionCall.id().isEmpty()) {
                    FunctionCall updatedFunctionCall = functionCall.toBuilder().id(Functions.generateClientFunctionCallId()).build();
                    newParts.add(Part.builder().functionCall(updatedFunctionCall).build());
                    modified = true;
                    continue;
                }
                newParts.add(part2);
                continue;
            }
            newParts.add(part2);
        }
        if (modified) {
            String role = (String)originalContent.role().orElseThrow(() -> new IllegalStateException("Content role is missing in event: " + modelResponseEvent.id()));
            Content newContent = Content.builder().role(role).parts(newParts).build();
            modelResponseEvent.setContent(Optional.of(newContent));
        }
    }

    public static Maybe<Event> handleFunctionCalls(InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
        ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
        ArrayList<Maybe> functionResponseEvents = new ArrayList<Maybe>();
        for (FunctionCall functionCall : functionCalls) {
            if (!tools.containsKey(functionCall.name().get())) {
                throw new RuntimeException("Tool not found: " + (String)functionCall.name().get());
            }
            BaseTool tool = tools.get(functionCall.name().get());
            ToolContext toolContext = ToolContext.builder(invocationContext).functionCallId(functionCall.id().orElse("")).build();
            Map functionArgs = functionCall.args().orElse(new HashMap());
            Maybe maybeFunctionResult = Functions.maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext).switchIfEmpty((MaybeSource)Maybe.defer(() -> Functions.callTool(tool, functionArgs, toolContext)));
            Maybe maybeFunctionResponseEvent = maybeFunctionResult.map(Optional::of).defaultIfEmpty(Optional.empty()).flatMapMaybe(optionalInitialResult -> {
                Map initialFunctionResult = optionalInitialResult.orElse(null);
                Maybe<Map<String, Object>> afterToolResultMaybe = Functions.maybeInvokeAfterToolCall(invocationContext, tool, functionArgs, toolContext, initialFunctionResult);
                return afterToolResultMaybe.map(Optional::of).defaultIfEmpty(Optional.ofNullable(initialFunctionResult)).flatMapMaybe(finalOptionalResult -> {
                    Map finalFunctionResult = finalOptionalResult.orElse(null);
                    if (tool.longRunning() && finalFunctionResult == null) {
                        return Maybe.empty();
                    }
                    Event functionResponseEvent = Functions.buildResponseEvent(tool, finalFunctionResult, toolContext, invocationContext);
                    return Maybe.just((Object)functionResponseEvent);
                });
            });
            functionResponseEvents.add(maybeFunctionResponseEvent);
        }
        return Maybe.merge(functionResponseEvents).toList().flatMapMaybe(events -> {
            if (events.isEmpty()) {
                return Maybe.empty();
            }
            Event mergedEvent = Functions.mergeParallelFunctionResponseEvents(events);
            if (mergedEvent == null) {
                return Maybe.empty();
            }
            if (events.size() > 1) {
                Tracer tracer = Telemetry.getTracer();
                Span mergedSpan = tracer.spanBuilder("tool_response").startSpan();
                try (Scope scope = mergedSpan.makeCurrent();){
                    Telemetry.traceToolResponse(invocationContext, mergedEvent.id(), mergedEvent);
                }
                finally {
                    mergedSpan.end();
                }
            }
            return Maybe.just((Object)mergedEvent);
        });
    }

    public static Set<String> getLongRunningFunctionCalls(List<FunctionCall> functionCalls, Map<String, BaseTool> tools) {
        HashSet<String> longRunningFunctionCalls = new HashSet<String>();
        for (FunctionCall functionCall : functionCalls) {
            BaseTool tool;
            if (!tools.containsKey(functionCall.name().get()) || !(tool = tools.get(functionCall.name().get())).longRunning()) continue;
            longRunningFunctionCalls.add(functionCall.id().orElse(""));
        }
        return longRunningFunctionCalls;
    }

    private static Event mergeParallelFunctionResponseEvents(List<Event> functionResponseEvents) {
        if (functionResponseEvents.isEmpty()) {
            return null;
        }
        if (functionResponseEvents.size() == 1) {
            return functionResponseEvents.get(0);
        }
        Event baseEvent = functionResponseEvents.get(0);
        ArrayList mergedParts = new ArrayList();
        for (Event event : functionResponseEvents) {
            event.content().flatMap(Content::parts).ifPresent(mergedParts::addAll);
        }
        EventActions.Builder mergedActionsBuilder = EventActions.builder();
        for (Event event : functionResponseEvents) {
            mergedActionsBuilder.merge(event.actions());
        }
        return Event.builder().id(Event.generateEventId()).invocationId(baseEvent.invocationId()).author(baseEvent.author()).branch(baseEvent.branch()).content(Optional.of(Content.builder().role("user").parts(mergedParts).build())).actions(mergedActionsBuilder.build()).timestamp(baseEvent.timestamp()).build();
    }

    private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(InvocationContext invocationContext, BaseTool tool, Map<String, Object> functionArgs, ToolContext toolContext) {
        if (invocationContext.agent() instanceof LlmAgent) {
            LlmAgent agent = (LlmAgent)invocationContext.agent();
            return agent.beforeToolCallback().map(callback -> callback.call(invocationContext, tool, functionArgs, toolContext)).orElse(Maybe.empty());
        }
        return Maybe.empty();
    }

    private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(InvocationContext invocationContext, BaseTool tool, Map<String, Object> functionArgs, ToolContext toolContext, Map<String, Object> functionResult) {
        if (invocationContext.agent() instanceof LlmAgent) {
            LlmAgent agent = (LlmAgent)invocationContext.agent();
            return agent.afterToolCallback().map(callback -> callback.call(invocationContext, tool, functionArgs, toolContext, functionResult)).orElse(Maybe.empty());
        }
        return Maybe.empty();
    }

    private static Maybe<Map<String, Object>> callTool(BaseTool tool, Map<String, Object> args, ToolContext toolContext) {
        Tracer tracer = Telemetry.getTracer();
        return Maybe.defer(() -> {
            Maybe maybe;
            block8: {
                Span span = tracer.spanBuilder("tool_call [" + tool.name() + "]").startSpan();
                Scope scope = span.makeCurrent();
                try {
                    Telemetry.traceToolCall(args);
                    maybe = tool.runAsync(args, toolContext).toMaybe().doOnError(arg_0 -> ((Span)span).recordException(arg_0)).doFinally(() -> ((Span)span).end());
                    if (scope == null) break block8;
                }
                catch (Throwable t$) {
                    try {
                        if (scope != null) {
                            try {
                                scope.close();
                            }
                            catch (Throwable x2) {
                                t$.addSuppressed(x2);
                            }
                        }
                        throw t$;
                    }
                    catch (RuntimeException e) {
                        span.recordException((Throwable)e);
                        span.end();
                        return Maybe.error((Throwable)new RuntimeException("Failed to call tool: " + tool.name(), e));
                    }
                }
                scope.close();
            }
            return maybe;
        });
    }

    private static Event buildResponseEvent(BaseTool tool, Map<String, Object> response, ToolContext toolContext, InvocationContext invocationContext) {
        Tracer tracer = Telemetry.getTracer();
        Span span = tracer.spanBuilder("tool_response [" + tool.name() + "]").startSpan();
        try {
            Event event;
            block10: {
                Scope scope = span.makeCurrent();
                try {
                    if (response == null) {
                        response = new HashMap<String, Object>();
                    }
                    Part partFunctionResponse = Part.builder().functionResponse(FunctionResponse.builder().id(toolContext.functionCallId().orElse("")).name(tool.name()).response(response).build()).build();
                    Event event2 = Event.builder().id(Event.generateEventId()).invocationId(invocationContext.invocationId()).author(invocationContext.agent().name()).branch(invocationContext.branch()).content(Optional.of(Content.builder().role("user").parts(Collections.singletonList(partFunctionResponse)).build())).actions(toolContext.eventActions()).build();
                    Telemetry.traceToolResponse(invocationContext, event2.id(), event2);
                    event = event2;
                    if (scope == null) break block10;
                }
                catch (Throwable throwable) {
                    if (scope != null) {
                        try {
                            scope.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                scope.close();
            }
            return event;
        }
        finally {
            span.end();
        }
    }

    private Functions() {
    }
}

