/*
 * Decompiled with CFR 0.152.
 */
package com.datasqrl.flinkrunner.stdlib.openai;

import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import lombok.Generated;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode;

public class OpenAiCompletions {
    private static final double TEMPERATURE_DEFAULT = 1.0;
    private static final double TOP_P_DEFAULT = 1.0;
    private static final ObjectMapper objectMapper = new ObjectMapper();
    private final HttpClient httpClient;

    public OpenAiCompletions() {
        this(HttpClient.newHttpClient());
    }

    public OpenAiCompletions(HttpClient httpClient) {
        this.httpClient = httpClient;
    }

    public String callCompletions(CompletionsRequest request) throws IOException, InterruptedException {
        if (request.getPrompt() == null || request.getModelName() == null) {
            return null;
        }
        ObjectNode requestBody = this.createRequestBody(request);
        HttpRequest httpRequest = HttpRequest.newBuilder().uri(URI.create(Optional.ofNullable(System.getenv("OPENAI_COMPLETIONS_API_URL")).orElse("https://api.openai.com/v1/chat/completions"))).header("Authorization", "Bearer " + System.getenv("OPENAI_API_KEY")).header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(requestBody.toString(), StandardCharsets.UTF_8)).build();
        HttpResponse<String> response = this.httpClient.send(httpRequest, HttpResponse.BodyHandlers.ofString());
        if (response.statusCode() == 200) {
            return this.extractContent(response.body());
        }
        throw new IOException(String.format("Failed to get completion: HTTP status code %d Message: %s", response.statusCode(), response.body()));
    }

    private ObjectNode createRequestBody(CompletionsRequest request) {
        ObjectNode requestBody = objectMapper.createObjectNode();
        requestBody.put("model", request.getModelName());
        ArrayNode messagesArray = objectMapper.createArrayNode();
        if (request.isRequireJsonOutput()) {
            if (request.getJsonSchema() != null) {
                JsonNode schemaNode;
                try {
                    schemaNode = objectMapper.readTree(request.getJsonSchema());
                }
                catch (JsonProcessingException e) {
                    throw new RuntimeException("Failed to parse JSON schema", e);
                }
                requestBody.putObject("response_format").put("type", "json_schema").putObject("json_schema").put("name", "extract_json_schema_name").put("strict", true).set("schema", schemaNode);
            } else {
                requestBody.putObject("response_format").put("type", "json_object");
            }
            messagesArray.add((JsonNode)OpenAiCompletions.createMessage("system", "You are a helpful assistant designed to output minified JSON."));
        }
        messagesArray.add((JsonNode)OpenAiCompletions.createMessage("user", request.getPrompt()));
        requestBody.set("messages", (JsonNode)messagesArray);
        requestBody.put("temperature", request.getTemperature() == null ? 1.0 : request.getTemperature());
        requestBody.put("top_p", request.getTopP() == null ? 1.0 : request.getTopP());
        requestBody.put("n", 1);
        if (request.getMaxOutputTokens() != null) {
            requestBody.put("max_tokens", request.getMaxOutputTokens());
        }
        return requestBody;
    }

    private String extractContent(String jsonResponse) throws IOException {
        JsonNode jsonNode = objectMapper.readTree(jsonResponse);
        return jsonNode.get("choices").get(0).get("message").get("content").asText().trim();
    }

    private static ObjectNode createMessage(String role, String prompt) {
        ObjectNode userMessage = objectMapper.createObjectNode();
        userMessage.put("role", role);
        userMessage.put("content", prompt);
        return userMessage;
    }

    public static class CompletionsRequest {
        private final String prompt;
        private final String modelName;
        private final boolean requireJsonOutput;
        private final String jsonSchema;
        private final Integer maxOutputTokens;
        private final Double temperature;
        private final Double topP;

        @Generated
        CompletionsRequest(String prompt, String modelName, boolean requireJsonOutput, String jsonSchema, Integer maxOutputTokens, Double temperature, Double topP) {
            this.prompt = prompt;
            this.modelName = modelName;
            this.requireJsonOutput = requireJsonOutput;
            this.jsonSchema = jsonSchema;
            this.maxOutputTokens = maxOutputTokens;
            this.temperature = temperature;
            this.topP = topP;
        }

        @Generated
        public static CompletionsRequestBuilder builder() {
            return new CompletionsRequestBuilder();
        }

        @Generated
        public String getPrompt() {
            return this.prompt;
        }

        @Generated
        public String getModelName() {
            return this.modelName;
        }

        @Generated
        public boolean isRequireJsonOutput() {
            return this.requireJsonOutput;
        }

        @Generated
        public String getJsonSchema() {
            return this.jsonSchema;
        }

        @Generated
        public Integer getMaxOutputTokens() {
            return this.maxOutputTokens;
        }

        @Generated
        public Double getTemperature() {
            return this.temperature;
        }

        @Generated
        public Double getTopP() {
            return this.topP;
        }

        @Generated
        public static class CompletionsRequestBuilder {
            @Generated
            private String prompt;
            @Generated
            private String modelName;
            @Generated
            private boolean requireJsonOutput;
            @Generated
            private String jsonSchema;
            @Generated
            private Integer maxOutputTokens;
            @Generated
            private Double temperature;
            @Generated
            private Double topP;

            @Generated
            CompletionsRequestBuilder() {
            }

            @Generated
            public CompletionsRequestBuilder prompt(String prompt) {
                this.prompt = prompt;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder modelName(String modelName) {
                this.modelName = modelName;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder requireJsonOutput(boolean requireJsonOutput) {
                this.requireJsonOutput = requireJsonOutput;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder jsonSchema(String jsonSchema) {
                this.jsonSchema = jsonSchema;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder maxOutputTokens(Integer maxOutputTokens) {
                this.maxOutputTokens = maxOutputTokens;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder temperature(Double temperature) {
                this.temperature = temperature;
                return this;
            }

            @Generated
            public CompletionsRequestBuilder topP(Double topP) {
                this.topP = topP;
                return this;
            }

            @Generated
            public CompletionsRequest build() {
                return new CompletionsRequest(this.prompt, this.modelName, this.requireJsonOutput, this.jsonSchema, this.maxOutputTokens, this.temperature, this.topP);
            }

            @Generated
            public String toString() {
                return "OpenAiCompletions.CompletionsRequest.CompletionsRequestBuilder(prompt=" + this.prompt + ", modelName=" + this.modelName + ", requireJsonOutput=" + this.requireJsonOutput + ", jsonSchema=" + this.jsonSchema + ", maxOutputTokens=" + this.maxOutputTokens + ", temperature=" + this.temperature + ", topP=" + this.topP + ")";
            }
        }
    }
}

