/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.spring.ai.generators;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.bsc.async.AsyncGenerator;
import org.bsc.async.FlowGenerator;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.streaming.StreamingOutput;
import org.reactivestreams.FlowAdapters;
import org.reactivestreams.Publisher;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import reactor.core.publisher.Flux;

public interface StreamingChatGenerator {
    public static <State extends AgentState> Builder<State> builder() {
        return new Builder();
    }

    public static class Builder<State extends AgentState> {
        private Function<ChatResponse, Map<String, Object>> mapResult;
        private String startingNode;
        private State startingState;

        public Builder<State> mapResult(Function<ChatResponse, Map<String, Object>> mapResult) {
            this.mapResult = mapResult;
            return this;
        }

        public Builder<State> startingNode(String node) {
            this.startingNode = node;
            return this;
        }

        public Builder<State> startingState(State state) {
            this.startingState = state;
            return this;
        }

        public AsyncGenerator<? extends NodeOutput<State>> build(Flux<ChatResponse> flux) {
            Objects.requireNonNull(flux, "flux cannot be null");
            Objects.requireNonNull(this.mapResult, "mapResult cannot be null");
            AtomicReference<Object> result = new AtomicReference<Object>(null);
            Flux processedFlux = flux.filter(response -> response.getResult() != null && response.getResult().getOutput() != null).doOnNext(currentResponse -> result.updateAndGet(lastResponse -> lastResponse == null ? currentResponse : this.mergeResponses((ChatResponse)lastResponse, (ChatResponse)currentResponse))).map(next -> new StreamingOutput(next.getResult().getOutput().getText(), this.startingNode, this.startingState));
            return FlowGenerator.fromPublisher((Flow.Publisher)FlowAdapters.toFlowPublisher((Publisher)processedFlux), () -> this.mapResult.apply((ChatResponse)result.get()));
        }

        private ChatResponse mergeResponses(ChatResponse last, ChatResponse current) {
            AssistantMessage lastMessage = last.getResult().getOutput();
            AssistantMessage currentMessage = current.getResult().getOutput();
            AssistantMessage mergedMessage = AssistantMessage.builder().content(Objects.requireNonNull(this.mergeText(lastMessage.getText(), currentMessage.getText()))).properties(currentMessage.getMetadata()).toolCalls(this.mergeToolCalls(lastMessage.getToolCalls(), currentMessage.getToolCalls())).media(currentMessage.getMedia()).build();
            Generation newGeneration = new Generation(mergedMessage, current.getResult().getMetadata());
            return new ChatResponse(List.of(newGeneration), current.getMetadata());
        }

        private String mergeText(String lastText, String currentText) {
            if (lastText == null) {
                return currentText;
            }
            if (currentText == null) {
                return lastText;
            }
            return lastText.concat(currentText);
        }

        private List<AssistantMessage.ToolCall> mergeToolCalls(List<AssistantMessage.ToolCall> lastToolCalls, List<AssistantMessage.ToolCall> currentToolCalls) {
            if (lastToolCalls == null || lastToolCalls.isEmpty()) {
                return currentToolCalls != null ? currentToolCalls : List.of();
            }
            if (currentToolCalls == null || currentToolCalls.isEmpty()) {
                return lastToolCalls;
            }
            LinkedHashMap toolCallMap = new LinkedHashMap();
            lastToolCalls.forEach(tc -> toolCallMap.put(tc.id(), tc));
            currentToolCalls.forEach(tc -> {
                if (!toolCallMap.containsKey(tc.id())) {
                    toolCallMap.put(tc.id(), tc);
                }
            });
            return toolCallMap.values().stream().toList();
        }
    }
}

