package shz.st.bst.lxx;

import shz.queue.LLinkedQueue;
import shz.st.bst.RedBlackBST;

/**
 * 健为K类型的红黑树
 */
public abstract class LXXRedBlackBST<K extends Comparable<K>> extends RedBlackBST<K> {
    protected static abstract class Node<K extends Comparable<K>> extends RedBlackBST.Node {
        protected K key;

        protected Node(K key, boolean red) {
            super(red);
            this.key = key;
        }
    }

    protected LXXRedBlackBST(Node<K> root) {
        super(root);
    }

    @SuppressWarnings("unchecked")
    protected final K key(RedBlackBST.Node h) {
        return ((Node<K>) h).key;
    }

    /**
     * 返回不大于hi的节点数量
     */
    public final int sizeLe(K hi) {
        if (hi == null) throw new NullPointerException();
        return sizeLe(root(), hi);
    }

    protected final int sizeLe(Node<K> h, K hi) {
        int res = 0;
        while (h != null) {
            if (hi.compareTo(h.key) < 0) h = h.left();
            else {
                res += 1 + size(h.left());
                h = h.right();
            }
        }
        return res;
    }

    /**
     * 返回不小于lo的节点数量
     */
    public final int sizeGe(K lo) {
        if (lo == null) throw new NullPointerException();
        return sizeGe(root(), lo);
    }

    protected final int sizeGe(Node<K> h, K lo) {
        int res = 0;
        while (h != null) {
            if (lo.compareTo(h.key) > 0) h = h.right();
            else {
                res += 1 + size(h.right());
                h = h.left();
            }
        }
        return res;
    }

    /**
     * 返回在区间[lo,hi]的节点数量
     */
    public final int size(K lo, K hi) {
        if (lo == null || hi == null) throw new NullPointerException();
        if (lo.compareTo(hi) > 0) throw new IllegalArgumentException();
        return size(root(), lo, hi);
    }

    protected final int size(Node<K> h, K lo, K hi) {
        int res = 0;
        while (h != null) {
            int cmp = lo.compareTo(h.key);
            if (cmp > 0) h = h.right();
            else if (cmp == 0) {
                res += 1 + sizeLe(h.right(), hi);
                break;
            } else if (hi.compareTo(h.key) < 0) h = h.left();
            else {
                res += sizeGe(h.left(), lo) + 1 + sizeLe(h.right(), hi);
                break;
            }
        }
        return res;
    }

    public final K floor(K key) {
        if (key == null) throw new NullPointerException();
        Node<K> h = floor(root(), key);
        return h == null ? null : h.key;
    }

    protected final Node<K> floor(Node<K> h, K key) {
        Node<K> res = null;
        while (h != null) {
            int cmp = key.compareTo(h.key);
            if (cmp == 0) return h;
            if (cmp < 0) h = h.left();
            else {
                res = h;
                h = h.right();
            }
        }
        return res;
    }

    public final K ceil(K key) {
        if (key == null) throw new NullPointerException();
        Node<K> h = ceil(root(), key);
        return h == null ? null : h.key;
    }

    protected final Node<K> ceil(Node<K> h, K key) {
        Node<K> res = null;
        while (h != null) {
            int cmp = key.compareTo(h.key);
            if (cmp == 0) return h;
            if (cmp > 0) h = h.right();
            else {
                res = h;
                h = h.left();
            }
        }
        return res;
    }

    /**
     * 查找所有键(中序遍历)
     */
    public final Iterable<K> keys() {
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keys(root(), queue);
        return queue;
    }

    protected final void keys(Node<K> h, LLinkedQueue<K> queue) {
        if (h == null) return;
        keys(h.left(), queue);
        queue.offer(h.key);
        keys(h.right(), queue);
    }

    /**
     * 查找所有不大于hi的键
     */
    public final Iterable<K> keysLe(K hi) {
        if (hi == null) throw new NullPointerException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keysLe(root(), queue, hi);
        return queue;
    }

    protected final void keysLe(Node<K> h, LLinkedQueue<K> queue, K hi) {
        if (h == null) return;
        if (hi.compareTo(h.key) < 0) keysLe(h.left(), queue, hi);
        else {
            keys(h.left(), queue);
            queue.offer(h.key);
            keysLe(h.right(), queue, hi);
        }
    }

    /**
     * 查找所有不小于lo的键
     */
    public final Iterable<K> keysGe(K lo) {
        if (lo == null) throw new NullPointerException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keysGe(root(), queue, lo);
        return queue;
    }

    protected final void keysGe(Node<K> h, LLinkedQueue<K> queue, K lo) {
        if (h == null) return;
        if (lo.compareTo(h.key) > 0) keysGe(h.right(), queue, lo);
        else {
            keysGe(h.left(), queue, lo);
            queue.offer(h.key);
            keys(h.right(), queue);
        }
    }

    /**
     * 查找所有在区间[lo,hi]的键
     */
    public final Iterable<K> keys(K lo, K hi) {
        if (lo == null || hi == null) throw new NullPointerException();
        if (lo.compareTo(hi) > 0) throw new IllegalArgumentException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keys(root(), queue, lo, hi);
        return queue;
    }

    protected final void keys(Node<K> h, LLinkedQueue<K> queue, K lo, K hi) {
        if (h == null) return;
        int cmp = lo.compareTo(h.key);
        if (cmp > 0) {
            keys(h.right(), queue, lo, hi);
        } else if (cmp == 0) {
            queue.offer(h.key);
            keysLe(h.right(), queue, hi);
        } else {
            if (hi.compareTo(h.key) >= 0) {
                keysGe(h.left(), queue, lo);
                queue.offer(h.key);
                keysLe(h.right(), queue, hi);
            } else keys(h.left(), queue, lo, hi);
        }
    }
}
