/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ai.functions;

import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import com.google.common.net.MediaType;
import com.google.inject.Inject;
import io.airlift.http.client.BodyGenerator;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.http.client.JsonBodyGenerator;
import io.airlift.http.client.JsonResponseHandler;
import io.airlift.http.client.Request;
import io.airlift.http.client.ResponseHandler;
import io.airlift.json.JsonCodec;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Scope;
import io.opentelemetry.semconv.incubating.GenAiIncubatingAttributes;
import io.trino.plugin.ai.functions.AbstractAiClient;
import io.trino.plugin.ai.functions.AiConfig;
import io.trino.plugin.ai.functions.AiErrorCode;
import io.trino.plugin.ai.functions.AnthropicConfig;
import io.trino.plugin.ai.functions.ForAiClient;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.TrinoException;
import java.net.URI;
import java.util.List;
import java.util.Objects;

public class AnthropicClient
extends AbstractAiClient {
    private static final JsonCodec<MessageRequest> MESSAGE_REQUEST_CODEC = JsonCodec.jsonCodec(MessageRequest.class);
    private static final JsonCodec<MessageResponse> MESSAGE_RESPONSE_CODEC = JsonCodec.jsonCodec(MessageResponse.class);
    private final HttpClient httpClient;
    private final Tracer tracer;
    private final URI endpoint;
    private final String apiKey;

    @Inject
    public AnthropicClient(@ForAiClient HttpClient httpClient, Tracer tracer, AnthropicConfig anthropicConfig, AiConfig aiConfig) {
        super(aiConfig);
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.tracer = Objects.requireNonNull(tracer, "tracer is null");
        this.endpoint = anthropicConfig.getEndpoint();
        this.apiKey = anthropicConfig.getApiKey();
    }

    @Override
    protected String generateCompletion(String model, String prompt) {
        MessageResponse response;
        URI uri = HttpUriBuilder.uriBuilderFrom((URI)this.endpoint).appendPath("/v1/messages").build();
        MessageRequest.Message messages = new MessageRequest.Message("user", prompt);
        MessageRequest body = new MessageRequest(model, 4096, List.of(messages));
        Request request = Request.Builder.preparePost().setUri(uri).setHeader("X-Api-Key", this.apiKey).setHeader("Anthropic-Version", "2023-06-01").setHeader("Content-Type", MediaType.JSON_UTF_8.toString()).setBodyGenerator((BodyGenerator)JsonBodyGenerator.jsonBodyGenerator(MESSAGE_REQUEST_CODEC, (Object)body)).build();
        Span span = this.tracer.spanBuilder("chat " + model).setAttribute(GenAiIncubatingAttributes.GEN_AI_OPERATION_NAME, (Object)"chat").setAttribute(GenAiIncubatingAttributes.GEN_AI_SYSTEM, (Object)"anthropic").setAttribute(GenAiIncubatingAttributes.GEN_AI_REQUEST_MODEL, (Object)model).setSpanKind(SpanKind.CLIENT).startSpan();
        try (Scope scope = span.makeCurrent();){
            response = (MessageResponse)this.httpClient.execute(request, (ResponseHandler)JsonResponseHandler.createJsonResponseHandler(MESSAGE_RESPONSE_CODEC));
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_ID, (Object)response.id());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_MODEL, (Object)response.model());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_INPUT_TOKENS, response.usage().inputTokens());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, response.usage().outputTokens());
        }
        catch (RuntimeException e) {
            span.setStatus(StatusCode.ERROR, e.getMessage());
            span.recordException((Throwable)e);
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "Request to AI provider at %s for model %s failed".formatted(uri, model), (Throwable)e);
        }
        finally {
            span.end();
        }
        if (response.content().isEmpty()) {
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "No response from AI provider at %s for model %s".formatted(uri, model));
        }
        return response.content().getFirst().text();
    }

    @JsonNaming(value=PropertyNamingStrategies.SnakeCaseStrategy.class)
    public record MessageRequest(String model, int maxTokens, List<Message> messages) {

        public record Message(String role, String content) {
        }
    }

    public record MessageResponse(String id, String model, List<Content> content, Usage usage) {

        @JsonNaming(value=PropertyNamingStrategies.SnakeCaseStrategy.class)
        public record Usage(int inputTokens, int outputTokens) {
        }

        public record Content(String text) {
        }
    }
}

