/*
 * Decompiled with CFR 0.152.
 */
package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.OpenAI;
import apoc.ml.RagConfig;
import apoc.result.StringResult;
import apoc.util.CollectionUtils;
import apoc.util.Util;
import apoc.util.collection.Iterators;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.text.WordUtils;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.QueryExecutionException;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class Prompt {
    public static final String API_KEY_CONF = "apiKey";
    @Context
    public Transaction tx;
    @Context
    public GraphDatabaseService db;
    @Context
    public Log log;
    @Context
    public ApocConfig apocConfig;
    @Context
    public ProcedureCallContext procedureCallContext;
    @Context
    public URLAccessChecker urlAccessChecker;
    public static final String BACKTICKS = "```";
    public static final String UNKNOWN_ANSWER = "Sorry, I don't know";
    static final String RAG_BASE_PROMPT = "You are a customer service agent that helps a customer with answering questions about a service.\nUse the following context to answer the `user question` at the end. Make sure not to make any changes to the context if possible when prepare answers so as to provide accurate responses.\nIf you don't know the answer, just say `%s`, don't try to make up an answer.\n\n---- Start context ----\n%s\n---- End context ----\n";
    public static final String EXPLAIN_SCHEMA_PROMPT = "You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.\nExplain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.\nTry to explain as much as possible the nodes, relationships and properties.\nKeep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.\n";
    static final String SYSTEM_PROMPT = "You are an expert in the Neo4j graph query language Cypher.\nGiven a graph database schema of entities (nodes) with labels and attributes and\nrelationships with start- and end-node, relationship-type, direction and properties\nyou are able to develop read only matching Cypher statements that express a user question as a graph database query.\nOnly answer with a single Cypher statement in triple backticks, if you can't determine a statement, answer with an empty response.\nDo not explain, apologize or provide additional detail, otherwise people will come to harm.\n";
    static final String FROM_CYPHER_PROMPT = "You are an expert in the Neo4j graph query language Cypher.\nGiven a graph database schema of entities (nodes) with labels and attributes and\nrelationships with start-node and end-node, relationship-type, direction and properties,\nand given a read only matching Cypher query statement,\nyou are able to explain the Cypher query statement in plain language,\nproviding useful details of each entity.\nDo not explain, apologize or provide additional detail about the schema, otherwise people will come to harm.\n";
    private static final String SCHEMA_FROM_META_DATA = "\nWITH label, elementType,\n     apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties,\n     collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\nwith  elementType as type,\napoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\nreturn collect(case type when \"relationship\" then entities end)[0] as relationships,\ncollect(case type when \"node\" then entities end)[0] as nodes,\ncollect(case type when \"node\" then patterns end)[0] as patterns\n";
    private static final String SCHEMA_QUERY = "call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})\nYIELD label, other, elementType, type, property\n\nWITH label, elementType,\n     apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties,\n     collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\nwith  elementType as type,\napoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\nreturn collect(case type when \"relationship\" then entities end)[0] as relationships,\ncollect(case type when \"node\" then entities end)[0] as nodes,\ncollect(case type when \"node\" then patterns end)[0] as patterns\n";
    private static final String GRAPH_QUERY = "UNWIND $queries AS query\nCALL apoc.meta.data.of(query, {maxRels: 10, sample: $sample})\nYIELD label, other, elementType, type, property\nWITH DISTINCT label, other, elementType, type, property\n\nWITH label, elementType,\n     apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties,\n     collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\nwith  elementType as type,\napoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\nreturn collect(case type when \"relationship\" then entities end)[0] as relationships,\ncollect(case type when \"node\" then entities end)[0] as nodes,\ncollect(case type when \"node\" then patterns end)[0] as patterns\n";
    private static final String SCHEMA_PROMPT = "nodes:\n```\n%s\n```\n\nrelationships:\n```\n%s\n```\n\npatterns:\n```\n%s\n```\n";

    @Procedure(mode=Mode.READ)
    @Description(value="Takes a query in cypher and in natural language and returns the results in natural language")
    public Stream<StringResult> rag(@Name(value="paths") Object paths, @Name(value="attributes") List<String> attributes, @Name(value="question") String question, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) throws Exception {
        RagConfig config = new RagConfig(conf);
        String[] arrayAttrs = (String[])attributes.toArray(String[]::new);
        StringBuilder context = new StringBuilder();
        if (paths instanceof List) {
            List pathList = (List)paths;
            for (Object listItem : pathList) {
                this.augment(config, arrayAttrs, context, listItem);
            }
        } else if (paths instanceof String) {
            String queryOrIndex = (String)paths;
            config.getEmbeddings().getQuery(queryOrIndex, question, this.tx, config).forEachRemaining(row -> row.values().forEach(val -> this.augment(config, arrayAttrs, context, val)));
        } else {
            throw new RuntimeException("The first parameter must be a List or a String");
        }
        String contextPrompt = "\n---- Start context ----\n%s\n---- End context ----\n".formatted(context);
        String prompt = config.getBasePrompt() + contextPrompt;
        String result = this.prompt("\nQuestion:" + question, prompt, null, null, conf, List.of());
        return Stream.of(new StringResult(result));
    }

    private void augment(RagConfig config, String[] objects, StringBuilder context, Object listItem) {
        if (listItem instanceof Path) {
            Path p = (Path)listItem;
            for (Entity entity : p) {
                this.augmentEntity(config, objects, context, entity);
            }
        } else if (listItem instanceof Entity) {
            Entity e = (Entity)listItem;
            this.augmentEntity(config, objects, context, e);
        } else {
            throw new RuntimeException("The list `%s` must have node/type/path items".formatted(listItem));
        }
    }

    private void augmentEntity(RagConfig config, String[] objects, StringBuilder context, Entity entity) {
        Map props = entity.getProperties(objects);
        if (config.isGetLabelTypes()) {
            String string;
            if (entity instanceof Node) {
                Node node = (Node)entity;
                string = Util.joinLabels((Iterable)node.getLabels(), (String)",");
            } else {
                string = ((Relationship)entity).getType().name();
            }
            String labelsOrType = string;
            labelsOrType = WordUtils.capitalize((String)labelsOrType, (char[])new char[]{'_'});
            props.put("context description", labelsOrType);
        }
        String obj = props.entrySet().stream().filter(i -> i.getValue() != null).map(i -> (String)i.getKey() + ": " + String.valueOf(i.getValue()) + "\n").collect(Collectors.joining("\n---\n"));
        context.append(obj);
    }

    @Procedure(mode=Mode.READ)
    @Description(value="Takes a query in cypher and in natural language and returns the results in natural language")
    public Stream<StringResult> fromCypher(@Name(value="cypher") String cypher, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
        String schemaAndCypher = "%s\nwhile the cypher query is:\n%s\n".formatted(this.loadSchema(this.tx, conf), cypher);
        String schemaExplanation = this.prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", FROM_CYPHER_PROMPT, "This is the Cypher query statement explanation: \n", schemaAndCypher, conf, List.of());
        return Stream.of(new StringResult(schemaExplanation));
    }

    @Procedure(mode=Mode.READ)
    public Stream<PromptMapResult> query(@Name(value="question") String question, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) {
        String schema = this.loadSchema(this.tx, conf);
        String query = "";
        long retries = (Long)conf.getOrDefault("retries", 3L);
        boolean retryWithError = Util.toBoolean((Object)conf.get("retryWithError"));
        boolean containsField = this.procedureCallContext.outputFields().collect(Collectors.toSet()).contains("query");
        ArrayList<Map<String, String>> otherPrompts = new ArrayList<Map<String, String>>();
        while (true) {
            Stream<PromptMapResult> stream;
            block10: {
                Transaction transaction = this.db.beginTx();
                try {
                    Stream<PromptMapResult> mapResultStream;
                    QueryResult queryResult = this.tryQuery(question, conf, schema, otherPrompts);
                    query = queryResult.query;
                    List maps = Iterators.asList((Iterator)transaction.execute(queryResult.query));
                    transaction.commit();
                    stream = mapResultStream = maps.stream().map(row -> containsField ? new PromptMapResult(this, (Map<String, Object>)row, queryResult.query) : new PromptMapResult(this, (Map<String, Object>)row));
                    if (transaction == null) break block10;
                }
                catch (Throwable throwable) {
                    try {
                        if (transaction != null) {
                            try {
                                transaction.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (QueryExecutionException quee) {
                        if (this.log.isDebugEnabled()) {
                            this.log.debug("Generated query for question %s\n%s\nfailed with %s".formatted(question, query, quee.getMessage()));
                        }
                        if (!retryWithError) continue;
                        otherPrompts.addAll(List.of(Map.of("role", "user", "content", "The previous Cypher Statement throws the following error, consider it to return the correct statement: `%s`".formatted(quee.getMessage())), Map.of("role", "assistant", "content", "Cypher Statement (in backticks):")));
                        if (--retries > 0L) continue;
                        throw quee;
                    }
                }
                transaction.close();
            }
            return stream;
            break;
        }
    }

    @Procedure
    public Stream<StringResult> schema(@Name(value="conf", defaultValue="{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
        String schema = this.loadSchema(this.tx, conf);
        String schemaExplanation = this.prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", EXPLAIN_SCHEMA_PROMPT, "This database schema ", schema, conf, List.of());
        return Stream.of(new StringResult(schemaExplanation));
    }

    @Procedure(mode=Mode.READ)
    public Stream<QueryResult> cypher(@Name(value="question") String question, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) {
        String schema = this.loadSchema(this.tx, conf);
        long count = (Long)conf.getOrDefault("count", 1L);
        return LongStream.rangeClosed(1L, count).mapToObj(i -> this.tryQuery(question, conf, schema, List.of()));
    }

    @NotNull
    private QueryResult tryQuery(String question, Map<String, Object> conf, String schema, List<Map<String, String>> otherPrompts) {
        String query = "";
        try {
            query = this.prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf, otherPrompts);
            return new QueryResult(this, query, null, null);
        }
        catch (QueryExecutionException e) {
            return new QueryResult(this, query, e.getMessage(), e.getStatusCode());
        }
        catch (Exception e) {
            return new QueryResult(this, query, e.getMessage(), e.getClass().getSimpleName());
        }
    }

    private String prompt(String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map<String, Object> conf, List<Map<String, String>> otherPromptsFromRetries) throws JsonProcessingException, MalformedURLException {
        List additionalPrompts;
        ArrayList<Map<String, Object>> prompt = new ArrayList<Map<String, Object>>();
        if (systemPrompt != null && !systemPrompt.isBlank()) {
            prompt.add(Map.of("role", "system", "content", systemPrompt));
        }
        if (schema != null && !schema.isBlank()) {
            prompt.add(Map.of("role", "system", "content", "The graph database schema consists of these elements\n" + schema));
        }
        if (CollectionUtils.isNotEmpty((Collection)(additionalPrompts = (List)conf.get("additionalPrompts")))) {
            prompt.addAll(additionalPrompts);
        }
        if (userQuestion != null && !userQuestion.isBlank()) {
            prompt.add(Map.of("role", "user", "content", userQuestion));
        }
        if (assistantPrompt != null && !assistantPrompt.isBlank()) {
            prompt.add(Map.of("role", "assistant", "content", assistantPrompt));
        }
        prompt.addAll(otherPromptsFromRetries);
        String apiKey = (String)conf.get(API_KEY_CONF);
        String model = (String)conf.getOrDefault("model", "gpt-4o");
        String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions", model, "messages", prompt, "$", this.apocConfig, this.urlAccessChecker).map(v -> (Map)v).flatMap(m -> ((List)m.get("choices")).stream()).map(m -> (String)((Map)m.get("message")).get("content")).filter(s -> s != null && !s.isBlank()).map(s -> s.contains(BACKTICKS) ? s.substring(s.indexOf(BACKTICKS) + 3, s.lastIndexOf(BACKTICKS)) : s).collect(Collectors.joining(" ")).replaceAll("\n\n+", "\n");
        if (this.log.isDebugEnabled()) {
            this.log.debug("Generated query for question %s\n%s".formatted(userQuestion, result));
        }
        return result;
    }

    @Procedure
    public Stream<StringResult> fromQueries(@Name(value="queries") List<String> queries, @Name(value="conf", defaultValue="{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
        String schemaExplanation = this.prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", EXPLAIN_SCHEMA_PROMPT, "This database schema ", this.loadSchemaQueries(this.tx, queries, conf), conf, List.of());
        return Stream.of(new StringResult(schemaExplanation));
    }

    private String loadSchemaQueries(Transaction tx, List<String> queries, Map<String, Object> conf) {
        Map<String, List<String>> params = Map.of("sample", conf.getOrDefault("sample", 1000L), "queries", queries);
        return tx.execute(GRAPH_QUERY, params).stream().map(m -> SCHEMA_PROMPT.formatted(m.get("nodes"), m.get("relationships"), m.get("patterns"))).collect(Collectors.joining("\n"));
    }

    private String loadSchema(Transaction tx, Map<String, Object> conf) {
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("sample", conf.get("sample"));
        return tx.execute(SCHEMA_QUERY, params).stream().map(m -> SCHEMA_PROMPT.formatted(m.get("nodes"), m.get("relationships"), m.get("patterns"))).collect(Collectors.joining("\n"));
    }

    public class QueryResult {
        public final String query;

        public QueryResult(Prompt this$0, String query, String error, String type) {
            this.query = query;
        }

        public boolean hasError() {
            return false;
        }
    }

    public class PromptMapResult {
        public final Map<String, Object> value;
        public final String query;

        public PromptMapResult(Prompt this$0, Map<String, Object> value, String query) {
            this.value = value;
            this.query = query;
        }

        public PromptMapResult(Prompt this$0, Map<String, Object> value) {
            this.value = value;
            this.query = null;
        }
    }
}

