package shz.st.bst.jxx;

import shz.queue.JLinkedQueue;
import shz.st.bst.RedBlackBST;

/**
 * 健为long类型的红黑树
 */
public abstract class JXXRedBlackBST extends RedBlackBST<Long> {
    protected static abstract class Node extends RedBlackBST.Node {
        protected long key;

        protected Node(long key, boolean red) {
            super(red);
            this.key = key;
        }
    }

    protected JXXRedBlackBST(Node root) {
        super(root);
    }

    protected final Long key(RedBlackBST.Node h) {
        return ((Node) h).key;
    }

    /**
     * 返回不大于hi的节点数量
     */
    public final int sizeLe(long hi) {
        return sizeLe(root(), hi);
    }

    protected final int sizeLe(Node h, long hi) {
        int res = 0;
        while (h != null) {
            if (hi < h.key) h = h.left();
            else {
                res += 1 + size(h.left);
                h = h.right();
            }
        }
        return res;
    }

    /**
     * 返回不小于lo的节点数量
     */
    public final int sizeGe(long lo) {
        return sizeGe(root(), lo);
    }

    protected final int sizeGe(Node x, long lo) {
        int res = 0;
        while (x != null) {
            if (lo > x.key) x = x.right();
            else {
                res += 1 + size(x.right);
                x = x.left();
            }
        }
        return res;
    }

    /**
     * 返回在区间[lo,hi]的节点数量
     */
    public final int size(long lo, long hi) {
        if (lo > hi) throw new IllegalArgumentException();
        return size(root(), lo, hi);
    }

    protected final int size(Node x, long lo, long hi) {
        int res = 0;
        while (x != null) {
            if (lo > x.key) x = x.right();
            else if (lo == x.key) {
                res += 1 + sizeLe(x.right(), hi);
                break;
            } else if (hi < x.key) x = x.left();
            else {
                res += sizeGe(x.left(), lo) + 1 + sizeLe(x.right(), hi);
                break;
            }
        }
        return res;
    }

    public final Long floor(long key) {
        Node h = floor(root(), key);
        return h == null ? null : h.key;
    }

    protected final Node floor(Node h, long key) {
        Node res = null;
        while (h != null) {
            if (key == h.key) return h;
            if (key < h.key) h = h.left();
            else {
                res = h;
                h = h.right();
            }
        }
        return res;
    }

    public final Long ceil(long key) {
        Node h = ceil(root(), key);
        return h == null ? null : h.key;
    }

    protected final Node ceil(Node h, long key) {
        Node res = null;
        while (h != null) {
            if (key == h.key) return h;
            if (key > h.key) h = h.right();
            else {
                res = h;
                h = h.left();
            }
        }
        return res;
    }

    /**
     * 查找所有键(中序遍历)
     */
    public final Iterable<Long> keys() {
        JLinkedQueue queue = JLinkedQueue.of();
        keys(root(), queue);
        return queue;
    }

    protected final void keys(Node h, JLinkedQueue queue) {
        if (h == null) return;
        keys(h.left(), queue);
        queue.offer(h.key);
        keys(h.right(), queue);
    }

    /**
     * 查找所有不大于hi的键
     */
    public final Iterable<Long> keysLe(long hi) {
        JLinkedQueue queue = JLinkedQueue.of();
        keysLe(root(), queue, hi);
        return queue;
    }

    protected final void keysLe(Node h, JLinkedQueue queue, long hi) {
        if (h == null) return;
        if (hi < h.key) keysLe(h.left(), queue, hi);
        else {
            keys(h.left(), queue);
            queue.offer(h.key);
            keysLe(h.right(), queue, hi);
        }
    }

    /**
     * 查找所有不小于lo的键
     */
    public final Iterable<Long> keysGe(long lo) {
        JLinkedQueue queue = JLinkedQueue.of();
        keysGe(root(), queue, lo);
        return queue;
    }

    protected final void keysGe(Node h, JLinkedQueue queue, long lo) {
        if (h == null) return;
        if (lo > h.key) keysGe(h.right(), queue, lo);
        else {
            keysGe(h.left(), queue, lo);
            queue.offer(h.key);
            keys(h.right(), queue);
        }
    }

    /**
     * 查找所有在区间[lo,hi]的键
     */
    public final Iterable<Long> keys(long lo, long hi) {
        if (lo > hi) throw new IllegalArgumentException();
        JLinkedQueue queue = JLinkedQueue.of();
        keys(root(), queue, lo, hi);
        return queue;
    }

    protected final void keys(Node h, JLinkedQueue queue, long lo, long hi) {
        if (h == null) return;
        if (lo > h.key) {
            keys(h.right(), queue, lo, hi);
        } else if (lo == h.key) {
            queue.offer(h.key);
            keysLe(h.right(), queue, hi);
        } else {
            if (hi >= h.key) {
                keysGe(h.left(), queue, lo);
                queue.offer(h.key);
                keysLe(h.right(), queue, hi);
            } else keys(h.left(), queue, lo, hi);
        }
    }
}
