package io.github.javpower.vectorex.keynote.storage;

import com.google.common.collect.Lists;
import io.github.javpower.vectorex.keynote.core.DbData;
import io.github.javpower.vectorex.keynote.core.VectorData;
import io.github.javpower.vectorex.keynote.core.VectorSearchResult;
import io.github.javpower.vectorex.keynote.index.bm25.Bm25IndexManager;
import io.github.javpower.vectorex.keynote.index.scalar.ScalarIndexManager;
import io.github.javpower.vectorex.keynote.index.vector.VectorIndexManager;
import io.github.javpower.vectorex.keynote.model.VectorFiled;
import io.github.javpower.vectorex.keynote.query.ConditionBuilder;
import org.mapdb.DB;
import org.mapdb.HTreeMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class MapDBStorage implements DataStore {
    private static final Logger logger = LoggerFactory.getLogger(MapDBStorage.class);
    private final DB db;
    private final HTreeMap<String, DbData> map;
    private final ScalarIndexManager scalarIndexManager;
    private final Map<String,VectorIndexManager> vectorIndexManagers=new ConcurrentHashMap<>();
    private final Bm25IndexManager bm25IndexManager;

    public MapDBStorage(DB db, HTreeMap<String, DbData> map, int maxDataCount, List<VectorFiled> vectorFiled) {
        try {
            this.db = db;
            this.map=map;
            scalarIndexManager = new ScalarIndexManager();
            bm25IndexManager=new Bm25IndexManager();
            if(vectorFiled!=null){
                for (VectorFiled filed : vectorFiled) {
                    VectorIndexManager vectorIndexManager = new VectorIndexManager(filed.getDimensions(), maxDataCount, filed.getMetricType());
                    vectorIndexManagers.put(filed.getName(), vectorIndexManager);
                }
            }
            logger.info("MapDB storage initialized");
        } catch (Exception e) {
            logger.error("Failed to initialize MapDB storage", e);
            throw new RuntimeException("Failed to initialize MapDB storage", e);
        }
    }
   public HTreeMap<String, DbData> getMap() {
        return map;
    }
    @Override
    public void save(DbData data) {
        try {
            map.put(data.getId(),data);
            scalarIndexManager.index(data);
            bm25IndexManager.index(data);
            if(data.getVectorFiled()!=null){
                for (VectorData vectorData : data.getVectorFiled()) {
                    vectorIndexManagers.get(vectorData.getName()).index(vectorData);
                }
            }
            db.commit();
            logger.debug("Saved vector data with ID: {}", data.getId());
        } catch (Exception e) {
            logger.error("Failed to save vector data with ID: {}", data.getId(), e);
            throw new RuntimeException("Failed to save vector data", e);
        }
    }

    @Override
    public void saveAll(List<DbData> dataList) {
        try {
            dataList.forEach(data -> {
                map.put(data.getId(), data);
                scalarIndexManager.index(data);
                bm25IndexManager.index(data);
                if(data.getVectorFiled()!=null){
                    for (VectorData vectorData : data.getVectorFiled()) {
                        vectorIndexManagers.get(vectorData.getName()).index(vectorData);
                    }
                }
            });
            db.commit();
            logger.debug("Saved {} vector data entries", dataList.size());
        } catch (Exception e) {
            logger.error("Failed to save batch vector data", e);
            throw new RuntimeException("Failed to save batch vector data", e);
        }
    }

    @Override
    public void update(DbData data) {
        try {
            DbData oldData = map.get(data.getId());
            if (oldData != null) {
                scalarIndexManager.remove(oldData.getId());
                bm25IndexManager.remove(oldData.getId());
                if(oldData.getVectorFiled()!=null){
                    for (VectorData vectorData : oldData.getVectorFiled()) {
                        vectorIndexManagers.get(vectorData.getName()).remove(vectorData.id());
                    }
                }
            }
            map.put(data.getId(), data);
            scalarIndexManager.index(data);
            bm25IndexManager.index(data);
            if(data.getVectorFiled()!=null){
                for (VectorData vectorData : data.getVectorFiled()) {
                    vectorIndexManagers.get(vectorData.getName()).index(vectorData);
                }
            }

            db.commit();
            logger.debug("Updated vector data with ID: {}", data.getId());
        } catch (Exception e) {
            logger.error("Failed to update vector data with ID: {}", data.getId(), e);
            throw new RuntimeException("Failed to update vector data", e);
        }
    }

    @Override
    public DbData get(String id) {
        DbData data = map.get(id);
        if (data == null) {
            logger.warn("Vector data with ID {} not found", id);
        }
        return data;
    }

    @Override
    public List<DbData> getAll() {
        return new ArrayList<>(map.values());
    }

    @Override
    public void delete(String id) {
        try {
            DbData data = map.get(id);
            if (data != null) {
                scalarIndexManager.remove(data.getId());
                bm25IndexManager.remove(data.getId());
                if(data.getVectorFiled()!=null){
                    for (VectorData vectorData : data.getVectorFiled()) {
                        vectorIndexManagers.get(vectorData.getName()).remove(vectorData.id());
                    }
                }
                map.remove(id);
                db.commit();
                logger.debug("Deleted vector data with ID: {}", id);
            }
        } catch (Exception e) {
            logger.error("Failed to delete vector data with ID: {}", id, e);
            throw new RuntimeException("Failed to delete vector data", e);
        }
    }

    @Override
    public void deleteAll(List<String> ids) {
        try {
            ids.forEach(id -> {
                DbData data = map.get(id);
                if (data != null) {
                    scalarIndexManager.remove(data.getId());
                    bm25IndexManager.remove(data.getId());
                    if(data.getVectorFiled()!=null){
                        for (VectorData vectorData : data.getVectorFiled()) {
                            vectorIndexManagers.get(vectorData.getName()).remove(vectorData.id());
                        }
                    }
                    map.remove(id);
                }
            });
            db.commit();
            logger.debug("Deleted {} vector data entries", ids.size());
        } catch (Exception e) {
            logger.error("Failed to delete batch vector data", e);
            throw new RuntimeException("Failed to delete batch vector data", e);
        }
    }

    @Override
    public void close() {
        try {
            db.close();
            logger.info("MapDB storage closed");
        } catch (Exception e) {
            logger.error("Failed to close MapDB storage", e);
            throw new RuntimeException("Failed to close MapDB storage", e);
        }
    }
    @Override
    public List<VectorSearchResult> search(String annsField, List<Float> queryVector, int k, ConditionBuilder conditionBuilder) {
        if(conditionBuilder == null){
            VectorIndexManager vectorIndexManager = vectorIndexManagers.get(annsField);
            return vectorIndexManager.search(queryVector, k).stream().peek(v-> v.setData(scalarIndexManager.getDbDataById(v.getId()))).collect(Collectors.toList());
        }
        //标量过滤
        List<String> filteredIds = scalarIndexManager.search(conditionBuilder);
        if(filteredIds!=null&&filteredIds.size()>0){
            //在过滤后的 ID 集合中进行向量搜索
            return vectorIndexManagers.get(annsField).search(queryVector, k, new HashSet<>(filteredIds)).stream().peek(v-> v.setData(scalarIndexManager.getDbDataById(v.getId()))).collect(Collectors.toList());
        }
        return Lists.newArrayList();
    }
    @Override
    public List<VectorSearchResult> search(String annsField,String queryVector, int k, ConditionBuilder conditionBuilder) {
        if(conditionBuilder == null){
            return bm25IndexManager.search(annsField,queryVector, k).stream().peek(v-> v.setData(scalarIndexManager.getDbDataById(v.getId()))).collect(Collectors.toList());
        }
        //标量过滤
        List<String> filteredIds = scalarIndexManager.search(conditionBuilder);
        if(filteredIds!=null&&filteredIds.size()>0){
            //在过滤后的 ID 集合中进行向量搜索
            return bm25IndexManager.search(annsField,queryVector, k, new HashSet<>(filteredIds)).stream().peek(v-> v.setData(scalarIndexManager.getDbDataById(v.getId()))).collect(Collectors.toList());
        }
        return Lists.newArrayList();
    }

    @Override
    public List<VectorSearchResult> query(ConditionBuilder conditionBuilder) {
        List<String> filteredIds = scalarIndexManager.search(conditionBuilder);
        return filteredIds.stream()
                .map(scalarIndexManager::getDbDataById).map(v->{
                    VectorSearchResult result=new VectorSearchResult();
                    result.setData(v);
                    result.setId(v.getId());
                    result.setScore(0.0f);
                    return result;
                })
                .collect(Collectors.toList());
    }

    @Override
    public void loadDataIntoManagers(DbData data) {
        scalarIndexManager.index(data);
        bm25IndexManager.index(data);
        for (VectorData vectorData : data.getVectorFiled()) {
            vectorIndexManagers.get(vectorData.getName()).index(vectorData);
        }
    }

}