package io.github.javpower.vectorex.keynote.graph.core;

import io.github.javpower.vectorex.keynote.graph.entity.EdgeInfo;
import io.github.javpower.vectorex.keynote.graph.entity.Node;
import io.github.javpower.vectorex.keynote.graph.entity.Relationship;
import org.mapdb.DB;
import org.mapdb.DBMaker;
import org.mapdb.HTreeMap;
import org.mapdb.Serializer;

import java.io.Serializable;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;

public class GraphDB {
    private final DB db;
    private final HTreeMap<String, Node> nodes;
    private final HTreeMap<String, Relationship> relationships;
    private final HTreeMap<String, Set<String>> labelIndex;
    private final HTreeMap<String, Set<String>> relTypeIndex;
    private final HTreeMap<String, Map<Object, Set<String>>> nodePropertyIndex;
    private final HTreeMap<String, List<Relationship>> inEdges;
    private final AdjacencyCache adjacencyCache;

    public GraphDB(String dbPath) {
        db = DBMaker.fileDB(dbPath).closeOnJvmShutdown().make();
        nodes = db.hashMap("nodes", Serializer.STRING, Serializer.JAVA).createOrOpen();
        relationships = db.hashMap("relationships", Serializer.STRING, Serializer.JAVA).createOrOpen();
        labelIndex = db.hashMap("labelIndex", Serializer.STRING, Serializer.JAVA).createOrOpen();
        relTypeIndex = db.hashMap("relTypeIndex", Serializer.STRING, Serializer.JAVA).createOrOpen();
        nodePropertyIndex = db.hashMap("nodePropertyIndex", Serializer.STRING, Serializer.JAVA).createOrOpen();
        inEdges = db.hashMap("inEdges", Serializer.STRING, Serializer.JAVA).createOrOpen();
        adjacencyCache = new AdjacencyCache(this);
    }

    public void addNode(Node node) {
        nodes.put(node.getId(), node);
        updateLabelIndex(node);
        updatePropertyIndex(node);
        adjacencyCache.updateCacheForNode(node.getId());
    }

    public void addRelationship(Relationship rel) {
        if (!nodes.containsKey(rel.getStartNodeId()) || !nodes.containsKey(rel.getEndNodeId())) {
            throw new IllegalArgumentException("Start or end node does not exist");
        }

        relationships.put(rel.getId(), rel);

        // 更新出边（强制持久化）
        Node startNode = nodes.get(rel.getStartNodeId());
        startNode.getOutgoingEdgesInternal().add(rel);
        nodes.put(rel.getStartNodeId(), startNode);

        // 更新入边索引
        inEdges.computeIfAbsent(rel.getEndNodeId(), k -> new CopyOnWriteArrayList<>()).add(rel);

        // 更新关系类型索引
        relTypeIndex.computeIfAbsent(rel.getType(), k -> ConcurrentHashMap.newKeySet()).add(rel.getId());

        // 更新缓存
        adjacencyCache.updateCacheForNode(rel.getStartNodeId());
        adjacencyCache.updateCacheForNode(rel.getEndNodeId());
    }

    public List<Relationship> getInEdges(String nodeId) {
        return inEdges.getOrDefault(nodeId, Collections.emptyList());
    }
    // 根据属性查询节点
    public List<Node> findNodesByProperty(String key, Object value) {
        return nodes.values().stream()
                .filter(node -> value.equals(node.getProperty(key)))
                .collect(Collectors.toList());
    }

    // 根据类型（标签）查询节点
    public List<Node> findNodesByLabel(String label) {
        return labelIndex.getOrDefault(label, Collections.emptySet()).stream()
                .map(nodes::get)
                .collect(Collectors.toList());
    }

    // 根据关系类型查询边
    public List<Relationship> findRelationshipsByType(String type) {
        return relTypeIndex.getOrDefault(type, Collections.emptySet()).stream()
                .map(relationships::get)
                .collect(Collectors.toList());
    }

    // 查询节点的邻居
    public List<Node> findNeighbors(String nodeId) {
        Node node = nodes.get(nodeId);
        if (node == null) return Collections.emptyList();
        return node.getOutgoingEdges().stream()
                .map(rel -> nodes.get(rel.getEndNodeId()))
                .collect(Collectors.toList());
    }
    // 查询两个节点之间的路径（简单实现）
    public List<String> shortestPathWithWeights(String startId, String endId) {
        if (startId.equals(endId)) {
            return Collections.singletonList(startId); // 自环路径
        }
        if (!containsNode(startId) || !containsNode(endId)) {
            return Collections.emptyList(); // 起点或终点不存在
        }

        PriorityQueue<NodeDistance> queue = new PriorityQueue<>(Comparator.comparingDouble(n -> n.distance));
        Map<String, Double> distances = new ConcurrentHashMap<>();
        Map<String, String> predecessors = new ConcurrentHashMap<>();

        // 初始化距离
        getAllNodeIds().forEach(id -> distances.put(id, Double.MAX_VALUE));
        distances.put(startId, 0.0);
        queue.add(new NodeDistance(startId, 0.0));

        while (!queue.isEmpty()) {
            NodeDistance current = queue.poll();
            if (current.nodeId.equals(endId)) {
                break; // 找到终点
            }

            for (EdgeInfo edge : adjacencyCache.getOutEdges(current.nodeId)) {
                double newDist = current.distance + edge.getWeight();
                if (newDist < distances.get(edge.getTargetId())) {
                    distances.put(edge.getTargetId(), newDist);
                    predecessors.put(edge.getTargetId(), current.nodeId);
                    queue.add(new NodeDistance(edge.getTargetId(), newDist));
                }
            }
        }

        return buildPath(predecessors, endId);
    }

    private List<String> buildPath(Map<String, String> predecessors, String endId) {
        if (!predecessors.containsKey(endId) && !endId.equals(predecessors.get(endId))) {
            return Collections.emptyList(); // 路径不存在
        }

        LinkedList<String> path = new LinkedList<>();
        String current = endId;
        while (current != null) {
            path.addFirst(current);
            current = predecessors.get(current);
        }

        return path;
    }

    public void updateAdjacencyCache(String nodeId) {
        adjacencyCache.updateCacheForNode(nodeId);
        // 同时更新关联节点
        getInEdges(nodeId).forEach(rel ->
                adjacencyCache.updateCacheForNode(rel.getStartNodeId()));
    }

    public boolean containsNode(String id) {
        return nodes.containsKey(id);
    }

    public Node getNode(String id) {
        return nodes.get(id);
    }

    public Set<String> getAllNodeIds() {
        return nodes.keySet();
    }

    private static class NodeDistance implements Comparable<NodeDistance>, Serializable {
        final String nodeId;
        final double distance;

        NodeDistance(String nodeId, double distance) {
            this.nodeId = nodeId;
            this.distance = distance;
        }

        @Override
        public int compareTo(NodeDistance o) {
            return Double.compare(this.distance, o.distance);
        }
    }

    // 索引维护方法
    private void updateLabelIndex(Node node) {
        node.getLabels().forEach(label ->
                labelIndex.computeIfAbsent(label, k -> ConcurrentHashMap.newKeySet())
                        .add(node.getId())
        );
    }

    private void updatePropertyIndex(Node node) {
        node.getProperties().forEach((key, value) ->
                nodePropertyIndex.computeIfAbsent(key, k -> new ConcurrentHashMap<>())
                        .computeIfAbsent(value, k -> ConcurrentHashMap.newKeySet())
                        .add(node.getId())
        );
    }

    // 清理方法
    public Node removeNode(String nodeId) {
        Node node = nodes.remove(nodeId);
        if (node != null) {
            // 清理出边
            new ArrayList<>(node.getOutgoingEdges()).forEach(rel -> removeRelationship(rel.getId()));
            // 清理入边
            List<Relationship> inRels = inEdges.remove(nodeId);
            if (inRels != null) {
                inRels.forEach(rel -> relationships.remove(rel.getId()));
            }
            // 清理索引
            removeFromLabelIndex(node);
            removeFromPropertyIndex(node);
        }
        return node;
    }

    private void removeFromLabelIndex(Node node) {
        node.getLabels().forEach(label -> {
            Set<String> ids = labelIndex.get(label);
            if (ids != null) ids.remove(node.getId());
        });
    }

    private void removeFromPropertyIndex(Node node) {
        node.getProperties().forEach((key, value) -> {
            Map<Object, Set<String>> valueMap = nodePropertyIndex.get(key);
            if (valueMap != null) {
                Set<String> ids = valueMap.get(value);
                if (ids != null) ids.remove(node.getId());
            }
        });
    }

    public Relationship removeRelationship(String relId) {
        Relationship rel = relationships.remove(relId);
        if (rel != null) {
            // 清理出边
            Node startNode = nodes.get(rel.getStartNodeId());
            if (startNode != null) {
                startNode.getOutgoingEdgesInternal().removeIf(e -> e.getId().equals(relId));
                nodes.put(rel.getStartNodeId(), startNode);
            }
            // 清理入边
            List<Relationship> inRelList = inEdges.get(rel.getEndNodeId());
            if (inRelList != null) inRelList.removeIf(r -> r.getId().equals(relId));
            // 清理关系索引
            Set<String> rels = relTypeIndex.get(rel.getType());
            if (rels != null) rels.remove(relId);
        }
        return rel;
    }
}