/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.model.bedrock.internal.Json;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.output.Response;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;

public abstract class AbstractSharedBedrockChatModel {
    private static final Logger log = LoggerFactory.getLogger(AbstractSharedBedrockChatModel.class);
    protected static final String HUMAN_PROMPT = "Human:";
    protected static final String ASSISTANT_PROMPT = "Assistant:";
    protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";
    protected final String humanPrompt;
    protected final String assistantPrompt;
    protected final Integer maxRetries;
    protected final Region region;
    protected final AwsCredentialsProvider credentialsProvider;
    protected final int maxTokens;
    protected final double temperature;
    protected final float topP;
    protected final String[] stopSequences;
    protected final int topK;
    protected final Duration timeout;
    protected final String anthropicVersion;
    protected final List<ChatModelListener> listeners;

    protected String chatMessageToString(ChatMessage message) {
        switch (message.type()) {
            case SYSTEM: {
                return message.text();
            }
            case USER: {
                return this.humanPrompt + " " + message.text();
            }
            case AI: {
                return this.assistantPrompt + " " + message.text();
            }
            case TOOL_EXECUTION_RESULT: {
                throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
            }
        }
        throw new IllegalArgumentException("Unknown message type: " + String.valueOf(message.type()));
    }

    protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
        String context = messages.stream().filter(message -> message.type() == ChatMessageType.SYSTEM).map(ChatMessage::text).collect(Collectors.joining("\n"));
        String userMessages = messages.stream().filter(message -> message.type() != ChatMessageType.SYSTEM).map(this::chatMessageToString).collect(Collectors.joining("\n"));
        String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
        Map<String, Object> requestParameters = this.getRequestParameters(prompt);
        String body = Json.toJson(requestParameters);
        return body;
    }

    protected Map<String, Object> getRequestParameters(String prompt) {
        HashMap<String, Object> parameters = new HashMap<String, Object>(7);
        parameters.put("prompt", prompt);
        parameters.put("max_tokens_to_sample", this.getMaxTokens());
        parameters.put("temperature", this.getTemperature());
        parameters.put("top_k", this.topK);
        parameters.put("top_p", Float.valueOf(this.getTopP()));
        parameters.put("stop_sequences", this.getStopSequences());
        parameters.put("anthropic_version", this.anthropicVersion);
        return parameters;
    }

    protected void listenerErrorResponse(Throwable e, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
        Throwable error = e.getCause() instanceof SdkClientException ? e.getCause() : e;
        ChatModelErrorContext errorContext = new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
        this.listeners.forEach(listener -> {
            try {
                listener.onError(errorContext);
            }
            catch (Exception e2) {
                log.warn("Exception while calling model listener", (Throwable)e2);
            }
        });
    }

    protected ChatModelRequest createModelListenerRequest(InvokeModelRequest invokeModelRequest, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatModelRequest.builder().model(invokeModelRequest.modelId()).temperature(Double.valueOf(this.temperature)).topP(Double.valueOf(this.topP)).maxTokens(Integer.valueOf(this.maxTokens)).messages(messages).toolSpecifications(toolSpecifications).build();
    }

    protected ChatModelRequest createModelListenerRequest(InvokeModelWithResponseStreamRequest invokeModelRequest, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatModelRequest.builder().model(this.getModelId()).temperature(Double.valueOf(this.temperature)).topP(Double.valueOf(this.topP)).maxTokens(Integer.valueOf(this.maxTokens)).messages(messages).toolSpecifications(toolSpecifications).build();
    }

    protected ChatModelResponse createModelListenerResponse(String responseId, String responseModel, Response<AiMessage> response) {
        if (response == null) {
            return null;
        }
        return ChatModelResponse.builder().id(responseId).model(responseModel).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).aiMessage((AiMessage)response.content()).build();
    }

    protected abstract String getModelId();

    private static String $default$humanPrompt() {
        return HUMAN_PROMPT;
    }

    private static String $default$assistantPrompt() {
        return ASSISTANT_PROMPT;
    }

    private static Integer $default$maxRetries() {
        return 5;
    }

    private static Region $default$region() {
        return Region.US_EAST_1;
    }

    private static AwsCredentialsProvider $default$credentialsProvider() {
        return DefaultCredentialsProvider.builder().build();
    }

    private static int $default$maxTokens() {
        return 300;
    }

    private static double $default$temperature() {
        return 1.0;
    }

    private static float $default$topP() {
        return 0.999f;
    }

    private static String[] $default$stopSequences() {
        return new String[0];
    }

    private static int $default$topK() {
        return 250;
    }

    private static Duration $default$timeout() {
        return Duration.ofMinutes(1L);
    }

    private static String $default$anthropicVersion() {
        return DEFAULT_ANTHROPIC_VERSION;
    }

    private static List<ChatModelListener> $default$listeners() {
        return Collections.emptyList();
    }

    protected AbstractSharedBedrockChatModel(AbstractSharedBedrockChatModelBuilder<?, ?> b) {
        this.humanPrompt = b.humanPrompt$set ? b.humanPrompt$value : AbstractSharedBedrockChatModel.$default$humanPrompt();
        this.assistantPrompt = b.assistantPrompt$set ? b.assistantPrompt$value : AbstractSharedBedrockChatModel.$default$assistantPrompt();
        this.maxRetries = b.maxRetries$set ? b.maxRetries$value : AbstractSharedBedrockChatModel.$default$maxRetries();
        this.region = b.region$set ? b.region$value : AbstractSharedBedrockChatModel.$default$region();
        this.credentialsProvider = b.credentialsProvider$set ? b.credentialsProvider$value : AbstractSharedBedrockChatModel.$default$credentialsProvider();
        this.maxTokens = b.maxTokens$set ? b.maxTokens$value : AbstractSharedBedrockChatModel.$default$maxTokens();
        this.temperature = b.temperature$set ? b.temperature$value : AbstractSharedBedrockChatModel.$default$temperature();
        this.topP = b.topP$set ? b.topP$value : AbstractSharedBedrockChatModel.$default$topP();
        this.stopSequences = b.stopSequences$set ? b.stopSequences$value : AbstractSharedBedrockChatModel.$default$stopSequences();
        this.topK = b.topK$set ? b.topK$value : AbstractSharedBedrockChatModel.$default$topK();
        this.timeout = b.timeout$set ? b.timeout$value : AbstractSharedBedrockChatModel.$default$timeout();
        this.anthropicVersion = b.anthropicVersion$set ? b.anthropicVersion$value : AbstractSharedBedrockChatModel.$default$anthropicVersion();
        this.listeners = b.listeners$set ? b.listeners$value : AbstractSharedBedrockChatModel.$default$listeners();
    }

    public String getHumanPrompt() {
        return this.humanPrompt;
    }

    public String getAssistantPrompt() {
        return this.assistantPrompt;
    }

    public Integer getMaxRetries() {
        return this.maxRetries;
    }

    public Region getRegion() {
        return this.region;
    }

    public AwsCredentialsProvider getCredentialsProvider() {
        return this.credentialsProvider;
    }

    public int getMaxTokens() {
        return this.maxTokens;
    }

    public double getTemperature() {
        return this.temperature;
    }

    public float getTopP() {
        return this.topP;
    }

    public String[] getStopSequences() {
        return this.stopSequences;
    }

    public int getTopK() {
        return this.topK;
    }

    public Duration getTimeout() {
        return this.timeout;
    }

    public String getAnthropicVersion() {
        return this.anthropicVersion;
    }

    public List<ChatModelListener> getListeners() {
        return this.listeners;
    }

    public static abstract class AbstractSharedBedrockChatModelBuilder<C extends AbstractSharedBedrockChatModel, B extends AbstractSharedBedrockChatModelBuilder<C, B>> {
        private boolean humanPrompt$set;
        private String humanPrompt$value;
        private boolean assistantPrompt$set;
        private String assistantPrompt$value;
        private boolean maxRetries$set;
        private Integer maxRetries$value;
        private boolean region$set;
        private Region region$value;
        private boolean credentialsProvider$set;
        private AwsCredentialsProvider credentialsProvider$value;
        private boolean maxTokens$set;
        private int maxTokens$value;
        private boolean temperature$set;
        private double temperature$value;
        private boolean topP$set;
        private float topP$value;
        private boolean stopSequences$set;
        private String[] stopSequences$value;
        private boolean topK$set;
        private int topK$value;
        private boolean timeout$set;
        private Duration timeout$value;
        private boolean anthropicVersion$set;
        private String anthropicVersion$value;
        private boolean listeners$set;
        private List<ChatModelListener> listeners$value;

        public B humanPrompt(String humanPrompt) {
            this.humanPrompt$value = humanPrompt;
            this.humanPrompt$set = true;
            return this.self();
        }

        public B assistantPrompt(String assistantPrompt) {
            this.assistantPrompt$value = assistantPrompt;
            this.assistantPrompt$set = true;
            return this.self();
        }

        public B maxRetries(Integer maxRetries) {
            this.maxRetries$value = maxRetries;
            this.maxRetries$set = true;
            return this.self();
        }

        public B region(Region region) {
            this.region$value = region;
            this.region$set = true;
            return this.self();
        }

        public B credentialsProvider(AwsCredentialsProvider credentialsProvider) {
            this.credentialsProvider$value = credentialsProvider;
            this.credentialsProvider$set = true;
            return this.self();
        }

        public B maxTokens(int maxTokens) {
            this.maxTokens$value = maxTokens;
            this.maxTokens$set = true;
            return this.self();
        }

        public B temperature(double temperature) {
            this.temperature$value = temperature;
            this.temperature$set = true;
            return this.self();
        }

        public B topP(float topP) {
            this.topP$value = topP;
            this.topP$set = true;
            return this.self();
        }

        public B stopSequences(String[] stopSequences) {
            this.stopSequences$value = stopSequences;
            this.stopSequences$set = true;
            return this.self();
        }

        public B topK(int topK) {
            this.topK$value = topK;
            this.topK$set = true;
            return this.self();
        }

        public B timeout(Duration timeout) {
            this.timeout$value = timeout;
            this.timeout$set = true;
            return this.self();
        }

        public B anthropicVersion(String anthropicVersion) {
            this.anthropicVersion$value = anthropicVersion;
            this.anthropicVersion$set = true;
            return this.self();
        }

        public B listeners(List<ChatModelListener> listeners) {
            this.listeners$value = listeners;
            this.listeners$set = true;
            return this.self();
        }

        protected abstract B self();

        public abstract C build();

        public String toString() {
            return "AbstractSharedBedrockChatModel.AbstractSharedBedrockChatModelBuilder(humanPrompt$value=" + this.humanPrompt$value + ", assistantPrompt$value=" + this.assistantPrompt$value + ", maxRetries$value=" + this.maxRetries$value + ", region$value=" + String.valueOf(this.region$value) + ", credentialsProvider$value=" + String.valueOf(this.credentialsProvider$value) + ", maxTokens$value=" + this.maxTokens$value + ", temperature$value=" + this.temperature$value + ", topP$value=" + this.topP$value + ", stopSequences$value=" + Arrays.deepToString(this.stopSequences$value) + ", topK$value=" + this.topK$value + ", timeout$value=" + String.valueOf(this.timeout$value) + ", anthropicVersion$value=" + this.anthropicVersion$value + ", listeners$value=" + String.valueOf(this.listeners$value) + ")";
        }
    }
}

