/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.azure.openai;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionChoice;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.openai.InternalOpenAiHelper;
import dev.langchain4j.model.openai.OpenAiStreamingResponseBuilder;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
import java.net.Proxy;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.jboss.logging.Logger;

public class AzureOpenAiStreamingChatModel
implements StreamingChatLanguageModel,
TokenCountEstimator {
    private static final Logger log = Logger.getLogger(AzureOpenAiStreamingChatModel.class);
    private final OpenAiClient client;
    private final Double temperature;
    private final Double topP;
    private final Integer maxTokens;
    private final Double presencePenalty;
    private final Double frequencyPenalty;
    private final Tokenizer tokenizer;
    private final ResponseFormat responseFormat;
    private final List<ChatModelListener> listeners;

    public AzureOpenAiStreamingChatModel(String endpoint, String apiVersion, String apiKey, String adToken, Tokenizer tokenizer, Double temperature, Double topP, Integer maxTokens, Double presencePenalty, Double frequencyPenalty, Duration timeout, Proxy proxy, String responseFormat, Boolean logRequests, Boolean logResponses, String configName, List<ChatModelListener> listeners) {
        this.listeners = listeners;
        timeout = (Duration)Utils.getOrDefault((Object)timeout, (Object)Duration.ofSeconds(60L));
        this.client = ((QuarkusOpenAiClient.Builder)OpenAiClient.builder().baseUrl(ValidationUtils.ensureNotBlank((String)endpoint, (String)"endpoint")).apiVersion(apiVersion).callTimeout(timeout).connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests).logStreamingResponses(logResponses)).userAgent("langchain4j-quarkus-azure-openai").azureAdToken(adToken).azureApiKey(apiKey).configName(configName).build();
        this.temperature = (Double)Utils.getOrDefault((Object)temperature, (Object)0.7);
        this.topP = topP;
        this.maxTokens = maxTokens;
        this.presencePenalty = presencePenalty;
        this.frequencyPenalty = frequencyPenalty;
        this.tokenizer = tokenizer;
        this.responseFormat = responseFormat == null ? null : ResponseFormat.builder().type(ResponseFormatType.valueOf((String)responseFormat.toUpperCase(Locale.ROOT))).build();
    }

    public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, null, null, handler);
    }

    public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, toolSpecifications, null, handler);
    }

    public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
        this.generate(messages, null, toolSpecification, handler);
    }

    private void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler<AiMessage> handler) {
        Integer inputTokenCount;
        ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder().stream(Boolean.valueOf(true)).messages(InternalOpenAiHelper.toOpenAiMessages(messages)).temperature(this.temperature).topP(this.topP).maxTokens(this.maxTokens).presencePenalty(this.presencePenalty).frequencyPenalty(this.frequencyPenalty).responseFormat(this.responseFormat);
        Integer n = inputTokenCount = this.tokenizer == null ? null : Integer.valueOf(this.tokenizer.estimateTokenCountInMessages(messages));
        if (toolThatMustBeExecuted != null) {
            requestBuilder.functions(InternalOpenAiHelper.toFunctions(Collections.singletonList(toolThatMustBeExecuted)));
            requestBuilder.functionCall(toolThatMustBeExecuted.name());
            if (this.tokenizer != null) {
                inputTokenCount = inputTokenCount + this.tokenizer.estimateTokenCountInForcefulToolSpecification(toolThatMustBeExecuted);
            }
        } else if (!Utils.isNullOrEmpty(toolSpecifications)) {
            requestBuilder.functions(InternalOpenAiHelper.toFunctions(toolSpecifications));
            if (this.tokenizer != null) {
                inputTokenCount = inputTokenCount + this.tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
            }
        }
        ChatCompletionRequest request = requestBuilder.build();
        ChatModelRequest modelListenerRequest = this.createModelListenerRequest(request, messages, toolSpecifications);
        ConcurrentHashMap attributes = new ConcurrentHashMap();
        ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
        this.listeners.forEach(listener -> {
            try {
                listener.onRequest(requestContext);
            }
            catch (Exception e) {
                log.warn((Object)"Exception while calling model listener", (Throwable)e);
            }
        });
        OpenAiStreamingResponseBuilder responseBuilder = new OpenAiStreamingResponseBuilder();
        AtomicReference responseId = new AtomicReference();
        AtomicReference responseModel = new AtomicReference();
        this.client.chatCompletion(request).onPartialResponse(partialResponse -> {
            responseBuilder.append(partialResponse);
            AzureOpenAiStreamingChatModel.handle(partialResponse, handler);
            if (!Utils.isNullOrBlank((String)partialResponse.id())) {
                responseId.set(partialResponse.id());
            }
            if (!Utils.isNullOrBlank((String)partialResponse.model())) {
                responseModel.set(partialResponse.model());
            }
        }).onComplete(() -> {
            ChatResponse response = responseBuilder.build();
            ChatModelResponse modelListenerResponse = this.createModelListenerResponse((String)responseId.get(), (String)responseModel.get(), response);
            ChatModelResponseContext responseContext = new ChatModelResponseContext(modelListenerResponse, modelListenerRequest, attributes);
            this.listeners.forEach(listener -> {
                try {
                    listener.onResponse(responseContext);
                }
                catch (Exception e) {
                    log.warn((Object)"Exception while calling model listener", (Throwable)e);
                }
            });
            Response aiResponse = Response.from((Object)response.aiMessage(), (TokenUsage)response.tokenUsage(), (FinishReason)response.finishReason());
            handler.onComplete(aiResponse);
        }).onError(error -> {
            ChatResponse response = responseBuilder.build();
            ChatModelResponse modelListenerPartialResponse = this.createModelListenerResponse((String)responseId.get(), (String)responseModel.get(), response);
            ChatModelErrorContext errorContext = new ChatModelErrorContext(error, modelListenerRequest, modelListenerPartialResponse, attributes);
            this.listeners.forEach(listener -> {
                try {
                    listener.onError(errorContext);
                }
                catch (Exception e) {
                    log.warn((Object)"Exception while calling model listener", (Throwable)e);
                }
            });
            handler.onError(error);
        }).execute();
    }

    private static void handle(ChatCompletionResponse partialResponse, StreamingResponseHandler<AiMessage> handler) {
        List choices = partialResponse.choices();
        if (choices == null || choices.isEmpty()) {
            return;
        }
        Delta delta = ((ChatCompletionChoice)choices.get(0)).delta();
        String content = delta.content();
        if (content != null) {
            handler.onNext(content);
        }
    }

    public int estimateTokenCount(List<ChatMessage> messages) {
        return this.tokenizer.estimateTokenCountInMessages(messages);
    }

    private ChatModelRequest createModelListenerRequest(ChatCompletionRequest request, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatModelRequest.builder().model(request.model()).temperature(request.temperature()).topP(request.topP()).maxTokens(request.maxTokens()).messages(messages).toolSpecifications(toolSpecifications).build();
    }

    private ChatModelResponse createModelListenerResponse(String responseId, String responseModel, ChatResponse response) {
        if (response == null) {
            return null;
        }
        return ChatModelResponse.builder().id(responseId).model(responseModel).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).aiMessage(response.aiMessage()).build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String endpoint;
        private String apiVersion;
        private String apiKey;
        private String adToken;
        private Tokenizer tokenizer;
        private Double temperature;
        private Double topP;
        private Integer maxTokens;
        private Double presencePenalty;
        private Double frequencyPenalty;
        private Duration timeout;
        private Proxy proxy;
        private String responseFormat;
        private Boolean logRequests;
        private Boolean logResponses;
        private String configName;
        private List<ChatModelListener> listeners = Collections.emptyList();

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder apiVersion(String apiVersion) {
            this.apiVersion = apiVersion;
            return this;
        }

        public Builder apiKey(String apiKey) {
            this.apiKey = apiKey;
            return this;
        }

        public Builder adToken(String adToken) {
            this.adToken = adToken;
            return this;
        }

        public Builder tokenizer(Tokenizer tokenizer) {
            this.tokenizer = tokenizer;
            return this;
        }

        public Builder temperature(Double temperature) {
            this.temperature = temperature;
            return this;
        }

        public Builder topP(Double topP) {
            this.topP = topP;
            return this;
        }

        public Builder maxTokens(Integer maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public Builder presencePenalty(Double presencePenalty) {
            this.presencePenalty = presencePenalty;
            return this;
        }

        public Builder frequencyPenalty(Double frequencyPenalty) {
            this.frequencyPenalty = frequencyPenalty;
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder proxy(Proxy proxy) {
            this.proxy = proxy;
            return this;
        }

        public Builder responseFormat(String responseFormat) {
            this.responseFormat = responseFormat;
            return this;
        }

        public Builder logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Builder configName(String configName) {
            this.configName = configName;
            return this;
        }

        public Builder listeners(List<ChatModelListener> listeners) {
            this.listeners = listeners;
            return this;
        }

        public AzureOpenAiStreamingChatModel build() {
            return new AzureOpenAiStreamingChatModel(this.endpoint, this.apiVersion, this.apiKey, this.adToken, this.tokenizer, this.temperature, this.topP, this.maxTokens, this.presencePenalty, this.frequencyPenalty, this.timeout, this.proxy, this.responseFormat, this.logRequests, this.logResponses, this.configName, this.listeners);
        }
    }
}

