/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.PdfFileContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.TextFileContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.bedrock.AwsDocumentConverter;
import dev.langchain4j.model.bedrock.AwsLoggingInterceptor;
import dev.langchain4j.model.bedrock.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.ToolChoice;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.AnyToolChoice;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentFormat;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ImageSource;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;

public class BedrockChatModel
implements ChatLanguageModel {
    private final Region region;
    private final String modelId;
    private final Integer maxRetries;
    private final Duration timeout;
    private final BedrockRuntimeClient client;
    private final ChatRequestParameters defaultRequestParameters;

    public BedrockChatModel(String modelId) {
        this(BedrockChatModel.builder().modelId(modelId));
    }

    public BedrockChatModel(Builder builder) {
        this.region = (Region)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.region, (Object)Region.US_EAST_1);
        this.modelId = ValidationUtils.ensureNotBlank((String)((String)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.modelId, Objects.nonNull(builder.defaultRequestParameters) ? builder.defaultRequestParameters.modelName() : null)), (String)"modelId");
        this.maxRetries = (Integer)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.maxRetries, (Object)3);
        this.timeout = (Duration)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.timeout, (Object)Duration.ofMinutes(1L));
        BedrockRuntimeClient bedrockRuntimeClient = this.client = Objects.isNull(builder.client) ? this.createClient((Boolean)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.logRequests, (Object)false), (Boolean)dev.langchain4j.internal.Utils.getOrDefault((Object)builder.logResponses, (Object)false)) : builder.client;
        if (builder.defaultRequestParameters != null) {
            BedrockChatModel.validate(builder.defaultRequestParameters);
            this.defaultRequestParameters = ChatRequestParameters.builder().overrideWith(builder.defaultRequestParameters).modelName(this.modelId).build();
        } else {
            this.defaultRequestParameters = ChatRequestParameters.builder().modelName(this.modelId).build();
        }
    }

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        return this.generate(messages, Collections.emptyList());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        return this.generate(messages, List.of(toolSpecification));
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        ConverseRequest request = this.buildConverseRequest(messages, toolSpecifications, null);
        ConverseResponse response = (ConverseResponse)RetryUtils.withRetry(() -> this.client.converse(request), (int)this.maxRetries);
        return Response.from((Object)this.aiMessageFrom(response), (TokenUsage)this.tokenUsageFrom(response.usage()), (FinishReason)this.finishReasonFrom(response.stopReason()), Map.of("id", response.responseMetadata().requestId()));
    }

    public ChatResponse chat(ChatRequest request) {
        ConverseRequest convRequest = this.buildConverseRequest(request.messages(), request.parameters().toolSpecifications(), request.parameters());
        ConverseResponse response = (ConverseResponse)RetryUtils.withRetry(() -> this.client.converse(convRequest), (int)this.maxRetries);
        return ChatResponse.builder().aiMessage(this.aiMessageFrom(response)).metadata(ChatResponseMetadata.builder().id(response.responseMetadata().requestId()).finishReason(this.finishReasonFrom(response.stopReason())).tokenUsage(this.tokenUsageFrom(response.usage())).modelName(convRequest.modelId()).build()).build();
    }

    private ConverseRequest buildConverseRequest(List<ChatMessage> messages, List<ToolSpecification> toolSpecs, ChatRequestParameters parameters) {
        String model;
        String string = model = Objects.isNull(parameters) || Objects.isNull(parameters.modelName()) ? this.modelId : parameters.modelName();
        if (Objects.nonNull(parameters)) {
            BedrockChatModel.validate(parameters);
        }
        return (ConverseRequest)ConverseRequest.builder().modelId(model).inferenceConfig(this.inferenceConfigurationFrom(parameters)).system(this.extractSystemMessages(messages)).messages(this.extractRegularMessages(messages)).toolConfig(this.extractToolConfigurationFrom(toolSpecs, parameters)).build();
    }

    static void validate(ChatRequestParameters parameters) {
        String errorTemplate = "%s is not supported yet by this model provider";
        if (parameters.topK() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'topK' parameter"));
        }
        if (parameters.frequencyPenalty() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'frequencyPenalty' parameter"));
        }
        if (parameters.presencePenalty() != null) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "'presencePenalty' parameter"));
        }
        if (Objects.nonNull(parameters.responseFormat()) && parameters.responseFormat().type().equals((Object)ResponseFormatType.JSON)) {
            throw new UnsupportedFeatureException(String.format(errorTemplate, "JSON response format"));
        }
    }

    private List<SystemContentBlock> extractSystemMessages(List<ChatMessage> messages) {
        return messages.stream().filter(message -> message.type() == ChatMessageType.SYSTEM).map(message -> (SystemContentBlock)SystemContentBlock.builder().text(((SystemMessage)message).text()).build()).toList();
    }

    private List<Message> extractRegularMessages(List<ChatMessage> messages) {
        ArrayList<Message> bedrockMessages = new ArrayList<Message>();
        ArrayList<ContentBlock> currentBlocks = new ArrayList<ContentBlock>();
        for (int i = 0; i < messages.size(); ++i) {
            ChatMessage msg = messages.get(i);
            if (msg instanceof ToolExecutionResultMessage) {
                ToolExecutionResultMessage toolResult = (ToolExecutionResultMessage)msg;
                this.handleToolResult(toolResult, currentBlocks, bedrockMessages, i, messages);
                continue;
            }
            if (msg instanceof SystemMessage) continue;
            bedrockMessages.add(this.convertToBedRockMessage(msg));
        }
        return bedrockMessages;
    }

    private void handleToolResult(ToolExecutionResultMessage toolResult, List<ContentBlock> blocks, List<Message> bedrockMessages, int currentIndex, List<ChatMessage> allMessages) {
        boolean isLastOrNextIsNotToolResult;
        blocks.add(this.createToolResultBlock(toolResult));
        boolean bl = isLastOrNextIsNotToolResult = currentIndex + 1 >= allMessages.size() || !(allMessages.get(currentIndex + 1) instanceof ToolExecutionResultMessage);
        if (isLastOrNextIsNotToolResult) {
            bedrockMessages.add((Message)Message.builder().role(ConversationRole.USER).content(blocks).build());
            blocks.clear();
        }
    }

    private ContentBlock createToolResultBlock(ToolExecutionResultMessage toolResult) {
        return (ContentBlock)ContentBlock.builder().toolResult((ToolResultBlock)ToolResultBlock.builder().toolUseId(toolResult.id()).content(new ToolResultContentBlock[]{(ToolResultContentBlock)ToolResultContentBlock.builder().text(toolResult.text()).build()}).build()).build();
    }

    private Message convertToBedRockMessage(ChatMessage message) {
        if (message instanceof UserMessage) {
            UserMessage userMsg = (UserMessage)message;
            return this.createUserMessage(userMsg);
        }
        if (message instanceof AiMessage) {
            AiMessage aiMsg = (AiMessage)message;
            return this.createAiMessage(aiMsg);
        }
        throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getClass()));
    }

    private Message createUserMessage(UserMessage message) {
        return (Message)Message.builder().role(ConversationRole.USER).content(this.convertContents(message.contents())).build();
    }

    private Message createAiMessage(AiMessage message) {
        ArrayList<ContentBlock> blocks = new ArrayList<ContentBlock>();
        if (message.text() != null) {
            blocks.add((ContentBlock)ContentBlock.builder().text(message.text()).build());
        }
        if (message.hasToolExecutionRequests()) {
            blocks.addAll(this.convertToolRequests(message.toolExecutionRequests()));
        }
        return (Message)Message.builder().role(ConversationRole.ASSISTANT).content(blocks).build();
    }

    private List<ContentBlock> convertToolRequests(List<ToolExecutionRequest> requests) {
        return requests.stream().map(req -> (ContentBlock)ContentBlock.builder().toolUse((ToolUseBlock)ToolUseBlock.builder().name(req.name()).toolUseId(req.id()).input(AwsDocumentConverter.documentFromJson(req.arguments())).build()).build()).toList();
    }

    private List<ContentBlock> convertContents(List<Content> contents) {
        if (contents == null || contents.isEmpty()) {
            return Collections.emptyList();
        }
        return contents.stream().map(this::convertContent).toList();
    }

    private ContentBlock convertContent(Content content) {
        if (content instanceof TextContent) {
            TextContent text = (TextContent)content;
            return (ContentBlock)ContentBlock.builder().text(text.text()).build();
        }
        if (content instanceof TextFileContent) {
            TextFileContent textFileContent = (TextFileContent)content;
            SdkBytes bytes = SdkBytes.fromByteArray((byte[])(Objects.nonNull(textFileContent.textFile().base64Data()) ? Base64.getDecoder().decode(textFileContent.textFile().base64Data()) : dev.langchain4j.internal.Utils.readBytes((String)String.valueOf(textFileContent.textFile().url()))));
            return (ContentBlock)ContentBlock.builder().document((DocumentBlock)DocumentBlock.builder().format(DocumentFormat.TXT).source((DocumentSource)DocumentSource.builder().bytes(bytes).build()).name(BedrockChatModel.extractFilenameWithoutExtensionFromUri(textFileContent.textFile().url())).build()).build();
        }
        if (content instanceof PdfFileContent) {
            PdfFileContent pdfFileContent = (PdfFileContent)content;
            SdkBytes bytes = SdkBytes.fromByteArray((byte[])(Objects.nonNull(pdfFileContent.pdfFile().base64Data()) ? Base64.getDecoder().decode(pdfFileContent.pdfFile().base64Data()) : dev.langchain4j.internal.Utils.readBytes((String)String.valueOf(pdfFileContent.pdfFile().url()))));
            return (ContentBlock)ContentBlock.builder().document((DocumentBlock)DocumentBlock.builder().format(DocumentFormat.PDF).source((DocumentSource)DocumentSource.builder().bytes(bytes).build()).name(BedrockChatModel.extractFilenameWithoutExtensionFromUri(pdfFileContent.pdfFile().url())).build()).build();
        }
        if (content instanceof ImageContent) {
            ImageContent image = (ImageContent)content;
            return this.createImageBlock(image);
        }
        throw new IllegalArgumentException("Unsupported content type: " + String.valueOf(content.getClass()));
    }

    private static String extractFilenameWithoutExtensionFromUri(URI uri) {
        String extractedCleanFileName = Utils.extractCleanFileName(uri);
        if (dev.langchain4j.internal.Utils.isNullOrEmpty((String)extractedCleanFileName)) {
            extractedCleanFileName = UUID.randomUUID().toString();
        }
        return extractedCleanFileName;
    }

    private ContentBlock createImageBlock(ImageContent imageContent) {
        SdkBytes bytes = SdkBytes.fromByteArray((byte[])(Objects.nonNull(imageContent.image().base64Data()) ? Base64.getDecoder().decode(imageContent.image().base64Data()) : dev.langchain4j.internal.Utils.readBytes((String)String.valueOf(imageContent.image().url()))));
        String imgFormat = Utils.extractAndValidateFormat(imageContent.image());
        return (ContentBlock)ContentBlock.builder().image((ImageBlock)ImageBlock.builder().format(imgFormat).source((ImageSource)ImageSource.builder().bytes(bytes).build()).build()).build();
    }

    private ToolConfiguration extractToolConfigurationFrom(List<ToolSpecification> toolSpecifications, ChatRequestParameters parameters) {
        ArrayList<Tool> allTools = new ArrayList<Tool>();
        ToolConfiguration.Builder toolConfigurationBuilder = ToolConfiguration.builder();
        if (Objects.nonNull(toolSpecifications) && !toolSpecifications.isEmpty()) {
            List<Tool> tools = toolSpecifications.stream().map(toolSpecification -> {
                ToolInputSchema toolInputSchema = (ToolInputSchema)ToolInputSchema.builder().json(AwsDocumentConverter.convertJsonObjectSchemaToDocument(toolSpecification)).build();
                return (software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification)software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification.builder().name(toolSpecification.name()).description(toolSpecification.description()).inputSchema(toolInputSchema).build();
            }).map(toolSpecification -> (Tool)Tool.builder().toolSpec(toolSpecification).build()).toList();
            allTools.addAll(tools);
        }
        if (allTools.isEmpty()) {
            return null;
        }
        toolConfigurationBuilder.tools(allTools);
        if (Objects.nonNull(parameters) && ToolChoice.REQUIRED.equals((Object)parameters.toolChoice())) {
            toolConfigurationBuilder.toolChoice(software.amazon.awssdk.services.bedrockruntime.model.ToolChoice.fromAny((AnyToolChoice)((AnyToolChoice)AnyToolChoice.builder().build())));
        }
        return (ToolConfiguration)toolConfigurationBuilder.build();
    }

    private AiMessage aiMessageFrom(ConverseResponse converseResponse) {
        ArrayList<ToolExecutionRequest> toolExecRequests = new ArrayList<ToolExecutionRequest>();
        String textAnswer = "";
        for (ContentBlock cBlock : converseResponse.output().message().content()) {
            if (cBlock.type() == ContentBlock.Type.TOOL_USE) {
                toolExecRequests.add(ToolExecutionRequest.builder().name(cBlock.toolUse().name()).id(cBlock.toolUse().toolUseId()).arguments(AwsDocumentConverter.documentToJson(cBlock.toolUse().input())).build());
                continue;
            }
            if (cBlock.type() == ContentBlock.Type.TEXT) {
                textAnswer = cBlock.text();
                continue;
            }
            throw new IllegalArgumentException("Unsupported content in LLM response. Content type: " + String.valueOf(cBlock.type()));
        }
        if (!toolExecRequests.isEmpty()) {
            if (dev.langchain4j.internal.Utils.isNullOrEmpty((String)textAnswer)) {
                return AiMessage.aiMessage(toolExecRequests);
            }
            return AiMessage.aiMessage((String)textAnswer, toolExecRequests);
        }
        return AiMessage.aiMessage((String)textAnswer);
    }

    private TokenUsage tokenUsageFrom(software.amazon.awssdk.services.bedrockruntime.model.TokenUsage tokenUsage) {
        return Optional.ofNullable(tokenUsage).map(usage -> new TokenUsage(usage.inputTokens(), usage.outputTokens(), usage.totalTokens())).orElseGet(TokenUsage::new);
    }

    private FinishReason finishReasonFrom(StopReason stopReason) {
        if (stopReason == StopReason.END_TURN || stopReason == StopReason.STOP_SEQUENCE) {
            return FinishReason.STOP;
        }
        if (stopReason == StopReason.MAX_TOKENS) {
            return FinishReason.LENGTH;
        }
        if (stopReason == StopReason.TOOL_USE) {
            return FinishReason.TOOL_EXECUTION;
        }
        throw new IllegalArgumentException("Unknown stop reason: " + String.valueOf(stopReason));
    }

    public static Builder builder() {
        return new Builder();
    }

    private BedrockRuntimeClient createClient(boolean logRequests, boolean logResponses) {
        return (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().region(this.region)).credentialsProvider((AwsCredentialsProvider)DefaultCredentialsProvider.create())).overrideConfiguration(config -> {
            config.apiCallTimeout(this.timeout);
            if (logRequests || logResponses) {
                config.addExecutionInterceptor((ExecutionInterceptor)new AwsLoggingInterceptor(logRequests, logResponses));
            }
        })).build();
    }

    private InferenceConfiguration inferenceConfigurationFrom(ChatRequestParameters chatRequestParameters) {
        if (Objects.nonNull(chatRequestParameters)) {
            return (InferenceConfiguration)InferenceConfiguration.builder().maxTokens((Integer)dev.langchain4j.internal.Utils.getOrDefault((Object)chatRequestParameters.maxOutputTokens(), (Object)this.defaultRequestParameters.maxOutputTokens())).temperature(BedrockChatModel.dblToFloat((Double)dev.langchain4j.internal.Utils.getOrDefault((Object)chatRequestParameters.temperature(), (Object)this.defaultRequestParameters.temperature()))).topP(BedrockChatModel.dblToFloat((Double)dev.langchain4j.internal.Utils.getOrDefault((Object)chatRequestParameters.topP(), (Object)this.defaultRequestParameters.topP()))).stopSequences((Collection)dev.langchain4j.internal.Utils.getOrDefault((List)chatRequestParameters.stopSequences(), (List)this.defaultRequestParameters.stopSequences())).build();
        }
        return (InferenceConfiguration)InferenceConfiguration.builder().maxTokens(this.defaultRequestParameters.maxOutputTokens()).temperature(BedrockChatModel.dblToFloat(this.defaultRequestParameters.temperature())).topP(BedrockChatModel.dblToFloat(this.defaultRequestParameters.topP())).stopSequences((Collection)this.defaultRequestParameters.stopSequences()).build();
    }

    public static Float dblToFloat(Double d) {
        if (Objects.isNull(d)) {
            return null;
        }
        return Float.valueOf(d.floatValue());
    }

    public static class Builder {
        private Region region;
        private String modelId;
        private Integer maxRetries;
        private Duration timeout;
        private BedrockRuntimeClient client;
        private ChatRequestParameters defaultRequestParameters;
        private Boolean logRequests;
        private Boolean logResponses;

        public Builder region(Region region) {
            this.region = region;
            return this;
        }

        public Builder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public Builder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder client(BedrockRuntimeClient client) {
            this.client = client;
            return this;
        }

        public Builder logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public Builder logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public Builder defaultRequestParameters(ChatRequestParameters defaultRequestParameters) {
            this.defaultRequestParameters = defaultRequestParameters;
            return this;
        }

        public BedrockChatModel build() {
            return new BedrockChatModel(this);
        }
    }
}

