package io.github.javpower.vectorex.keynote.utils;

import io.github.javpower.vectorex.keynote.VectorDB;
import io.github.javpower.vectorex.keynote.core.DbData;
import io.github.javpower.vectorex.keynote.core.VectorData;
import io.github.javpower.vectorex.keynote.core.VectorSearchResult;
import io.github.javpower.vectorex.keynote.model.MetricType;
import io.github.javpower.vectorex.keynote.model.VectorFiled;
import io.github.javpower.vectorex.keynote.storage.MapDBStorage;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class VectorTest {
//    public static void main(String[] args) {
//        // 初始化 VectorDB
//        VectorDB vectorDB = new VectorDB("vectorex/vectorex.db");
//        vectorDB.init();
//
//        // 创建集合
//        VectorFiled vectorFiled = new VectorFiled();
//        vectorFiled.setDimensions(3);
//        vectorFiled.setName("vector");
//        vectorFiled.setMetricType(MetricType.FLOAT_COSINE_DISTANCE);
//        VectorDB.createCollection("face", 1000, null);
//
//        // 获取数据存储
//        MapDBStorage face = VectorDB.getDataStore("face");
//
//        // 添加英文数据
//        addEnglishData(face);
//
//        // 添加中文数据
//        addChineseData(face);
//
//        // 向量搜索（英文）
//        System.out.println("英文搜索测试：");
//        searchAndPrintResults(face, "category", "ani", 1);
//
//        // 向量搜索（中文）
//        System.out.println("中文搜索测试：");
//        searchAndPrintResults(face, "category", "动物", 1);
//
//        // 性能测试
//        System.out.println("性能测试：");
//        performanceTest(face, "category", "ani", 100);
//
//    }

    /**
     * 添加英文测试数据
     */
    private static void addEnglishData(MapDBStorage face) {
        Map<String, Object> metadata1 = new HashMap<>();
        metadata1.put("id", "1");
        metadata1.put("category", "animal");
        metadata1.put("size", "small");

        Map<String, Object> metadata2 = new HashMap<>();
        metadata2.put("id", "2");
        metadata2.put("category", "animal vv");
        metadata2.put("size", "large");

        VectorData data1 = new VectorData("1", new float[]{1.0f, 2.0f, 3.0f});
        data1.setName("vector");

        VectorData data2 = new VectorData("2", new float[]{4.0f, 5.0f, 6.0f});
        data2.setName("vector");

        DbData dbData1 = new DbData();
        dbData1.setId("1");
        dbData1.setMetadata(metadata1);
//        dbData1.setVectorFiled(Lists.newArrayList(data1));

        DbData dbData2 = new DbData();
        dbData2.setId("2");
        dbData2.setMetadata(metadata2);
//        dbData2.setVectorFiled(Lists.newArrayList(data2));

        face.save(dbData1);
        face.save(dbData2);
    }

    /**
     * 添加中文测试数据
     */
    private static void addChineseData(MapDBStorage face) {
        Map<String, Object> metadata3 = new HashMap<>();
        metadata3.put("id", "3");
        metadata3.put("category", "动物");
        metadata3.put("size", "小");

        Map<String, Object> metadata4 = new HashMap<>();
        metadata4.put("id", "4");
        metadata4.put("category", "动物 世界");
        metadata4.put("size", "大");

        VectorData data3 = new VectorData("3", new float[]{7.0f, 8.0f, 9.0f});
        data3.setName("vector");

        VectorData data4 = new VectorData("4", new float[]{10.0f, 11.0f, 12.0f});
        data4.setName("vector");

        DbData dbData3 = new DbData();
        dbData3.setId("3");
        dbData3.setMetadata(metadata3);
//        dbData3.setVectorFiled(Lists.newArrayList(data3));

        DbData dbData4 = new DbData();
        dbData4.setId("4");
        dbData4.setMetadata(metadata4);
//        dbData4.setVectorFiled(Lists.newArrayList(data4));

        face.save(dbData3);
        face.save(dbData4);
    }

    /**
     * 搜索并打印结果
     */
    private static void searchAndPrintResults(MapDBStorage face, String field, String query, int topK) {
        List<VectorSearchResult> results = face.search(field, query, topK, null);
        if (results.isEmpty()) {
            System.out.println("未找到匹配结果。");
        } else {
            for (VectorSearchResult result : results) {
                System.out.println("找到匹配结果: ID = " + result.getId() + ", 得分 = " + result.getScore());
            }
        }
        System.out.println();
    }

    /**
     * 性能测试
     */
    private static void performanceTest(MapDBStorage face, String field, String query, int iterations) {
        List<Long> timings = new ArrayList<>();
        for (int i = 0; i < iterations; i++) {
            long startTime = System.nanoTime();
            face.search(field, query, 2, null);
            long endTime = System.nanoTime();
            timings.add(endTime - startTime);
        }

        // 输出性能测试结果
        System.out.println("每次搜索耗时（纳秒）：");
        System.out.println(timings.stream().map(String::valueOf).collect(Collectors.joining(",")));

        // 计算平均耗时
        double averageTime = timings.stream().mapToLong(Long::longValue).average().orElse(0);
        System.out.println("平均耗时（纳秒）：" + averageTime);
    }
}