/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml.aws;

import apoc.Description;
import apoc.Extended;
import apoc.ml.aws.AWSConfig;
import apoc.ml.aws.AwsSignatureV4Generator;
import apoc.ml.aws.BedrockGetModelsConfig;
import apoc.ml.aws.BedrockInvokeConfig;
import apoc.ml.aws.BedrockInvokeResult;
import apoc.ml.aws.ModelItemResult;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import apoc.util.Util;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class Bedrock {
    @Context
    public URLAccessChecker urlAccessChecker;
    public static final String JURASSIC_2_ULTRA = "ai21.j2-ultra-v1";
    public static final String TITAN_EMBED_TEXT = "amazon.titan-embed-text-v1";
    public static final String ANTHROPIC_CLAUDE_V2 = "anthropic.claude-v2";
    public static final String STABILITY_STABLE_DIFFUSION_XL = "stability.stable-diffusion-xl-v0";

    @Procedure(value="apoc.ml.bedrock.list")
    public Stream<ModelItemResult> list(@Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("jsonPath", "modelSummaries[*]");
        BedrockGetModelsConfig conf = new BedrockGetModelsConfig(config);
        return this.executeRequestCommon(null, conf).flatMap(i -> ((List)i).stream()).map(ModelItemResult::new);
    }

    @Procedure(value="apoc.ml.bedrock.custom")
    @Description(value="To create a customizable bedrock call")
    public Stream<MapResult> custom(@Name(value="body") Map<String, Object> body, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        BedrockInvokeConfig conf = new BedrockInvokeConfig(configuration);
        return this.executeRequestReturningMap(body, conf).map(MapResult::new);
    }

    @Procedure(value="apoc.ml.bedrock.chat")
    @Description(value="apoc.ml.bedrock.chat(messages, $conf) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name(value="messages") List<Map<String, Object>> messages, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        if (messages == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("model", ANTHROPIC_CLAUDE_V2);
        BedrockInvokeConfig conf = new BedrockInvokeConfig(config);
        return messages.stream().flatMap(message -> {
            if (conf.isOpenAICompatible()) {
                this.transformOpenAiToBedrockRequestBody((Map<String, Object>)message);
            }
            message.putIfAbsent("max_tokens_to_sample", 200);
            return this.executeRequestReturningMap((Map)message, conf).map(MapResult::new);
        });
    }

    private void transformOpenAiToBedrockRequestBody(Map<String, Object> message) {
        String content = (String)message.get("content");
        content = StringUtils.prependIfMissing((String)content, (CharSequence)"\n\nHuman:", (CharSequence[])new CharSequence[0]);
        content = StringUtils.appendIfMissing((String)content, (CharSequence)"\n\nAssistant:", (CharSequence[])new CharSequence[0]);
        message.clear();
        message.put("prompt", content);
    }

    @Procedure(value="apoc.ml.bedrock.completion")
    @Description(value="apoc.ml.bedrock.completion(prompt, $conf) - prompts the completion API")
    public Stream<MapResult> completion(@Name(value="prompt") String prompt, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        if (prompt == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("model", JURASSIC_2_ULTRA);
        BedrockInvokeConfig conf = new BedrockInvokeConfig(config);
        Map body = Util.map((Object[])new String[]{"prompt", prompt});
        return this.executeRequestReturningMap(body, conf).map(MapResult::new);
    }

    @Procedure(value="apoc.ml.bedrock.embedding")
    @Description(value="apoc.ml.bedrock.embedding([texts], $configuration) - returns the embeddings for a given text")
    public Stream<BedrockInvokeResult.Embedding> embedding(@Name(value="texts") List<String> texts, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        if (texts == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        HashMap<String, Object> config = new HashMap<String, Object>(configuration);
        config.putIfAbsent("model", TITAN_EMBED_TEXT);
        BedrockInvokeConfig conf = new BedrockInvokeConfig(config);
        return texts.stream().flatMap(text -> {
            Map body = Util.map((Object[])new String[]{"inputText", text});
            return this.executeRequestReturningMap(body, conf).map(i -> BedrockInvokeResult.Embedding.from(i, text));
        });
    }

    @Procedure(value="apoc.ml.bedrock.image")
    public Stream<BedrockInvokeResult.Image> image(@Name(value="body") Map<String, Object> body, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        if (body == null) {
            throw new RuntimeException("Null, blank or empty input provided. Please specify a valid input");
        }
        configuration.putIfAbsent("model", STABILITY_STABLE_DIFFUSION_XL);
        configuration.putIfAbsent("jsonPath", "$.artifacts[0]");
        BedrockInvokeConfig conf = new BedrockInvokeConfig(configuration);
        return this.executeRequestReturningMap(body, conf).map(BedrockInvokeResult.Image::from);
    }

    private Stream<Map<String, Object>> executeRequestReturningMap(Map body, AWSConfig config) {
        return this.executeRequestCommon(body, config).map(i -> (Map)i);
    }

    private Stream<Object> executeRequestCommon(Map body, AWSConfig conf) {
        try {
            String bodyString = null;
            if (body != null) {
                body.putAll(conf.getBody());
                bodyString = JsonUtil.OBJECT_MAPPER.writeValueAsString((Object)body);
            }
            HashMap<String, Object> headers = new HashMap<String, Object>(conf.getHeaders());
            headers.putIfAbsent("Content-Type", "application/json");
            headers.putIfAbsent("accept", "*/*");
            if (!headers.containsKey("Authorization")) {
                AwsSignatureV4Generator.calculateAuthorizationHeaders(conf, bodyString, headers, "bedrock");
            }
            return JsonUtil.loadJson((Object)conf.getEndpoint(), headers, (String)bodyString, (String)conf.getJsonPath(), (boolean)true, List.of(), (URLAccessChecker)this.urlAccessChecker);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

