/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.neo4j;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.neo4j.Neo4jEmbeddingUtils;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import lombok.Generated;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Record;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Neo4jEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(Neo4jEmbeddingStore.class);
    private final Driver driver;
    private final SessionConfig config;
    private final int dimension;
    private final long awaitIndexTimeout;
    private final String indexName;
    private final String metadataPrefix;
    private final String embeddingProperty;
    private final String idProperty;
    private final String sanitizedEmbeddingProperty;
    private final String sanitizedIdProperty;
    private final String sanitizedText;
    private final String label;
    private final String sanitizedLabel;
    private final String textProperty;
    private final String databaseName;
    private final String retrievalQuery;
    private final Set<String> notMetaKeys;

    public Neo4jEmbeddingStore(SessionConfig config, Driver driver, int dimension, String label, String embeddingProperty, String idProperty, String metadataPrefix, String textProperty, String indexName, String databaseName, String retrievalQuery, long awaitIndexTimeout) {
        this.driver = (Driver)ValidationUtils.ensureNotNull((Object)driver, (String)"driver");
        this.dimension = ValidationUtils.ensureBetween((Integer)dimension, (int)0, (int)4096, (String)"dimension");
        this.databaseName = (String)Utils.getOrDefault((Object)databaseName, (Object)"neo4j");
        this.config = (SessionConfig)Utils.getOrDefault((Object)config, (Object)SessionConfig.forDatabase((String)this.databaseName));
        this.label = (String)Utils.getOrDefault((Object)label, (Object)"Document");
        this.embeddingProperty = (String)Utils.getOrDefault((Object)embeddingProperty, (Object)"embedding");
        this.idProperty = (String)Utils.getOrDefault((Object)idProperty, (Object)"id");
        this.indexName = (String)Utils.getOrDefault((Object)indexName, (Object)"vector");
        this.metadataPrefix = (String)Utils.getOrDefault((Object)metadataPrefix, (Object)"");
        this.textProperty = (String)Utils.getOrDefault((Object)textProperty, (Object)"text");
        this.awaitIndexTimeout = (Long)Utils.getOrDefault((Object)awaitIndexTimeout, (Object)60L);
        this.sanitizedLabel = Neo4jEmbeddingUtils.sanitizeOrThrows(this.label, "label");
        this.sanitizedEmbeddingProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.embeddingProperty, "embeddingProperty");
        this.sanitizedIdProperty = Neo4jEmbeddingUtils.sanitizeOrThrows(this.idProperty, "idProperty");
        this.sanitizedText = Neo4jEmbeddingUtils.sanitizeOrThrows(this.textProperty, "textProperty");
        String defaultRetrievalQuery = String.format("RETURN properties(node) AS metadata, node.%1$s AS %1$s, node.%2$s AS %2$s, node.%3$s AS %3$s, score", this.sanitizedIdProperty, this.sanitizedText, this.sanitizedEmbeddingProperty);
        this.retrievalQuery = (String)Utils.getOrDefault((Object)retrievalQuery, (Object)defaultRetrievalQuery);
        this.notMetaKeys = new HashSet<String>(Arrays.asList(this.idProperty, this.embeddingProperty, this.textProperty));
        this.createSchema();
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        return this.addAll(embeddings, null);
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
        Value embeddingValue = Values.value((float[])request.queryEmbedding().vector());
        try (Session session = this.session();){
            Map<String, Integer> params = Map.of("indexName", this.indexName, "embeddingValue", embeddingValue, "minScore", request.minScore(), "maxResults", request.maxResults());
            List matches = session.run("CALL db.index.vector.queryNodes($indexName, $maxResults, $embeddingValue)\nYIELD node, score\nWHERE score >= $minScore\n" + this.retrievalQuery, params).list(item -> Neo4jEmbeddingUtils.toEmbeddingMatch(this, item));
            EmbeddingSearchResult embeddingSearchResult = new EmbeddingSearchResult(matches);
            return embeddingSearchResult;
        }
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAll(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    public void addAll(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("[do not add empty embeddings to neo4j]");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        this.bulk(ids, embeddings, embedded);
    }

    private void bulk(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        Stream<List<Map<String, Object>>> rowsBatched = Neo4jEmbeddingUtils.getRowsBatched(this, ids, embeddings, embedded);
        try (Session session = this.session();){
            rowsBatched.forEach(rows -> {
                String statement = "UNWIND $rows AS row\nMERGE (u:%1$s {%2$s: row.%2$s})\nSET u += row.%3$s\nWITH row, u\nCALL db.create.setNodeVectorProperty(u, $embeddingProperty, row.%4$s)\nRETURN count(*)".formatted(this.sanitizedLabel, this.sanitizedIdProperty, "props", "embeddingRow");
                Map<String, String> params = Map.of("rows", rows, "embeddingProperty", this.embeddingProperty);
                session.executeWrite(tx -> tx.run(statement, params).consume());
            });
        }
    }

    private void createSchema() {
        if (!this.indexExists()) {
            this.createIndex();
        }
        this.createUniqueConstraint();
    }

    private void createUniqueConstraint() {
        try (Session session = this.session();){
            String query = String.format("CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE", this.sanitizedLabel, this.sanitizedIdProperty);
            session.run(query);
        }
    }

    private boolean indexExists() {
        try (Session session = this.session();){
            boolean isIndexDifferent;
            Map<String, String> params = Map.of("name", this.indexName);
            Result resIndex = session.run("SHOW INDEX WHERE type = 'VECTOR' AND name = $name", params);
            if (!resIndex.hasNext()) {
                boolean bl = false;
                return bl;
            }
            Record record = resIndex.single();
            List idxLabels = record.get("labelsOrTypes").asList(Value::asString);
            List idxProps = record.get("properties").asList();
            boolean bl = isIndexDifferent = !idxLabels.equals(Collections.singletonList(this.label)) || !idxProps.equals(Collections.singletonList(this.embeddingProperty));
            if (isIndexDifferent) {
                String errMessage = String.format("It's not possible to create an index for the label `%s` and the property `%s`,\nas there is another index with name `%s` with different labels: `%s` and properties `%s`.\nPlease provide another indexName to create the vector index, or delete the existing one", this.label, this.embeddingProperty, this.indexName, idxLabels, idxProps);
                throw new RuntimeException(errMessage);
            }
            boolean bl2 = true;
            return bl2;
        }
    }

    private void createIndex() {
        Map<String, Integer> params = Map.of("indexName", this.indexName, "label", this.label, "embeddingProperty", this.embeddingProperty, "dimension", this.dimension);
        try (Session session = this.session();){
            session.run("CALL db.index.vector.createNodeIndex($indexName, $label, $embeddingProperty, $dimension, 'cosine')", params);
            session.run("CALL db.awaitIndexes($timeout)", Map.of("timeout", this.awaitIndexTimeout)).consume();
        }
    }

    private Session session() {
        return this.driver.session(this.config);
    }

    @Generated
    public static Neo4jEmbeddingStoreBuilder builder() {
        return new Neo4jEmbeddingStoreBuilder();
    }

    @Generated
    public Driver getDriver() {
        return this.driver;
    }

    @Generated
    public SessionConfig getConfig() {
        return this.config;
    }

    @Generated
    public int getDimension() {
        return this.dimension;
    }

    @Generated
    public long getAwaitIndexTimeout() {
        return this.awaitIndexTimeout;
    }

    @Generated
    public String getIndexName() {
        return this.indexName;
    }

    @Generated
    public String getMetadataPrefix() {
        return this.metadataPrefix;
    }

    @Generated
    public String getEmbeddingProperty() {
        return this.embeddingProperty;
    }

    @Generated
    public String getIdProperty() {
        return this.idProperty;
    }

    @Generated
    public String getSanitizedEmbeddingProperty() {
        return this.sanitizedEmbeddingProperty;
    }

    @Generated
    public String getSanitizedIdProperty() {
        return this.sanitizedIdProperty;
    }

    @Generated
    public String getSanitizedText() {
        return this.sanitizedText;
    }

    @Generated
    public String getLabel() {
        return this.label;
    }

    @Generated
    public String getSanitizedLabel() {
        return this.sanitizedLabel;
    }

    @Generated
    public String getTextProperty() {
        return this.textProperty;
    }

    @Generated
    public String getDatabaseName() {
        return this.databaseName;
    }

    @Generated
    public String getRetrievalQuery() {
        return this.retrievalQuery;
    }

    @Generated
    public Set<String> getNotMetaKeys() {
        return this.notMetaKeys;
    }

    public static class Neo4jEmbeddingStoreBuilder {
        @Generated
        private SessionConfig config;
        @Generated
        private Driver driver;
        @Generated
        private int dimension;
        @Generated
        private String label;
        @Generated
        private String embeddingProperty;
        @Generated
        private String idProperty;
        @Generated
        private String metadataPrefix;
        @Generated
        private String textProperty;
        @Generated
        private String indexName;
        @Generated
        private String databaseName;
        @Generated
        private String retrievalQuery;
        @Generated
        private long awaitIndexTimeout;

        public Neo4jEmbeddingStoreBuilder withBasicAuth(String uri, String user, String password) {
            return this.driver(GraphDatabase.driver((String)uri, (AuthToken)AuthTokens.basic((String)user, (String)password)));
        }

        @Generated
        Neo4jEmbeddingStoreBuilder() {
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder config(SessionConfig config) {
            this.config = config;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder driver(Driver driver) {
            this.driver = driver;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder dimension(int dimension) {
            this.dimension = dimension;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder label(String label) {
            this.label = label;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder embeddingProperty(String embeddingProperty) {
            this.embeddingProperty = embeddingProperty;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder idProperty(String idProperty) {
            this.idProperty = idProperty;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder metadataPrefix(String metadataPrefix) {
            this.metadataPrefix = metadataPrefix;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder textProperty(String textProperty) {
            this.textProperty = textProperty;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder indexName(String indexName) {
            this.indexName = indexName;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder databaseName(String databaseName) {
            this.databaseName = databaseName;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder retrievalQuery(String retrievalQuery) {
            this.retrievalQuery = retrievalQuery;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStoreBuilder awaitIndexTimeout(long awaitIndexTimeout) {
            this.awaitIndexTimeout = awaitIndexTimeout;
            return this;
        }

        @Generated
        public Neo4jEmbeddingStore build() {
            return new Neo4jEmbeddingStore(this.config, this.driver, this.dimension, this.label, this.embeddingProperty, this.idProperty, this.metadataPrefix, this.textProperty, this.indexName, this.databaseName, this.retrievalQuery, this.awaitIndexTimeout);
        }

        @Generated
        public String toString() {
            return "Neo4jEmbeddingStore.Neo4jEmbeddingStoreBuilder(config=" + String.valueOf(this.config) + ", driver=" + String.valueOf(this.driver) + ", dimension=" + this.dimension + ", label=" + this.label + ", embeddingProperty=" + this.embeddingProperty + ", idProperty=" + this.idProperty + ", metadataPrefix=" + this.metadataPrefix + ", textProperty=" + this.textProperty + ", indexName=" + this.indexName + ", databaseName=" + this.databaseName + ", retrievalQuery=" + this.retrievalQuery + ", awaitIndexTimeout=" + this.awaitIndexTimeout + ")";
        }
    }
}

