/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml.aws;

import apoc.Description;
import apoc.Extended;
import apoc.ml.aws.AWSConfig;
import apoc.ml.aws.AwsSignatureV4Generator;
import apoc.ml.aws.SageMakerConfig;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class SageMaker {
    @Context
    public URLAccessChecker urlAccessChecker;

    @Procedure(value="apoc.ml.sagemaker.custom")
    @Description(value="apoc.ml.sagemaker.chat(body, $conf) - To create a customizable SageMaker call")
    public Stream<MapResult> custom(@Name(value="body") Object body, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        SageMakerConfig conf = new SageMakerConfig(configuration);
        return this.executeRequestReturningMap(body, conf).map(MapResult::new);
    }

    @Procedure(value="apoc.ml.sagemaker.chat")
    @Description(value="apoc.ml.sagemaker.chat(messages, $conf) - Prompts the chat completion API")
    public Stream<MapResult> chatCompletion(@Name(value="messages") List<Map<String, String>> messages, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("endpointName", "Endpoint-Distilbart-xsum-1-1-1");
        config.putIfAbsent("headers", Util.map((Object[])new String[]{"Content-Type", "application/x-text"}));
        SageMakerConfig conf = new SageMakerConfig(config);
        return messages.stream().flatMap(message -> {
            Map body = message.containsKey("content") ? message.get("content") : message;
            return this.executeRequestReturningMap(body, conf).map(MapResult::new);
        });
    }

    @Procedure(value="apoc.ml.sagemaker.completion")
    @Description(value="apoc.ml.sagemaker.completion(prompt, $conf) - Prompts the completion API")
    public Stream<MapResult> completion(@Name(value="prompt") String prompt, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("endpointName", "Endpoint-GPT-2-1");
        config.putIfAbsent("headers", Map.of("Content-Type", "application/x-text"));
        SageMakerConfig conf = new SageMakerConfig(config);
        return this.executeRequestReturningMap(prompt, conf).map(MapResult::new);
    }

    @Procedure(value="apoc.ml.sagemaker.embedding")
    @Description(value="apoc.ml.sagemaker.embedding([texts], $configuration) - Returns the embeddings for a given text")
    public Stream<EmbeddingResult> embedding(@Name(value="texts") List<String> texts, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("endpointName", "Endpoint-Jina-Embeddings-v2-Base-en-1");
        config.putIfAbsent("jsonPath", "data[*]");
        SageMakerConfig conf = new SageMakerConfig(config);
        List<Map> inputs = texts.stream().map(text -> Map.of("text", text)).toList();
        Map<String, List<Map>> data2 = Map.of("data", inputs);
        AtomicInteger idx = new AtomicInteger();
        return this.executeRequestCommon(data2, conf).flatMap(v -> ((List)v).stream()).map(i -> {
            int index = idx.getAndIncrement();
            return new EmbeddingResult(index, (String)texts.get(index), (List)i.get("embedding"));
        });
    }

    private Stream<Map<String, Object>> executeRequestReturningMap(Object body, AWSConfig config) {
        return this.executeRequestCommon(body, config).map(i -> (Map)i);
    }

    private Stream<Object> executeRequestCommon(Object body, AWSConfig conf) {
        try {
            String string;
            String bodyString = body instanceof String ? (string = (String)body) : JsonUtil.OBJECT_MAPPER.writeValueAsString(body);
            Map<String, Object> headers = conf.getHeaders();
            headers.putIfAbsent("Content-Type", "application/json");
            headers.putIfAbsent("accept", "*/*");
            if (!headers.containsKey("Authorization")) {
                AwsSignatureV4Generator.calculateAuthorizationHeaders(conf, bodyString, headers, "sagemaker");
            }
            return JsonUtil.loadJson((Object)conf.getEndpoint(), conf.getHeaders(), (String)bodyString, (String)conf.getJsonPath(), (boolean)true, List.of(), (URLAccessChecker)this.urlAccessChecker);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public record EmbeddingResult(long index, String text, List<Double> embedding) {
    }
}

