/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.model.chatglm;

import dev.langchain4j.community.model.chatglm.ChatCompletionRequest;
import dev.langchain4j.community.model.chatglm.ChatCompletionResponse;
import dev.langchain4j.community.model.chatglm.ChatGlmClient;
import dev.langchain4j.community.model.chatglm.spi.ChatGlmChatModelBuilderFactory;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.spi.ServiceHelper;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

public class ChatGlmChatModel
implements ChatLanguageModel {
    private final ChatGlmClient client;
    private final Double temperature;
    private final Double topP;
    private final Integer maxLength;
    private final Integer maxRetries;

    public ChatGlmChatModel(String baseUrl, Duration timeout, Double temperature, Integer maxRetries, Double topP, Integer maxLength, boolean logRequests, boolean logResponses) {
        baseUrl = (String)ValidationUtils.ensureNotNull((Object)baseUrl, (String)"baseUrl");
        timeout = (Duration)Utils.getOrDefault((Object)timeout, (Object)Duration.ofSeconds(60L));
        this.temperature = (Double)Utils.getOrDefault((Object)temperature, (Object)0.7);
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)3);
        this.topP = topP;
        this.maxLength = maxLength;
        this.client = ChatGlmClient.builder().baseUrl(baseUrl).timeout(timeout).logRequests(logRequests).logResponses(logResponses).build();
    }

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        ChatMessage lastMessage = messages.get(messages.size() - 1);
        if (!(lastMessage instanceof UserMessage)) {
            throw new RuntimeException("Last message must be UserMessage, but is: " + String.valueOf(lastMessage.type()));
        }
        UserMessage userMessage = (UserMessage)lastMessage;
        String prompt = userMessage.singleText();
        List<List<String>> history = this.toHistory(messages.subList(0, messages.size() - 1));
        ChatCompletionRequest request = ChatCompletionRequest.builder().prompt(prompt).temperature(this.temperature).topP(this.topP).maxLength(this.maxLength).history(history).build();
        ChatCompletionResponse response = (ChatCompletionResponse)RetryUtils.withRetry(() -> this.client.chatCompletion(request), (int)this.maxRetries);
        return Response.from((Object)AiMessage.from((String)response.getResponse()));
    }

    private List<List<String>> toHistory(List<ChatMessage> historyMessages) {
        if (this.containsSystemMessage(historyMessages)) {
            throw new IllegalArgumentException("ChatGLM does not support system prompt");
        }
        if (historyMessages.size() % 2 != 0) {
            throw new IllegalArgumentException("History must be divisible by 2 because it's order User - AI - User - AI ...");
        }
        ArrayList<List<String>> history = new ArrayList<List<String>>();
        for (int i = 0; i < historyMessages.size() / 2; ++i) {
            history.add(historyMessages.subList(i * 2, i * 2 + 2).stream().map(chatMessage -> {
                if (chatMessage instanceof UserMessage) {
                    UserMessage userMessage = (UserMessage)chatMessage;
                    return userMessage.singleText();
                }
                if (chatMessage instanceof AiMessage) {
                    AiMessage aiMessage = (AiMessage)chatMessage;
                    return aiMessage.text();
                }
                if (chatMessage instanceof SystemMessage) {
                    SystemMessage systemMessage = (SystemMessage)chatMessage;
                    return systemMessage.text();
                }
                throw new RuntimeException("Unexpected message type: " + String.valueOf(chatMessage.getClass()));
            }).collect(Collectors.toList()));
        }
        return history;
    }

    private boolean containsSystemMessage(List<ChatMessage> messages) {
        return messages.stream().anyMatch(message -> message.type() == ChatMessageType.SYSTEM);
    }

    public static ChatGlmChatModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(ChatGlmChatModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            ChatGlmChatModelBuilderFactory factory = (ChatGlmChatModelBuilderFactory)iterator.next();
            return (ChatGlmChatModelBuilder)factory.get();
        }
        return new ChatGlmChatModelBuilder();
    }

    public static class ChatGlmChatModelBuilder {
        private String baseUrl;
        private Duration timeout;
        private Double temperature;
        private Integer maxRetries;
        private Double topP;
        private Integer maxLength;
        private boolean logRequests;
        private boolean logResponses;

        public ChatGlmChatModelBuilder baseUrl(String baseUrl) {
            this.baseUrl = baseUrl;
            return this;
        }

        public ChatGlmChatModelBuilder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public ChatGlmChatModelBuilder temperature(Double temperature) {
            this.temperature = temperature;
            return this;
        }

        public ChatGlmChatModelBuilder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public ChatGlmChatModelBuilder topP(Double topP) {
            this.topP = topP;
            return this;
        }

        public ChatGlmChatModelBuilder maxLength(Integer maxLength) {
            this.maxLength = maxLength;
            return this;
        }

        public ChatGlmChatModelBuilder logRequests(boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public ChatGlmChatModelBuilder logResponses(boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public ChatGlmChatModel build() {
            return new ChatGlmChatModel(this.baseUrl, this.timeout, this.temperature, this.maxRetries, this.topP, this.maxLength, this.logRequests, this.logResponses);
        }
    }
}

