/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.ai.functions;

import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import com.google.common.net.MediaType;
import com.google.inject.Inject;
import io.airlift.http.client.BodyGenerator;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.http.client.JsonBodyGenerator;
import io.airlift.http.client.JsonResponseHandler;
import io.airlift.http.client.Request;
import io.airlift.http.client.ResponseHandler;
import io.airlift.json.JsonCodec;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Scope;
import io.opentelemetry.semconv.incubating.GenAiIncubatingAttributes;
import io.trino.plugin.ai.functions.AbstractAiClient;
import io.trino.plugin.ai.functions.AiConfig;
import io.trino.plugin.ai.functions.AiErrorCode;
import io.trino.plugin.ai.functions.ForAiClient;
import io.trino.plugin.ai.functions.OpenAiConfig;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.TrinoException;
import java.net.URI;
import java.util.List;
import java.util.Objects;

public class OpenAiClient
extends AbstractAiClient {
    private static final JsonCodec<ChatRequest> CHAT_REQUEST_CODEC = JsonCodec.jsonCodec(ChatRequest.class);
    private static final JsonCodec<ChatResponse> CHAT_RESPONSE_CODEC = JsonCodec.jsonCodec(ChatResponse.class);
    private final HttpClient httpClient;
    private final Tracer tracer;
    private final URI endpoint;
    private final String apiKey;

    @Inject
    public OpenAiClient(@ForAiClient HttpClient httpClient, Tracer tracer, OpenAiConfig openAiConfig, AiConfig aiConfig) {
        super(aiConfig);
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.tracer = Objects.requireNonNull(tracer, "tracer is null");
        this.endpoint = openAiConfig.getEndpoint();
        this.apiKey = openAiConfig.getApiKey();
    }

    @Override
    protected String generateCompletion(String model, String prompt) {
        ChatResponse response;
        URI uri = HttpUriBuilder.uriBuilderFrom((URI)this.endpoint).appendPath("/v1/chat/completions").build();
        ChatRequest.Message messages = new ChatRequest.Message("user", prompt);
        ChatRequest body = new ChatRequest(model, List.of(messages), 0);
        Request request = Request.Builder.preparePost().setUri(uri).setHeader("Authorization", "Bearer " + this.apiKey).setHeader("Content-Type", MediaType.JSON_UTF_8.toString()).setBodyGenerator((BodyGenerator)JsonBodyGenerator.jsonBodyGenerator(CHAT_REQUEST_CODEC, (Object)body)).build();
        Span span = this.tracer.spanBuilder("chat " + model).setAttribute(GenAiIncubatingAttributes.GEN_AI_OPERATION_NAME, (Object)"chat").setAttribute(GenAiIncubatingAttributes.GEN_AI_SYSTEM, (Object)"openai").setAttribute(GenAiIncubatingAttributes.GEN_AI_REQUEST_MODEL, (Object)model).setAttribute(GenAiIncubatingAttributes.GEN_AI_REQUEST_SEED, body.seed()).setSpanKind(SpanKind.CLIENT).startSpan();
        try (Scope scope = span.makeCurrent();){
            response = (ChatResponse)this.httpClient.execute(request, (ResponseHandler)JsonResponseHandler.createJsonResponseHandler(CHAT_RESPONSE_CODEC));
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_ID, (Object)response.id());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_RESPONSE_MODEL, (Object)response.model());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER, (Object)response.serviceTier());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_OPENAI_RESPONSE_SYSTEM_FINGERPRINT, (Object)response.systemFingerprint());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_INPUT_TOKENS, response.usage().promptTokens());
            span.setAttribute(GenAiIncubatingAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, response.usage().completionTokens());
        }
        catch (RuntimeException e) {
            span.setStatus(StatusCode.ERROR, e.getMessage());
            span.recordException((Throwable)e);
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "Request to AI provider at %s for model %s failed".formatted(uri, model), (Throwable)e);
        }
        finally {
            span.end();
        }
        if (response.choices().isEmpty()) {
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "No response from AI provider at %s for model %s".formatted(uri, model));
        }
        ChatResponse.Choice message = response.choices().getFirst();
        if (message.message().refusal() != null) {
            throw new TrinoException((ErrorCodeSupplier)AiErrorCode.AI_ERROR, "AI provider at %s for model %s refused to generate response: %s".formatted(uri, model, message.message().refusal()));
        }
        return message.message().content();
    }

    public record ChatRequest(String model, List<Message> messages, int seed) {

        public record Message(String role, String content) {
        }
    }

    @JsonNaming(value=PropertyNamingStrategies.SnakeCaseStrategy.class)
    public record ChatResponse(String id, String model, List<Choice> choices, Usage usage, String serviceTier, String systemFingerprint) {

        @JsonNaming(value=PropertyNamingStrategies.SnakeCaseStrategy.class)
        public record Usage(int promptTokens, int completionTokens) {
        }

        public record Choice(Message message) {

            public record Message(String content, String refusal) {
            }
        }
    }
}

