package io.github.javpower.vectorex.keynote.index.bm25;

import io.github.javpower.vectorex.keynote.bm25.BM25;
import io.github.javpower.vectorex.keynote.core.DbData;
import io.github.javpower.vectorex.keynote.core.VectorSearchResult;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class Bm25IndexManager {

    // BM25 算法实例
    private final BM25 bm25 = new BM25(1.2, 0.75);

    // 字段 -> id -> 文本（用于 BM25 计算）
    private final Map<String, Map<String, String>> bm25Index = new ConcurrentHashMap<>();

    public Bm25IndexManager() {
    }

    /**
     * 添加或更新文档
     *
     * @param data 文档数据
     */
    public void index(DbData data) {
        String id = data.getId();
        Map<String, Object> metadata = data.getMetadata();
        // 遍历文档的元数据字段
        for (Map.Entry<String, Object> entry : metadata.entrySet()) {
            String field = entry.getKey();
            Object value = entry.getValue();
            // 只处理文本字段
            if (value instanceof String) {
                String text = (String) value;
                // 获取或创建字段对应的索引
                Map<String, String> fieldIndex = bm25Index.computeIfAbsent(field, k -> new ConcurrentHashMap<>());
                // 更新字段索引
                fieldIndex.put(id, text);
            }
        }
    }

    /**
     * 删除文档
     *
     * @param id 文档 ID
     */
    public void remove(String id) {
        // 遍历所有字段的索引，移除该文档 ID
        for (Map<String, String> fieldIndex : bm25Index.values()) {
            fieldIndex.remove(id);
        }
    }

    /**
     * 更新文档
     *
     * @param data 文档数据
     */
    public void update(DbData data) {
        remove(data.getId()); // 先删除旧数据
        index(data);         // 再添加新数据
    }

    /**
     * 搜索文档
     *
     * @param annsField    搜索字段
     * @param queryVector  查询文本
     * @param k            返回结果数量
     * @return 搜索结果列表
     */
    public List<VectorSearchResult> search(String annsField, String queryVector, int k) {
        return search(annsField, queryVector, k, null);
    }

    /**
     * 搜索文档（带过滤条件）
     *
     * @param annsField    搜索字段
     * @param queryVector  查询文本
     * @param k            返回结果数量
     * @param includedIds  允许的文档 ID 集合（过滤条件）
     * @return 搜索结果列表
     */
    public List<VectorSearchResult> search(String annsField, String queryVector, int k, Set<String> includedIds) {
        // 获取字段对应的索引
        Map<String, String> fieldIndex = bm25Index.get(annsField);
        if (fieldIndex == null || fieldIndex.isEmpty()) return Collections.emptyList();

        // 如果设置了过滤条件，只保留 includedIds 中的文档
        Map<String, String> filteredCorpora = new HashMap<>();
        for (Map.Entry<String, String> entry : fieldIndex.entrySet()) {
            String id = entry.getKey();
            String text = entry.getValue();
            if (includedIds == null || includedIds.contains(id)) {
                filteredCorpora.put(id, text);
            }
        }
        // 调用 BM25 进行排序
        Map<String, Double> rankBM25 = bm25.rankBM25(queryVector, filteredCorpora, k);
        // 转换为搜索结果列表
        List<VectorSearchResult> results = new ArrayList<>();
        for (Map.Entry<String, Double> entry : rankBM25.entrySet()) {
            VectorSearchResult vectorSearchResult = new VectorSearchResult();
            vectorSearchResult.setId(entry.getKey());
            vectorSearchResult.setScore(entry.getValue().floatValue());
            results.add(vectorSearchResult);
        }
        return results;
    }
}