/*
 * Decompiled with CFR 0.152.
 */
package ai.knowly.langtoch.llm.providers.openai;

import ai.knowly.langtoch.llm.base.chatmodel.BaseChatModel;
import ai.knowly.langtoch.llm.message.AssistantMessage;
import ai.knowly.langtoch.llm.message.BaseChatMessage;
import ai.knowly.langtoch.llm.message.Role;
import ai.knowly.langtoch.llm.message.SystemMessage;
import ai.knowly.langtoch.llm.message.UserMessage;
import ai.knowly.langtoch.llm.providers.openai.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.FluentLogger;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import java.util.List;
import javax.inject.Inject;

public class OpenAIChat
extends BaseChatModel {
    private static final FluentLogger logger = FluentLogger.forEnclosingClass();
    private final int DEFAULT_MAX_TOKENS = 2048;
    private final String DEFAULT_MODEL = "gpt-3.5-turbo";
    private final OpenAiService openAiService;
    private final ChatCompletionRequest.ChatCompletionRequestBuilder completionRequest = ChatCompletionRequest.builder().maxTokens(Integer.valueOf(2048)).model("gpt-3.5-turbo");

    @Inject
    OpenAIChat(OpenAiService openAiService) {
        this.openAiService = openAiService;
    }

    public OpenAIChat(String apiKey) {
        Utils.logPartialApiKey(logger, apiKey);
        this.openAiService = new OpenAiService(apiKey);
    }

    public OpenAIChat() {
        this.openAiService = new OpenAiService(Utils.getApiKeyFromEnv(logger));
    }

    private static ChatMessage toChatMessage(BaseChatMessage message) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setContent(message.getMessage());
        chatMessage.setRole(message.getRole().name().toLowerCase());
        return chatMessage;
    }

    public OpenAIChat setMaxTokens(int maxTokens) {
        this.completionRequest.maxTokens(Integer.valueOf(maxTokens));
        return this;
    }

    public OpenAIChat setModel(String model) {
        this.completionRequest.model(model);
        return this;
    }

    public OpenAIChat setTemperature(double temperature) {
        this.completionRequest.temperature(Double.valueOf(temperature));
        return this;
    }

    @Override
    public BaseChatMessage run(List<BaseChatMessage> messages) {
        ChatCompletionResult completion = this.openAiService.createChatCompletion(this.completionRequest.messages((List)messages.stream().map(OpenAIChat::toChatMessage).collect(ImmutableList.toImmutableList())).build());
        ChatMessage chatMessage = ((ChatCompletionChoice)completion.getChoices().get(0)).getMessage();
        if (Role.USER.name().toLowerCase().equals(chatMessage.getRole())) {
            return UserMessage.builder().setMessage(chatMessage.getContent()).build();
        }
        if (Role.SYSTEM.name().toLowerCase().equals(chatMessage.getRole())) {
            return SystemMessage.builder().setMessage(chatMessage.getContent()).build();
        }
        if (Role.ASSISTANT.name().toLowerCase().equals(chatMessage.getRole())) {
            return AssistantMessage.builder().setMessage(chatMessage.getContent()).build();
        }
        throw new RuntimeException(String.format("Unknown role %s with message: %s ", chatMessage.getRole(), chatMessage.getContent()));
    }

    @Override
    public String run(String message) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setRole(Role.USER.name().toLowerCase());
        chatMessage.setContent(message);
        ChatCompletionResult completion = this.openAiService.createChatCompletion(this.completionRequest.messages((List)ImmutableList.of((Object)chatMessage)).build());
        return ((ChatCompletionChoice)completion.getChoices().get(0)).getMessage().getContent();
    }
}

