/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding.neo4j;

import dev.langchain4j.community.store.embedding.neo4j.Neo4jEmbeddingUtils;
import dev.langchain4j.community.store.embedding.neo4j.Neo4jFilterMapper;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Record;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Neo4jEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(Neo4jEmbeddingStore.class);
    public static final String ENTITIES_CREATION = "UNWIND $rows AS row\nMERGE (u:%1$s {%2$s: row.%2$s})\nSET u += row.%3$s\nWITH row, u\nCALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)\nRETURN count(*)";
    public static final String INDEX_ALREADY_EXISTS_ERROR = "It's not possible to create an index for the label `%s` and the property `%s`,\nas there is another index with name `%s` with different labels: `%s` and properties `%s`.\nPlease provide another indexName to create the vector index, or delete the existing one";
    public static final String CREATE_VECTOR_INDEX = "CREATE VECTOR INDEX %s IF NOT EXISTS\nFOR (m:%s) ON m.%s\nOPTIONS { indexConfig: {\n    `vector.dimensions`: %s,\n    `vector.similarity_function`: 'cosine'\n}}\n";
    public static final String COLUMNS_NOT_ALLOWED_ERR = "There are columns not allowed in the search query: ";
    private final Driver driver;
    private final SessionConfig config;
    private final int dimension;
    private final long awaitIndexTimeout;
    private final String indexName;
    private final String metadataPrefix;
    private final String embeddingProperty;
    private final String sanitizedEmbeddingProperty;
    private final String idProperty;
    private final String sanitizedIdProperty;
    private final String label;
    private final String sanitizedLabel;
    private final String textProperty;
    private final String retrievalQuery;
    private final Set<String> notMetaKeys;
    private final String fullTextIndexName;
    private final String fullTextQuery;
    private final String fullTextRetrievalQuery;
    private final boolean autoCreateFullText;

    public Neo4jEmbeddingStore(SessionConfig config, Driver driver, int dimension, String label, String embeddingProperty, String idProperty, String metadataPrefix, String textProperty, String indexName, String databaseName, String retrievalQuery, long awaitIndexTimeout, String fullTextIndexName, String fullTextQuery, String fullTextRetrievalQuery, boolean autoCreateFullText) {
        this.driver = (Driver)ValidationUtils.ensureNotNull((Object)driver, (String)"driver");
        this.dimension = ValidationUtils.ensureBetween((Integer)dimension, (int)0, (int)4096, (String)"dimension");
        String dbName = (String)Utils.getOrDefault((Object)databaseName, (Object)"neo4j");
        this.config = (SessionConfig)Utils.getOrDefault((Object)config, (Object)SessionConfig.forDatabase((String)dbName));
        this.label = (String)Utils.getOrDefault((Object)label, (Object)"Document");
        this.embeddingProperty = (String)Utils.getOrDefault((Object)embeddingProperty, (Object)"embedding");
        this.idProperty = (String)Utils.getOrDefault((Object)idProperty, (Object)"id");
        this.indexName = (String)Utils.getOrDefault((Object)indexName, (Object)"vector");
        this.metadataPrefix = (String)Utils.getOrDefault((Object)metadataPrefix, (Object)"");
        this.textProperty = (String)Utils.getOrDefault((Object)textProperty, (Object)"text");
        this.awaitIndexTimeout = (Long)Utils.getOrDefault((Object)awaitIndexTimeout, (Object)60L);
        this.sanitizedLabel = Neo4jEmbeddingUtils.sanitizeOrThrows(this.label, "label");
        this.sanitizedEmbeddingProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.embeddingProperty, "embeddingProperty");
        this.sanitizedIdProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.idProperty, "idProperty");
        String sanitizedText = Neo4jEmbeddingUtils.sanitizeOrThrows(this.textProperty, "textProperty");
        String defaultRetrievalQuery = String.format("RETURN properties(node) AS metadata, node.%1$s AS %1$s, node.%2$s AS %2$s, node.%3$s AS %3$s, score", this.sanitizedIdProperty, sanitizedText, this.sanitizedEmbeddingProperty);
        this.retrievalQuery = (String)Utils.getOrDefault((Object)retrievalQuery, (Object)defaultRetrievalQuery);
        this.notMetaKeys = new HashSet<String>(Arrays.asList(this.idProperty, this.embeddingProperty, this.textProperty));
        this.autoCreateFullText = autoCreateFullText;
        this.fullTextIndexName = (String)Utils.getOrDefault((Object)fullTextIndexName, (Object)"fulltext");
        this.fullTextQuery = fullTextQuery;
        this.fullTextRetrievalQuery = (String)Utils.getOrDefault((Object)fullTextRetrievalQuery, (Object)this.retrievalQuery);
        this.createSchema();
    }

    public static Builder builder() {
        return new Builder();
    }

    public Set<String> getNotMetaKeys() {
        return this.notMetaKeys;
    }

    public String getMetadataPrefix() {
        return this.metadataPrefix;
    }

    public String getTextProperty() {
        return this.textProperty;
    }

    public String getIdProperty() {
        return this.idProperty;
    }

    public String getEmbeddingProperty() {
        return this.embeddingProperty;
    }

    public String getIndexName() {
        return this.indexName;
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public void removeAll() {
        try (Session session = this.session();){
            String statement = String.format("CALL { MATCH (n:%1$s) DETACH DELETE n } IN TRANSACTIONS", this.sanitizedLabel);
            session.run(statement);
        }
    }

    public void removeAll(Collection<String> ids) {
        ValidationUtils.ensureNotEmpty(ids, (String)"ids");
        try (Session session = this.session();){
            String statement = String.format("CALL { UNWIND $ids AS id MATCH (n:%1$s {%2$s: id}) DETACH DELETE n } IN TRANSACTIONS ", this.sanitizedLabel, this.sanitizedIdProperty);
            Map<String, Collection<String>> params = Map.of("ids", ids);
            session.run(statement, params);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull((Object)filter, (String)"filter");
        AbstractMap.SimpleEntry<String, Map<String, Object>> filterEntry = new Neo4jFilterMapper().map(filter);
        try (Session session = this.session();){
            String statement = String.format("CALL { MATCH (n:%1$s) WHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s DETACH DELETE n } IN TRANSACTIONS ", this.sanitizedLabel, this.embeddingProperty, this.dimension, filterEntry.getKey());
            Map<String, Object> params = filterEntry.getValue();
            session.run(statement, params);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        Value embeddingValue = Values.value((float[])request.queryEmbedding().vector());
        try (Session session = this.session();){
            Filter filter = request.filter();
            if (filter == null) {
                EmbeddingSearchResult<TextSegment> embeddingSearchResult = this.getSearchResUsingVectorIndex(request, embeddingValue, session);
                return embeddingSearchResult;
            }
            EmbeddingSearchResult embeddingSearchResult = this.getSearchResUsingVectorSimilarity(request, filter, embeddingValue, session);
            return embeddingSearchResult;
        }
    }

    private EmbeddingSearchResult getSearchResUsingVectorSimilarity(EmbeddingSearchRequest request, Filter filter, Value embeddingValue, Session session) {
        AbstractMap.SimpleEntry<String, Map<String, Object>> entry = new Neo4jFilterMapper().map(filter);
        String query = String.format("CYPHER runtime = parallel parallelRuntimeSupport=all\nMATCH (n:%1$s)\nWHERE n.%2$s IS NOT NULL AND size(n.%2$s) = toInteger(%3$s) AND %4$s\nWITH n, vector.similarity.cosine(n.%2$s, %5$s) AS score\nWHERE score >= $minScore\nWITH n AS node, score\nORDER BY score DESC\nLIMIT $maxResults\n" + this.retrievalQuery, this.sanitizedLabel, this.embeddingProperty, this.dimension, entry.getKey(), embeddingValue);
        Map<String, Object> params = entry.getValue();
        params.put("minScore", request.minScore());
        params.put("maxResults", request.maxResults());
        return this.getEmbeddingSearchResult(session, query, params);
    }

    private EmbeddingSearchResult<TextSegment> getSearchResUsingVectorIndex(EmbeddingSearchRequest request, Value embeddingValue, Session session) {
        HashMap<String, Object> params = new HashMap<String, Object>(Map.of("indexName", this.indexName, "embeddingValue", embeddingValue, "minScore", request.minScore(), "maxResults", request.maxResults()));
        String query = "CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue)\nYIELD node, score\nWHERE score >= $minScore\n" + this.retrievalQuery;
        if (this.fullTextQuery != null) {
            query = query + "\nUNION\nCALL db.index.fulltext.queryNodes($fullTextIndexName, $fullTextQuery, {limit: $maxResults})\nYIELD node, score\nWHERE score >= $minScore\n" + this.fullTextRetrievalQuery;
            params.putAll(Map.of("fullTextIndexName", this.fullTextIndexName, "fullTextQuery", this.fullTextQuery));
        }
        Set<String> columns = this.getColumnNames(session, query);
        Set<String> allowedColumn = Set.of(this.textProperty, this.embeddingProperty, this.idProperty, "score", "metadata");
        if (!allowedColumn.containsAll(columns) || columns.size() > allowedColumn.size()) {
            throw new RuntimeException(COLUMNS_NOT_ALLOWED_ERR + String.valueOf(columns));
        }
        return this.getEmbeddingSearchResult(session, query, params);
    }

    private EmbeddingSearchResult<TextSegment> getEmbeddingSearchResult(Session session, String query, Map<String, Object> params) {
        List matches = (List)session.executeRead(tx -> tx.run(query, params).list(item -> Neo4jEmbeddingUtils.toEmbeddingMatch(this, item)));
        return new EmbeddingSearchResult(matches);
    }

    private Set<String> getColumnNames(Session session, String query) {
        List keys = session.run("EXPLAIN " + query).keys();
        return keys.stream().map(i -> i.replaceFirst("@[0-9]+", "").trim()).collect(Collectors.toSet());
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("[do not add empty embeddings to neo4j]");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        this.bulk(ids, embeddings, embedded);
    }

    private void bulk(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        Stream<List<Map<String, Object>>> rowsBatched = Neo4jEmbeddingUtils.getRowsBatched(this, ids, embeddings, embedded);
        try (Session session = this.session();){
            rowsBatched.forEach(rows -> {
                String statement = String.format(ENTITIES_CREATION, this.sanitizedLabel, this.sanitizedIdProperty, "props", "embeddingRow");
                Map<String, String> params = Map.of("rows", rows, "embeddingProperty", this.embeddingProperty);
                session.executeWrite(tx -> tx.run(statement, params).consume());
            });
        }
    }

    private void createSchema() {
        if (!this.indexExists()) {
            this.createIndex();
        }
        this.createFullTextIndex();
        if (!this.constraintExist()) {
            this.createUniqueConstraint();
        }
    }

    private boolean constraintExist() {
        try (Session session = this.session();){
            String query = "SHOW CONSTRAINTS\nWHERE $label IN labelsOrTypes\nAND $property IN properties\nAND type IN ['NODE_KEY', 'UNIQUENESS']\n";
            Result result = session.run(query, Map.of("label", this.label, "property", this.idProperty));
            boolean bl = !result.list().isEmpty();
            return bl;
        }
    }

    private void createUniqueConstraint() {
        try (Session session = this.session();){
            String query = String.format("CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE", this.sanitizedLabel, this.sanitizedIdProperty);
            session.run(query);
        }
    }

    private boolean indexExists() {
        try (Session session = this.session();){
            boolean isIndexDifferent;
            Map<String, String> params = Map.of("name", this.indexName);
            Result resIndex = session.run("SHOW VECTOR INDEX WHERE name = $name", params);
            if (!resIndex.hasNext()) {
                boolean bl = false;
                return bl;
            }
            Record record = resIndex.single();
            List idxLabels = record.get("labelsOrTypes").asList(Value::asString);
            List idxProps = record.get("properties").asList();
            boolean bl = isIndexDifferent = !idxLabels.equals(Collections.singletonList(this.label)) || !idxProps.equals(Collections.singletonList(this.embeddingProperty));
            if (isIndexDifferent) {
                String errMessage = String.format(INDEX_ALREADY_EXISTS_ERROR, this.label, this.embeddingProperty, this.indexName, idxLabels, idxProps);
                throw new RuntimeException(errMessage);
            }
            boolean bl2 = true;
            return bl2;
        }
    }

    private void createFullTextIndex() {
        if (!this.autoCreateFullText) {
            return;
        }
        try (Session session = this.session();){
            String query = String.format("CREATE FULLTEXT INDEX %s IF NOT EXISTS FOR (n:%s) ON EACH [n.%s]", this.fullTextIndexName, this.sanitizedLabel, this.sanitizedIdProperty);
            session.run(query).consume();
        }
    }

    private void createIndex() {
        Map<String, Integer> params = Map.of("indexName", this.indexName, "label", this.label, "embeddingProperty", this.embeddingProperty, "dimension", this.dimension);
        try (Session session = this.session();){
            String createIndexQuery = String.format(CREATE_VECTOR_INDEX, this.indexName, this.sanitizedLabel, this.sanitizedEmbeddingProperty, this.dimension);
            session.run(createIndexQuery, params);
            session.run("CALL db.awaitIndexes($timeout)", Map.of("timeout", this.awaitIndexTimeout)).consume();
        }
    }

    private Session session() {
        return this.driver.session(this.config);
    }

    public static class Builder {
        private String indexName;
        private String metadataPrefix;
        private String embeddingProperty;
        private String idProperty;
        private String label;
        private String textProperty;
        private String databaseName;
        private String retrievalQuery;
        private SessionConfig config;
        private Driver driver;
        private int dimension;
        private long awaitIndexTimeout;
        private String fullTextIndexName;
        private String fullTextQuery;
        private String fullTextRetrievalQuery;
        private boolean autoCreateFullText;

        public Builder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        public Builder metadataPrefix(String metadataPrefix) {
            this.metadataPrefix = metadataPrefix;
            return this;
        }

        public Builder embeddingProperty(String embeddingProperty) {
            this.embeddingProperty = embeddingProperty;
            return this;
        }

        public Builder idProperty(String idProperty) {
            this.idProperty = idProperty;
            return this;
        }

        public Builder label(String label) {
            this.label = label;
            return this;
        }

        public Builder textProperty(String textProperty) {
            this.textProperty = textProperty;
            return this;
        }

        public Builder databaseName(String databaseName) {
            this.databaseName = databaseName;
            return this;
        }

        public Builder retrievalQuery(String retrievalQuery) {
            this.retrievalQuery = retrievalQuery;
            return this;
        }

        public Builder config(SessionConfig config) {
            this.config = config;
            return this;
        }

        public Builder driver(Driver driver) {
            this.driver = driver;
            return this;
        }

        public Builder dimension(int dimension) {
            this.dimension = dimension;
            return this;
        }

        public Builder awaitIndexTimeout(long awaitIndexTimeout) {
            this.awaitIndexTimeout = awaitIndexTimeout;
            return this;
        }

        public Builder fullTextIndexName(String fullTextIndexName) {
            this.fullTextIndexName = fullTextIndexName;
            return this;
        }

        public Builder fullTextQuery(String fullTextQuery) {
            this.fullTextQuery = fullTextQuery;
            return this;
        }

        public Builder fullTextRetrievalQuery(String fullTextRetrievalQuery) {
            this.fullTextRetrievalQuery = fullTextRetrievalQuery;
            return this;
        }

        public Builder autoCreateFullText(boolean autoCreateFullText) {
            this.autoCreateFullText = autoCreateFullText;
            return this;
        }

        public Builder withBasicAuth(String uri, String user, String password) {
            this.driver = GraphDatabase.driver((String)uri, (AuthToken)AuthTokens.basic((String)user, (String)password));
            return this;
        }

        public Neo4jEmbeddingStore build() {
            return new Neo4jEmbeddingStore(this.config, this.driver, this.dimension, this.label, this.embeddingProperty, this.idProperty, this.metadataPrefix, this.textProperty, this.indexName, this.databaseName, this.retrievalQuery, this.awaitIndexTimeout, this.fullTextIndexName, this.fullTextQuery, this.fullTextRetrievalQuery, this.autoCreateFullText);
        }
    }
}

