/*
 * Decompiled with CFR 0.152.
 */
package apoc.vectordb;

import apoc.Extended;
import apoc.ml.RestAPIConfig;
import apoc.result.ListResult;
import apoc.result.MapResult;
import apoc.util.CollectionUtils;
import apoc.util.UrlResolver;
import apoc.util.Util;
import apoc.vectordb.VectorDb;
import apoc.vectordb.VectorDbHandler;
import apoc.vectordb.VectorDbUtil;
import apoc.vectordb.VectorEmbeddingConfig;
import apoc.vectordb.VectorMappingConfig;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class Weaviate {
    public static final VectorDbHandler DB_HANDLER = VectorDbHandler.Type.WEAVIATE.get();
    @Context
    public ProcedureCallContext procedureCallContext;
    @Context
    public Transaction tx;
    @Context
    public GraphDatabaseService db;
    @Context
    public URLAccessChecker urlAccessChecker;

    @Procedure(value="apoc.vectordb.weaviate.info")
    @Description(value="apoc.vectordb.weaviate.info(hostOrKey, collection, $configuration) - Get information about the specified existing collection or throws an error if it does not exist")
    public Stream<MapResult> createCollection(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema/%s");
        VectorDbUtil.methodAndPayloadNull(config);
        Map<String, String> additionalBodies = Map.of("class", collection);
        RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
        return VectorDb.executeRequest(restAPIConfig, this.urlAccessChecker).map(v -> (Map)v).map(MapResult::new);
    }

    @Procedure(value="apoc.vectordb.weaviate.createCollection")
    @Description(value="apoc.vectordb.weaviate.createCollection(hostOrKey, collection, similarity, size, $configuration) - Creates a collection, with the name specified in the 2nd parameter, and with the specified `similarity` and `size`")
    public Stream<MapResult> createCollection(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="similarity") String similarity, @Name(value="size") Long size, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
        config.putIfAbsent("method", "POST");
        Map<String, Map<String, Long>> additionalBodies = Map.of("class", collection, "vectorIndexConfig", Map.of("distance", similarity, "size", size));
        RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), additionalBodies);
        return VectorDb.executeRequest(restAPIConfig, this.urlAccessChecker).map(v -> (Map)v).map(MapResult::new);
    }

    @Procedure(value="apoc.vectordb.weaviate.deleteCollection")
    @Description(value="apoc.vectordb.weaviate.deleteCollection(hostOrKey, collection, $configuration) - Deletes a collection with the name specified in the 2nd parameter")
    public Stream<MapResult> deleteCollection(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema/%s");
        config.putIfAbsent("method", "DELETE");
        RestAPIConfig restAPIConfig = new RestAPIConfig(config);
        return VectorDb.executeRequest(restAPIConfig, this.urlAccessChecker).map(v -> (Map)v).map(MapResult::new);
    }

    @Procedure(value="apoc.vectordb.weaviate.upsert")
    @Description(value="apoc.vectordb.weaviate.upsert(hostOrKey, collection, vectors, $configuration) - Upserts, in the collection with the name specified in the 2nd parameter, the vectors [{id: 'id', vector: '<vectorDb>', medatada: '<metadata>'}]")
    public Stream<MapResult> upsert(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="vectors") List<Map<String, Object>> vectors, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/objects");
        config.putIfAbsent("method", "POST");
        HashMap<String, String> body = new HashMap<String, String>();
        body.put("class", collection);
        RestAPIConfig restAPIConfig = new RestAPIConfig(config, Map.of(), body);
        return vectors.stream().flatMap(vector -> {
            try {
                HashMap<String, Object> configBody = new HashMap<String, Object>(restAPIConfig.getBody());
                configBody.putAll((Map<String, Object>)vector);
                configBody.put("properties", vector.remove("metadata"));
                restAPIConfig.setBody(configBody);
                Stream<Object> objectStream = VectorDb.executeRequest(restAPIConfig, this.urlAccessChecker);
                return objectStream;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).map(v -> (Map)v).map(MapResult::new);
    }

    @Procedure(value="apoc.vectordb.weaviate.delete")
    @Description(value="apoc.vectordb.weaviate.delete(hostOrKey, collection, ids, $configuration) - Deletes the vectors with the specified `ids`")
    public Stream<ListResult> delete(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="ids") List<Object> ids, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
        config.putIfAbsent("method", "DELETE");
        RestAPIConfig restAPIConfig = new RestAPIConfig(config, Util.map((Object[])new Object[0]), Util.map((Object[])new Object[0]));
        List<Object> objects = ids.stream().peek(id -> {
            String endpoint = "%s/objects/%s/%s".formatted(restAPIConfig.getBaseUrl(), collection, id);
            restAPIConfig.setEndpoint(endpoint);
            try {
                VectorDb.executeRequest(restAPIConfig, this.urlAccessChecker);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).toList();
        return Stream.of(new ListResult(objects));
    }

    @Procedure(value="apoc.vectordb.weaviate.getAndUpdate", mode=Mode.WRITE)
    @Description(value="apoc.vectordb.weaviate.getAndUpdate(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`")
    public Stream<VectorDbUtil.EmbeddingResult> getAndUpdate(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="ids") List<Object> ids, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        return this.getCommon(hostOrKey, collection, ids, configuration);
    }

    @Procedure(value="apoc.vectordb.weaviate.get")
    @Description(value="apoc.vectordb.weaviate.get(hostOrKey, collection, ids, $configuration) - Gets the vectors with the specified `ids`")
    public Stream<VectorDbUtil.EmbeddingResult> get(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="ids") List<Object> ids, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        VectorDbUtil.setReadOnlyMappingMode(configuration);
        return this.getCommon(hostOrKey, collection, ids, configuration);
    }

    private Stream<VectorDbUtil.EmbeddingResult> getCommon(String hostOrKey, String collection, List<Object> ids, Map<String, Object> configuration) {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/schema");
        config.putIfAbsent("method", null);
        List fields = this.procedureCallContext.outputFields().toList();
        VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromGet(config, this.procedureCallContext, ids, collection);
        boolean hasEmbedding = fields.contains("vector") && conf.isAllResults();
        boolean hasMetadata = fields.contains("metadata");
        VectorMappingConfig mapping = conf.getMapping();
        String suffix = hasEmbedding ? "?include=vector" : "";
        return ids.stream().flatMap(id -> {
            String endpoint = "%s/objects/%s/%s".formatted(conf.getApiConfig().getBaseUrl(), collection, id) + suffix;
            conf.getApiConfig().setEndpoint(endpoint);
            try {
                return VectorDb.executeRequest(conf.getApiConfig(), this.urlAccessChecker).map(v -> (Map)v).map(m -> VectorDb.getEmbeddingResult(conf, this.tx, hasEmbedding, hasMetadata, mapping, m));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    @Procedure(value="apoc.vectordb.weaviate.query")
    @Description(value="apoc.vectordb.weaviate.query(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter")
    public Stream<VectorDbUtil.EmbeddingResult> query(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="vector", defaultValue="[]") List<Double> vector, @Name(value="filter", defaultValue="null") Object filter, @Name(value="limit", defaultValue="10") long limit, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        VectorDbUtil.setReadOnlyMappingMode(configuration);
        return this.queryCommon(hostOrKey, collection, vector, filter, limit, configuration);
    }

    @Procedure(value="apoc.vectordb.weaviate.queryAndUpdate", mode=Mode.WRITE)
    @Description(value="apoc.vectordb.weaviate.queryAndUpdate(hostOrKey, collection, vector, filter, limit, $configuration) - Retrieves closest vectors from the defined `vector`, `limit` of results, in the collection with the name specified in the 2nd parameter")
    public Stream<VectorDbUtil.EmbeddingResult> queryAndUpdate(@Name(value="hostOrKey") String hostOrKey, @Name(value="collection") String collection, @Name(value="vector", defaultValue="[]") List<Double> vector, @Name(value="filter", defaultValue="null") Object filter, @Name(value="limit", defaultValue="10") long limit, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) throws Exception {
        return this.queryCommon(hostOrKey, collection, vector, filter, limit, configuration);
    }

    private Stream<VectorDbUtil.EmbeddingResult> queryCommon(String hostOrKey, String collection, List<Double> vector, Object filter, long limit, Map<String, Object> configuration) throws Exception {
        Map<String, Object> config = this.getVectorDbInfo(hostOrKey, collection, configuration, "%s/graphql");
        VectorEmbeddingConfig conf = DB_HANDLER.getEmbedding().fromQuery(config, this.procedureCallContext, vector, filter, limit, collection);
        return VectorDb.getEmbeddingResultStream(conf, this.procedureCallContext, this.urlAccessChecker, this.tx, v -> {
            Map mapResult = (Map)v;
            List errors = (List)mapResult.get("errors");
            if (CollectionUtils.isNotEmpty((Collection)errors)) {
                String message = "An error occurred during Weaviate API response: \n" + StringUtils.join((Iterable)errors, (String)"\n");
                throw new RuntimeException(message);
            }
            Object getValue = ((Map)mapResult.get("data")).get("Get");
            Object collectionValue = ((Map)getValue).get(collection);
            return ((List)collectionValue).stream().map(i -> {
                Map additional = (Map)i.remove("_additional");
                HashMap map = new HashMap();
                map.put(conf.getMetadataKey(), i);
                map.put(conf.getScoreKey(), additional.get("distance"));
                map.put(conf.getIdKey(), additional.get("id"));
                map.put(conf.getVectorKey(), additional.get("vector"));
                return map;
            });
        });
    }

    private Map<String, Object> getVectorDbInfo(String hostOrKey, String collection, Map<String, Object> configuration, String templateUrl) {
        return VectorDbUtil.getCommonVectorDbInfo(hostOrKey, collection, configuration, templateUrl, DB_HANDLER);
    }

    protected String getWeaviateUrl(String hostOrKey) {
        String baseUrl = new UrlResolver("http", "localhost", 8000).getUrl("weaviate", hostOrKey);
        return baseUrl + "/v1";
    }
}

