/*
 * Decompiled with CFR 0.152.
 */
package io.opentelemetry.instrumentation.awssdk.v2_2.internal;

import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.common.Value;
import io.opentelemetry.api.logs.LogRecordBuilder;
import io.opentelemetry.api.logs.Logger;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.ContextKey;
import io.opentelemetry.context.ImplicitContextKeyed;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.awssdk.v2_2.internal.DocumentTypeJsonMarshaller;
import io.opentelemetry.instrumentation.awssdk.v2_2.internal.Response;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import javax.annotation.Nullable;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.document.VoidDocumentVisitor;
import software.amazon.awssdk.protocols.json.SdkJsonGenerator;
import software.amazon.awssdk.protocols.json.StructuredJsonGenerator;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import software.amazon.awssdk.thirdparty.jackson.core.JsonFactory;

public final class BedrockRuntimeImpl {
    private static final AttributeKey<String> EVENT_NAME = AttributeKey.stringKey((String)"event.name");
    private static final AttributeKey<String> GEN_AI_SYSTEM = AttributeKey.stringKey((String)"gen_ai.system");
    private static final JsonFactory JSON_FACTORY = new JsonFactory();

    private BedrockRuntimeImpl() {
    }

    static boolean isBedrockRuntimeRequest(SdkRequest request) {
        if (request instanceof ConverseRequest) {
            return true;
        }
        return request instanceof ConverseStreamRequest;
    }

    static boolean isBedrockRuntimeResponse(SdkResponse request) {
        return request instanceof ConverseResponse;
    }

    @Nullable
    static String getModelId(SdkRequest request) {
        if (request instanceof ConverseRequest) {
            return ((ConverseRequest)request).modelId();
        }
        if (request instanceof ConverseStreamRequest) {
            return ((ConverseStreamRequest)request).modelId();
        }
        return null;
    }

    @Nullable
    static Long getMaxTokens(SdkRequest request) {
        InferenceConfiguration config = null;
        if (request instanceof ConverseRequest) {
            config = ((ConverseRequest)request).inferenceConfig();
        } else if (request instanceof ConverseStreamRequest) {
            config = ((ConverseStreamRequest)request).inferenceConfig();
        }
        if (config != null) {
            return BedrockRuntimeImpl.integerToLong(config.maxTokens());
        }
        return null;
    }

    @Nullable
    static Double getTemperature(SdkRequest request) {
        InferenceConfiguration config = null;
        if (request instanceof ConverseRequest) {
            config = ((ConverseRequest)request).inferenceConfig();
        } else if (request instanceof ConverseStreamRequest) {
            config = ((ConverseStreamRequest)request).inferenceConfig();
        }
        if (config != null) {
            return BedrockRuntimeImpl.floatToDouble(config.temperature());
        }
        return null;
    }

    @Nullable
    static Double getTopP(SdkRequest request) {
        InferenceConfiguration config = null;
        if (request instanceof ConverseRequest) {
            config = ((ConverseRequest)request).inferenceConfig();
        } else if (request instanceof ConverseStreamRequest) {
            config = ((ConverseStreamRequest)request).inferenceConfig();
        }
        if (config != null) {
            return BedrockRuntimeImpl.floatToDouble(config.topP());
        }
        return null;
    }

    @Nullable
    static List<String> getStopSequences(SdkRequest request) {
        InferenceConfiguration config = null;
        if (request instanceof ConverseRequest) {
            config = ((ConverseRequest)request).inferenceConfig();
        } else if (request instanceof ConverseStreamRequest) {
            config = ((ConverseStreamRequest)request).inferenceConfig();
        }
        if (config != null) {
            return config.stopSequences();
        }
        return null;
    }

    @Nullable
    static List<String> getStopReasons(Response response) {
        SdkResponse sdkResponse = response.getSdkResponse();
        if (sdkResponse instanceof ConverseResponse) {
            StopReason reason = ((ConverseResponse)sdkResponse).stopReason();
            if (reason != null) {
                return Collections.singletonList(reason.toString());
            }
        } else {
            TracingConverseStreamResponseHandler streamHandler = TracingConverseStreamResponseHandler.fromContext(response.otelContext());
            if (streamHandler != null) {
                return streamHandler.stopReasons;
            }
        }
        return null;
    }

    @Nullable
    static Long getUsageInputTokens(Response response) {
        SdkResponse sdkResponse = response.getSdkResponse();
        TokenUsage usage = null;
        if (sdkResponse instanceof ConverseResponse) {
            usage = ((ConverseResponse)sdkResponse).usage();
        } else {
            TracingConverseStreamResponseHandler streamHandler = TracingConverseStreamResponseHandler.fromContext(response.otelContext());
            if (streamHandler != null) {
                usage = streamHandler.usage;
            }
        }
        if (usage != null) {
            return BedrockRuntimeImpl.integerToLong(usage.inputTokens());
        }
        return null;
    }

    @Nullable
    static Long getUsageOutputTokens(Response response) {
        SdkResponse sdkResponse = response.getSdkResponse();
        TokenUsage usage = null;
        if (sdkResponse instanceof ConverseResponse) {
            usage = ((ConverseResponse)sdkResponse).usage();
        } else {
            TracingConverseStreamResponseHandler streamHandler = TracingConverseStreamResponseHandler.fromContext(response.otelContext());
            if (streamHandler != null) {
                usage = streamHandler.usage;
            }
        }
        if (usage != null) {
            return BedrockRuntimeImpl.integerToLong(usage.outputTokens());
        }
        return null;
    }

    static void recordRequestEvents(Context otelContext, Logger eventLogger, SdkRequest request, boolean captureMessageContent) {
        if (request instanceof ConverseRequest) {
            block4: for (Message message : ((ConverseRequest)request).messages()) {
                long numToolResults = message.content().stream().filter(block -> block.toolResult() != null).count();
                if (numToolResults > 0L) {
                    BedrockRuntimeImpl.emitToolResultEvents(otelContext, eventLogger, message, captureMessageContent);
                    if (numToolResults == (long)message.content().size()) continue;
                }
                LogRecordBuilder event = BedrockRuntimeImpl.newEvent(otelContext, eventLogger);
                switch (message.role()) {
                    case ASSISTANT: {
                        event.setAttribute(EVENT_NAME, (Object)"gen_ai.assistant.message");
                        break;
                    }
                    case USER: {
                        event.setAttribute(EVENT_NAME, (Object)"gen_ai.user.message");
                        break;
                    }
                    default: {
                        continue block4;
                    }
                }
                event.setBody(BedrockRuntimeImpl.convertMessage(message, -1, null, captureMessageContent)).emit();
            }
        }
    }

    static void recordResponseEvents(Context otelContext, Logger eventLogger, SdkResponse response, boolean captureMessageContent) {
        if (response instanceof ConverseResponse) {
            ConverseResponse converseResponse = (ConverseResponse)response;
            BedrockRuntimeImpl.newEvent(otelContext, eventLogger).setAttribute(EVENT_NAME, (Object)"gen_ai.choice").setBody(BedrockRuntimeImpl.convertMessage(converseResponse.output().message(), 0, converseResponse.stopReason(), captureMessageContent)).emit();
        }
    }

    @Nullable
    private static Long integerToLong(Integer value) {
        if (value == null) {
            return null;
        }
        return (long)value;
    }

    @Nullable
    private static Double floatToDouble(Float value) {
        if (value == null) {
            return null;
        }
        return value.floatValue();
    }

    public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClient) {
        return (BedrockRuntimeAsyncClient)Proxy.newProxyInstance(asyncClient.getClass().getClassLoader(), new Class[]{BedrockRuntimeAsyncClient.class}, (proxy, method, args) -> {
            if (method.getName().equals("converseStream") && args.length >= 2 && args[1] instanceof ConverseStreamResponseHandler) {
                TracingConverseStreamResponseHandler wrapped = new TracingConverseStreamResponseHandler((ConverseStreamResponseHandler)args[1]);
                args[1] = wrapped;
                try (Scope ignored = wrapped.makeCurrent();){
                    Object object = BedrockRuntimeImpl.invokeProxyMethod(method, asyncClient, args);
                    return object;
                }
            }
            return BedrockRuntimeImpl.invokeProxyMethod(method, asyncClient, args);
        });
    }

    private static Object invokeProxyMethod(Method method, Object target, Object[] args) throws Throwable {
        try {
            return method.invoke(target, args);
        }
        catch (InvocationTargetException exception) {
            throw exception.getCause();
        }
    }

    private static LogRecordBuilder newEvent(Context otelContext, Logger eventLogger) {
        return eventLogger.logRecordBuilder().setContext(otelContext).setAttribute(GEN_AI_SYSTEM, (Object)"aws.bedrock");
    }

    private static void emitToolResultEvents(Context otelContext, Logger eventLogger, Message message, boolean captureMessageContent) {
        for (ContentBlock content : message.content()) {
            if (content.toolResult() == null) continue;
            HashMap<String, Value> body = new HashMap<String, Value>();
            body.put("id", Value.of((String)content.toolResult().toolUseId()));
            if (captureMessageContent) {
                StringBuilder text = new StringBuilder();
                for (ToolResultContentBlock toolContent : content.toolResult().content()) {
                    if (toolContent.text() != null) {
                        text.append(toolContent.text());
                    }
                    if (toolContent.json() == null) continue;
                    text.append(BedrockRuntimeImpl.serializeDocument(toolContent.json()));
                }
                body.put("content", Value.of((String)text.toString()));
            }
            BedrockRuntimeImpl.newEvent(otelContext, eventLogger).setAttribute(EVENT_NAME, (Object)"gen_ai.tool.message").setBody(Value.of(body)).emit();
        }
    }

    private static Value<?> convertMessage(Message message, int index, @Nullable StopReason stopReason, boolean captureMessageContent) {
        StringBuilder text = null;
        ArrayList toolCalls = null;
        for (ContentBlock content : message.content()) {
            if (captureMessageContent && content.text() != null) {
                if (text == null) {
                    text = new StringBuilder();
                }
                text.append(content.text());
            }
            if (content.toolUse() == null) continue;
            if (toolCalls == null) {
                toolCalls = new ArrayList();
            }
            toolCalls.add(BedrockRuntimeImpl.convertToolCall(content.toolUse(), captureMessageContent));
        }
        HashMap<String, Value> body = new HashMap<String, Value>();
        if (text != null) {
            body.put("content", Value.of((String)text.toString()));
        }
        if (toolCalls != null) {
            body.put("toolCalls", Value.of(toolCalls));
        }
        if (stopReason != null) {
            body.put("finish_reason", Value.of((String)stopReason.toString()));
        }
        if (index >= 0) {
            body.put("index", Value.of((long)index));
        }
        return Value.of(body);
    }

    private static Value<?> convertToolCall(ToolUseBlock toolCall, boolean captureMessageContent) {
        HashMap<String, Value> body = new HashMap<String, Value>();
        body.put("id", Value.of((String)toolCall.toolUseId()));
        body.put("name", Value.of((String)toolCall.name()));
        body.put("type", Value.of((String)"function"));
        if (captureMessageContent) {
            body.put("arguments", Value.of((String)BedrockRuntimeImpl.serializeDocument(toolCall.input())));
        }
        return Value.of(body);
    }

    private static String serializeDocument(Document document) {
        SdkJsonGenerator generator = new SdkJsonGenerator(JSON_FACTORY, "application/json");
        DocumentTypeJsonMarshaller marshaller = new DocumentTypeJsonMarshaller((StructuredJsonGenerator)generator);
        document.accept((VoidDocumentVisitor)marshaller);
        return new String(generator.getBytes(), StandardCharsets.UTF_8);
    }

    public static class TracingConverseStreamResponseHandler
    implements ConverseStreamResponseHandler,
    ImplicitContextKeyed {
        private static final ContextKey<TracingConverseStreamResponseHandler> KEY = ContextKey.named((String)"bedrock-runtime-converse-stream-response-handler");
        private final ConverseStreamResponseHandler delegate;
        List<String> stopReasons;
        TokenUsage usage;

        @Nullable
        public static TracingConverseStreamResponseHandler fromContext(Context context) {
            return (TracingConverseStreamResponseHandler)context.get(KEY);
        }

        TracingConverseStreamResponseHandler(ConverseStreamResponseHandler delegate) {
            this.delegate = delegate;
        }

        public void responseReceived(ConverseStreamResponse converseStreamResponse) {
            this.delegate.responseReceived((Object)converseStreamResponse);
        }

        public void onEventStream(SdkPublisher<ConverseStreamOutput> sdkPublisher) {
            this.delegate.onEventStream(sdkPublisher.map(event -> {
                if (event instanceof MessageStopEvent) {
                    if (this.stopReasons == null) {
                        this.stopReasons = new ArrayList<String>();
                    }
                    this.stopReasons.add(((MessageStopEvent)event).stopReasonAsString());
                }
                if (event instanceof ConverseStreamMetadataEvent) {
                    this.usage = ((ConverseStreamMetadataEvent)event).usage();
                }
                return event;
            }));
        }

        public void exceptionOccurred(Throwable throwable) {
            this.delegate.exceptionOccurred(throwable);
        }

        public void complete() {
            this.delegate.complete();
        }

        public Context storeInContext(Context context) {
            return context.with(KEY, (Object)this);
        }
    }
}

