/*
 * Decompiled with CFR 0.152.
 */
package dev.embeddings4j;

import com.github.jelmerk.knn.DistanceFunction;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import dev.embeddings4j.Embedding;
import dev.embeddings4j.SearchNearestQuery;
import dev.embeddings4j.SearchNearestResult;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public abstract class AbstractInMemoryVectorDatabase<IdType, ContentType, VectorType extends Comparable<VectorType>> {
    private final HnswIndex<IdType, List<VectorType>, Embedding<IdType, ContentType, VectorType>, VectorType> index;

    protected AbstractInMemoryVectorDatabase(int dimensions, int maxSize, DistanceFunction<List<VectorType>, VectorType> distanceFunction) {
        this.index = HnswIndex.newBuilder((int)dimensions, distanceFunction, (int)maxSize).build();
    }

    public void insert(Embedding<IdType, ContentType, VectorType> embedding) {
        this.validate(embedding);
        this.index.add(embedding);
    }

    @SafeVarargs
    public final void insert(Embedding<IdType, ContentType, VectorType> ... embeddings) {
        List<Embedding<IdType, ContentType, VectorType>> embeddingList = Arrays.asList(embeddings);
        embeddingList.forEach(this::validate);
        embeddingList.forEach(arg_0 -> this.index.add(arg_0));
    }

    public void insertAll(Collection<Embedding<IdType, ContentType, VectorType>> embeddings) throws InterruptedException {
        embeddings.forEach(this::validate);
        this.index.addAll(embeddings);
    }

    private void validate(Embedding<IdType, ContentType, VectorType> embedding) {
        if (embedding == null) {
            throw new IllegalArgumentException("Embedding must not be null");
        }
        if (embedding.dimensions() != this.index.getDimensions()) {
            throw new IllegalArgumentException(String.format("Dimensions of vector (%s) should match dimensions of DB (%s)", embedding.dimensions(), this.index.getDimensions()));
        }
    }

    public List<SearchNearestResult<IdType, ContentType, VectorType>> execute(SearchNearestQuery<IdType, ContentType, VectorType> query) {
        List searchResults = this.index.findNearest(query.reference().vector(), query.maxResults().intValue());
        return searchResults.stream().sorted(Comparator.comparing(SearchResult::distance)).map(result -> new SearchNearestResult((Embedding)result.item(), result.distance())).collect(Collectors.toList());
    }

    public int size() {
        return this.index.size();
    }
}

