/*
 * Decompiled with CFR 0.152.
 */
package com.datasqrl.flinkrunner.stdlib.openai;

import com.datasqrl.flinkrunner.stdlib.vector.FlinkVectorType;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ObjectNode;

public class OpenAiEmbeddings {
    private static final int TOKEN_LIMIT = 8192;
    private static final ObjectMapper objectMapper = new ObjectMapper();
    private final HttpClient httpClient;

    public OpenAiEmbeddings() {
        this(HttpClient.newHttpClient());
    }

    public OpenAiEmbeddings(HttpClient httpClient) {
        this.httpClient = httpClient;
    }

    public FlinkVectorType vectorEmbed(String text, String modelName) throws IOException, InterruptedException {
        if (text == null || modelName == null) {
            return null;
        }
        return this.vectorEmbed(text, modelName, 8192);
    }

    public FlinkVectorType vectorEmbed(String text, String modelName, int tokenLimit) throws IOException, InterruptedException {
        text = OpenAiEmbeddings.truncateText(text, tokenLimit);
        ObjectNode requestBody = objectMapper.createObjectNode();
        requestBody.put("input", text);
        requestBody.put("model", modelName);
        HttpRequest request = HttpRequest.newBuilder().uri(URI.create(Optional.ofNullable(System.getenv("OPENAI_EMBEDDING_API_URL")).orElse("https://api.openai.com/v1/embeddings"))).header("Authorization", "Bearer " + System.getenv("OPENAI_API_KEY")).header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(requestBody.toString(), StandardCharsets.UTF_8)).build();
        HttpResponse<String> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofString());
        if (response.statusCode() == 200) {
            return OpenAiEmbeddings.parseEmbeddingVector(response.body());
        }
        throw new IOException(String.format("Failed to get embedding: HTTP status code %d", response.statusCode()));
    }

    private static FlinkVectorType parseEmbeddingVector(String responseBody) throws IOException {
        JsonNode jsonResponse = objectMapper.readTree(responseBody);
        ArrayNode embeddingArray = (ArrayNode)jsonResponse.get("data").get(0).get("embedding");
        double[] embeddingVector = new double[embeddingArray.size()];
        for (int i = 0; i < embeddingArray.size(); ++i) {
            embeddingVector[i] = embeddingArray.get(i).asDouble();
        }
        return new FlinkVectorType(embeddingVector);
    }

    private static String truncateText(String text, int maxTokens) {
        if (text.length() > maxTokens) {
            return text.substring(0, maxTokens);
        }
        return text;
    }
}

