/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.OpenAIRequestHandler;
import apoc.result.MapResult;
import apoc.util.ExtendedMapUtils;
import apoc.util.ExtendedUtil;
import apoc.util.JsonUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class OpenAI {
    public static final String API_TYPE_CONF_KEY = "apiType";
    public static final String APIKEY_CONF_KEY = "apiKey";
    public static final String JSON_PATH_CONF_KEY = "jsonPath";
    public static final String PATH_CONF_KEY = "path";
    public static final String GPT_4O_MODEL = "gpt-4o";
    public static final String FAIL_ON_ERROR_CONF = "failOnError";
    public static final String ENABLE_BACK_OFF_RETRIES_CONF_KEY = "enableBackOffRetries";
    public static final String ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY = "exponentialBackoff";
    public static final String BACK_OFF_RETRIES_CONF_KEY = "backOffRetries";
    @Context
    public ApocConfig apocConfig;
    @Context
    public URLAccessChecker urlAccessChecker;
    public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";

    static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
        apiKey = (String)configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString("apoc.openai.key", apiKey));
        boolean enableBackOffRetries = Util.toBoolean((Object)configuration.get(ENABLE_BACK_OFF_RETRIES_CONF_KEY));
        Integer backOffRetries = Util.toInteger((Object)configuration.getOrDefault(BACK_OFF_RETRIES_CONF_KEY, 5));
        boolean exponentialBackoff = Util.toBoolean((Object)configuration.get(ENABLE_EXPONENTIAL_BACK_OFF_CONF_KEY));
        if (apiKey == null || apiKey.isBlank()) {
            throw new IllegalArgumentException("API Key must not be empty");
        }
        String apiTypeString = (String)configuration.getOrDefault(API_TYPE_CONF_KEY, apocConfig.getString("apoc.ml.openai.type", OpenAIRequestHandler.Type.OPENAI.name()));
        OpenAIRequestHandler.Type type = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH));
        HashMap<String, Object> configForPayload = new HashMap<String, Object>(configuration);
        Stream.of("endpoint", API_TYPE_CONF_KEY, "apiVersion", APIKEY_CONF_KEY).forEach(configForPayload::remove);
        HashMap<String, Object> headers = new HashMap<String, Object>();
        OpenAI.handleAPIProvider(type, configuration, path, model, key, inputs, configForPayload, headers);
        path = (String)configuration.getOrDefault(PATH_CONF_KEY, path);
        OpenAIRequestHandler apiType = type.get();
        String sJsonPath = (String)configuration.getOrDefault(JSON_PATH_CONF_KEY, jsonPath);
        headers.put("Content-Type", "application/json");
        apiType.addApiKey(headers, apiKey);
        String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(configForPayload);
        String url = apiType.getFullUrl(path, configuration, apocConfig);
        return ExtendedUtil.withBackOffRetries(() -> JsonUtil.loadJson((Object)url, (Map)headers, (String)payload, (String)sJsonPath, (boolean)true, List.of(), (URLAccessChecker)urlAccessChecker), enableBackOffRetries, backOffRetries, exponentialBackoff, exception -> {
            if (!exception.getMessage().contains("429")) {
                throw new RuntimeException((Throwable)exception);
            }
        });
    }

    private static void handleAPIProvider(OpenAIRequestHandler.Type type, Map<String, Object> configuration, String path, String model, String key, Object inputs, HashMap<String, Object> configForPayload, Map<String, Object> headers) {
        switch (type) {
            case MIXEDBREAD_CUSTOM: {
                break;
            }
            case HUGGINGFACE: {
                configForPayload.putIfAbsent("inputs", inputs);
                configuration.putIfAbsent(PATH_CONF_KEY, "");
                headers.putIfAbsent("method", "POST");
                configuration.putIfAbsent(JSON_PATH_CONF_KEY, "$[0]");
                break;
            }
            case ANTHROPIC: {
                headers.putIfAbsent("anthropic-version", configuration.getOrDefault("anthropic-version", "2023-06-01"));
                if (path.equals("completions")) {
                    configuration.putIfAbsent(PATH_CONF_KEY, "complete");
                    configForPayload.putIfAbsent("max_tokens_to_sample", 1000);
                    configForPayload.putIfAbsent("model", "claude-2.1");
                } else {
                    configuration.putIfAbsent(PATH_CONF_KEY, "messages");
                    configForPayload.putIfAbsent("max_tokens", 1000);
                    configForPayload.putIfAbsent("model", "claude-3-5-sonnet-20240620");
                }
                configForPayload.remove("anthropic-version");
                configForPayload.put(key, inputs);
                break;
            }
            default: {
                configForPayload.putIfAbsent("model", model);
                configForPayload.put(key, inputs);
            }
        }
    }

    @Procedure(value="apoc.ml.openai.embedding")
    @Description(value="apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> getEmbedding(@Name(value="texts") List<String> texts, @Name(value="api_key") String apiKey, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        boolean failOnError = OpenAI.isFailOnError(configuration);
        if (OpenAI.checkNullInput(texts, failOnError)) {
            return Stream.empty();
        }
        if (OpenAI.checkEmptyInput(texts = texts.stream().filter(StringUtils::isNotBlank).toList(), failOnError)) {
            return Stream.empty();
        }
        return OpenAI.getEmbeddingResult(texts, apiKey, configuration, this.apocConfig, this.urlAccessChecker, (map, text) -> {
            Long index = (Long)map.get("index");
            return new EmbeddingResult(index, (String)text, (List)map.get("embedding"));
        }, m -> new EmbeddingResult(-1L, (String)m, List.of()));
    }

    static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker, BiFunction<Map, String, T> embeddingMapping, Function<String, T> nullMapping) throws JsonProcessingException, MalformedURLException {
        if (texts == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        Map<Boolean, List<String>> collect = texts.stream().collect(Collectors.groupingBy(Objects::nonNull));
        List<String> nonNullTexts = collect.get(true);
        Stream<Object> resultStream = OpenAI.executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker);
        Stream<Object> embeddingResultStream = resultStream.flatMap(v -> ((List)v).stream()).map(m -> {
            Long index = (Long)m.get("index");
            String text = (String)nonNullTexts.get(index.intValue());
            return embeddingMapping.apply((Map)m, text);
        });
        List nullTexts = collect.getOrDefault(false, List.of());
        Stream<T> nullResultStream = nullTexts.stream().map(nullMapping);
        return Stream.concat(embeddingResultStream, nullResultStream);
    }

    @Procedure(value="apoc.ml.openai.completion")
    @Description(value="apoc.ml.openai.completion(prompt, api_key, configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name(value="prompt") String prompt, @Name(value="api_key") String apiKey, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        boolean failOnError = OpenAI.isFailOnError(configuration);
        if (OpenAI.checkBlankInput(prompt, failOnError)) {
            return Stream.empty();
        }
        return OpenAI.executeRequest(apiKey, configuration, "completions", "gpt-3.5-turbo-instruct", "prompt", prompt, "$", this.apocConfig, this.urlAccessChecker).map(v -> (Map)v).map(MapResult::new);
    }

    @Procedure(value="apoc.ml.openai.chat")
    @Description(value="apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name(value="messages") List<Map<String, Object>> messages, @Name(value="api_key") String apiKey, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        boolean failOnError = OpenAI.isFailOnError(configuration);
        if (OpenAI.checkNullInput(messages, failOnError)) {
            return Stream.empty();
        }
        if (OpenAI.checkEmptyInput(messages = messages.stream().filter(ExtendedMapUtils::isNotEmpty).toList(), failOnError)) {
            return Stream.empty();
        }
        return OpenAI.executeRequest(apiKey, configuration, "chat/completions", GPT_4O_MODEL, "messages", messages, "$", this.apocConfig, this.urlAccessChecker).map(v -> (Map)v).map(MapResult::new);
    }

    private static boolean isFailOnError(Map<String, Object> configuration) {
        return Util.toBoolean((Object)configuration.getOrDefault(FAIL_ON_ERROR_CONF, true));
    }

    static boolean checkNullInput(Object input, boolean failOnError) {
        return OpenAI.checkInput(failOnError, () -> Objects.isNull(input));
    }

    static boolean checkEmptyInput(Collection<?> input, boolean failOnError) {
        return OpenAI.checkInput(failOnError, () -> input.isEmpty());
    }

    static boolean checkBlankInput(String input, boolean failOnError) {
        return OpenAI.checkInput(failOnError, () -> StringUtils.isBlank((CharSequence)input));
    }

    private static boolean checkInput(boolean failOnError, Supplier<Boolean> checkFunction) {
        if (checkFunction.get().booleanValue()) {
            if (failOnError) {
                throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
            }
            return true;
        }
        return false;
    }

    public static class EmbeddingResult {
        public final long index;
        public final String text;
        public final List<Double> embedding;

        public EmbeddingResult(long index, String text, List<Double> embedding) {
            this.index = index;
            this.text = text;
            this.embedding = embedding;
        }
    }
}

