package io.github.javpower.vectorex.keynote.index.vector;

import io.github.javpower.vectorex.keynote.knn.DistanceFunction;
import io.github.javpower.vectorex.keynote.knn.DistanceFunctions;
import io.github.javpower.vectorex.keynote.knn.SearchResult;
import io.github.javpower.vectorex.keynote.knn.hnsw.HnswIndex;
import io.github.javpower.vectorex.keynote.core.VectorData;
import io.github.javpower.vectorex.keynote.core.VectorSearchResult;
import io.github.javpower.vectorex.keynote.model.MetricType;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;

public class HNSWIndex {

    private HnswIndex<String, float[], VectorData, Float> index;

    public HNSWIndex(int dimensions, int maxDataCount, MetricType type) {
        DistanceFunction<float[], Float> distanceFunction = getDistanceFunction(type);
        if (distanceFunction == null) {
            throw new IllegalArgumentException("Unsupported distance function type: " + type);
        }
        this.index = HnswIndex.newBuilder(dimensions, distanceFunction, maxDataCount)
                .withM(16)
                .withEf(200)
                .withEfConstruction(200)
                .withRemoveEnabled()
                .build();
    }

    public void add(VectorData data) {
        index.add(data);
    }
    public void remove(String id) {
        Optional<VectorData> vectorData = index.get(id);
        if(vectorData.isPresent()){
            VectorData vector = vectorData.get();
            index.remove(vector.id(),System.currentTimeMillis());
        }
    }
    private DistanceFunction<float[], Float> getDistanceFunction(MetricType type) {
        switch (type) {
            case FLOAT_COSINE_DISTANCE:
                return DistanceFunctions.FLOAT_COSINE_DISTANCE;
            case FLOAT_INNER_PRODUCT:
                return DistanceFunctions.FLOAT_INNER_PRODUCT;
            case FLOAT_EUCLIDEAN_DISTANCE:
                return DistanceFunctions.FLOAT_EUCLIDEAN_DISTANCE;
            case FLOAT_CANBERRA_DISTANCE:
                return DistanceFunctions.FLOAT_CANBERRA_DISTANCE;
            case FLOAT_BRAY_CURTIS_DISTANCE:
                return DistanceFunctions.FLOAT_BRAY_CURTIS_DISTANCE;
            case FLOAT_CORRELATION_DISTANCE:
                return DistanceFunctions.FLOAT_CORRELATION_DISTANCE;
            case FLOAT_MANHATTAN_DISTANCE:
                return DistanceFunctions.FLOAT_MANHATTAN_DISTANCE;
            default:
                return null;
        }
    }

    public List<VectorSearchResult> search(List<Float> queryVector, int k) {
        float[] query = new float[queryVector.size()];
        for (int i = 0; i < queryVector.size(); i++) {
            query[i] = queryVector.get(i);
        }
        List<SearchResult<VectorData, Float>> results = index.findNearest(query, k);
        List<VectorSearchResult> vectorResults = new ArrayList<>();
        for (SearchResult<VectorData, Float> result : results) {
            VectorSearchResult res=new VectorSearchResult();
            res.setScore(result.distance());
            res.setId(result.item().id());
            vectorResults.add(res);
        }
        return vectorResults;
    }
    public List<VectorSearchResult> search(List<Float> queryVector, int k, Set<String> includedIds) {
        float[] query = new float[queryVector.size()];
        for (int i = 0; i < queryVector.size(); i++) {
            query[i] = queryVector.get(i);
        }
        List<SearchResult<VectorData, Float>> results = index.findNearestByIds(query, k,includedIds);
        List<VectorSearchResult> vectorResults = new ArrayList<>();
        for (SearchResult<VectorData, Float> result : results) {
            VectorSearchResult res=new VectorSearchResult();
            res.setScore(result.distance());
            res.setId(result.item().id());
            vectorResults.add(res);
        }
        return vectorResults;
    }
    //
}