package com.xzchaoo.commons.basic.consistenthash;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;

import static java.util.Collections.emptyList;

/**
 * Default consistent hash impl
 * created at 2020/7/26
 *
 * @author xzchaoo
 */
public class DefaultConsistentHash<N> implements ConsistentHash<N> {
    private final    int             vNodeCount;
    private final    NodeFunction<N> nodeFunction;
    private volatile State           state = new State(emptyList());

    /**
     * @param vNodeCopyCount Virtual node count for every physical node
     * @param nodeFunction   node function
     */
    public DefaultConsistentHash(int vNodeCopyCount,
        NodeFunction<N> nodeFunction) {
        this.vNodeCount = vNodeCopyCount;
        this.nodeFunction = Objects.requireNonNull(nodeFunction);
    }

    public DefaultConsistentHash(int vNodeCount, NodeFunction<N> nodeFunction,
        List<N> initNodes) {
        this.vNodeCount = vNodeCount;
        this.nodeFunction = Objects.requireNonNull(nodeFunction);
        this.update(initNodes);
    }

    @Override
    public N select(int hash) {
        State state = this.state;
        TreeMap<Integer, VNode<N>> tree = state.tree;
        if (tree.isEmpty()) {
            return null;
        }
        Map.Entry<Integer, VNode<N>> e = tree.ceilingEntry(hash);
        if (e != null) {
            return e.getValue().node;
        }
        return state.min.node;
    }

    @Override
    public Stat stat() {
        State state = this.state;
        Stat stat = new Stat();
        stat.setNodeSize(state.nodes.size());
        stat.setVnodeSize(state.tree.size());
        return stat;
    }

    public List<N> getNodes() {
        return state.nodes;
    }

    public void update(List<N> nodes) {
        this.state = new State(nodes);
    }

    private int compare(VNode<N> a, VNode<N> b) {
        int r = nodeFunction.compare(a.node, b.node);
        if (r != 0) {
            return r;
        }
        return Integer.compare(a.vIndex, b.vIndex);
    }

    /**
     * State of this instance
     */
    private class State {
        private final List<N>                    nodes;
        /**
         * Consistent hash ring
         */
        private final TreeMap<Integer, VNode<N>> tree = new TreeMap<>();
        /**
         * Min node
         */
        private final VNode<N>                   min;

        State(List<N> nodes) {
            this.nodes = Collections.unmodifiableList(new ArrayList<>(nodes));
            for (N node : nodes) {
                int nodeHash = nodeFunction.hash(node);
                for (int i = 0; i < vNodeCount; i++) {
                    VNode<N> vNode = new VNode<>(node, nodeHash, i);
                    tree.merge(vNode.hash, vNode, (a, b) -> {
                        if (compare(a, b) < 0) {
                            return a;
                        }
                        return b;
                    });
                }
            }
            Map.Entry<Integer, VNode<N>> first = tree.firstEntry();
            this.min = first != null ? first.getValue() : null;
        }
    }
}
