package shz.st.bst;

import shz.queue.LLinkedQueue;

/**
 * 二叉查找树
 * <p>
 * 在由N个随机键构造的二叉查找树中，插入和查找平均所需的比较次数为～2lnN（约1.39lgN）
 * <p>
 * 二叉查找树得以广泛应用的一个重要原因就是它能够保持键的有序性，因此它可以作为实现有序符号表API中的众多方法的基础。
 * 这使得符号表的用例不仅能够通过键还能通过键的相对顺序来访问键值对
 * <p>
 * 在一棵二叉查找树中，所有操作在最坏情况下所需的时间都和树的高度成正比
 * <p>
 * 8+[52+K(类型字节)+V(类型字节)+对齐填充]*n(n为元素个数)
 * <p>
 * B=24+48*n+(4+K+V+对齐填充)*n
 */
public class BST<K extends Comparable<K>, V> {
    /**
     * 16+8+K(类型字节)+8+V(类型字节)+4+8*2+对齐填充
     * <p>
     * B=52+K(类型字节)+V(类型字节)+对齐填充
     */
    protected static class Node<K extends Comparable<K>, V> {
        public K key;
        public V val;
        public Node<K, V> left, right;
        public int size = 1;

        public Node(K key, V val) {
            this.key = key;
            this.val = val;
        }
    }

    protected Node<K, V> root;

    protected BST(K key, V val) {
        root = new Node<>(key, val);
    }

    public static <K extends Comparable<K>, V> BST<K, V> of(K key, V val) {
        if (key == null) throw new NullPointerException();
        return new BST<>(key, val);
    }

    public static <K extends Comparable<K>, V> BST<K, V> of(K key) {
        return of(key, null);
    }

    public final int size() {
        return size(root);
    }

    protected final int size(Node<K, V> x) {
        return x == null ? 0 : x.size;
    }

    /**
     * 返回不大于key的节点数量
     */
    public final int sizeLe(K hi) {
        if (hi == null) throw new NullPointerException();
        return sizeLe(root, hi);
    }

    protected final int sizeLe(Node<K, V> x, K hi) {
        int res = 0;
        while (x != null) {
            if (hi.compareTo(x.key) < 0) x = x.left;
            else {
                res += 1 + size(x.left);
                x = x.right;
            }
        }
        return res;
    }

    /**
     * 返回不小于key的节点数量
     */
    public final int sizeGe(K lo) {
        if (lo == null) throw new NullPointerException();
        return sizeGe(root, lo);
    }

    protected final int sizeGe(Node<K, V> x, K lo) {
        int res = 0;
        while (x != null) {
            if (lo.compareTo(x.key) > 0) x = x.right;
            else {
                res += 1 + size(x.right);
                x = x.left;
            }
        }
        return res;
    }

    /**
     * 返回在区间[lo,hi]的节点数量
     */
    public final int size(K lo, K hi) {
        if (lo == null || hi == null) throw new NullPointerException();
        if (lo.compareTo(hi) > 0) throw new IllegalArgumentException();
        return size(root, lo, hi);
    }

    protected final int size(Node<K, V> x, K lo, K hi) {
        int res = 0;
        while (x != null) {
            int cmp_lo = lo.compareTo(x.key);
            if (cmp_lo > 0) x = x.right;
            else if (cmp_lo == 0) {
                res += 1 + sizeLe(x.right, hi);
                break;
            } else if (hi.compareTo(x.key) < 0) x = x.left;
            else {
                res += sizeGe(x.left, lo) + 1 + sizeLe(x.right, hi);
                break;
            }
        }
        return res;
    }

    public final boolean isEmpty() {
        return size() == 0;
    }

    public final boolean isLeaf() {
        return size() == 1;
    }

    public final void put(K key, V val) {
        if (key == null) throw new NullPointerException();
        root = put(root, key, val);
    }

    protected final Node<K, V> put(Node<K, V> x, K key, V val) {
        if (x == null) return new Node<>(key, val);
        int cmp = key.compareTo(x.key);
        if (cmp < 0) x.left = put(x.left, key, val);
        else if (cmp > 0) x.right = put(x.right, key, val);
        else x.val = val;
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }

    public final V get(K key) {
        if (key == null) throw new NullPointerException();
        Node<K, V> x = get(root, key);
        return x == null ? null : x.val;
    }

    protected final Node<K, V> get(Node<K, V> x, K key) {
        while (x != null) {
            int cmp = key.compareTo(x.key);
            if (cmp == 0) break;
            x = cmp < 0 ? x.left : x.right;
        }
        return x;
    }

    public final K min() {
        return root == null ? null : min(root).key;
    }

    protected final Node<K, V> min(Node<K, V> x) {
        Node<K, V> l;
        while ((l = x.left) != null) x = l;
        return x;
    }

    public final K max() {
        return root == null ? null : max(root).key;
    }

    protected final Node<K, V> max(Node<K, V> x) {
        Node<K, V> r;
        while ((r = x.right) != null) x = r;
        return x;
    }

    /**
     * 树的高度
     */
    public final int depth() {
        return depth(root);
    }

    protected final int depth(Node<K, V> x) {
        if (x == null) return 0;
        return Math.max(depth(x.left), depth(x.right)) + 1;
    }

    public final K floor(K key) {
        if (key == null) throw new NullPointerException();
        Node<K, V> x = floor(root, key);
        return x == null ? null : x.key;
    }

    protected final Node<K, V> floor(Node<K, V> x, K key) {
        Node<K, V> res = null;
        while (x != null) {
            int cmp = key.compareTo(x.key);
            if (cmp == 0) return x;
            if (cmp < 0) x = x.left;
            else {
                res = x;
                x = x.right;
            }
        }
        return res;
    }

    public final K ceil(K key) {
        if (key == null) throw new NullPointerException();
        Node<K, V> x = ceil(root, key);
        return x == null ? null : x.key;
    }

    protected final Node<K, V> ceil(Node<K, V> x, K key) {
        Node<K, V> res = null;
        while (x != null) {
            int cmp = key.compareTo(x.key);
            if (cmp == 0) return x;
            if (cmp > 0) x = x.right;
            else {
                res = x;
                x = x.left;
            }
        }
        return res;
    }

    /**
     * 返回排名为k的节点key
     */
    public final K select(int k) {
        if (k < 1) throw new IllegalArgumentException();
        Node<K, V> x = select(root, k);
        return x == null ? null : x.key;
    }

    protected final Node<K, V> select(Node<K, V> x, int k) {
        while (x != null) {
            int k0 = k - size(x.left) - 1;
            if (k0 == 0) break;
            if (k0 < 0) x = x.left;
            else {
                x = x.right;
                k = k0;
            }
        }
        return x;
    }

    public final void deleteMin() {
        if (root == null) return;
        root = deleteMin(root);
    }

    protected final Node<K, V> deleteMin(Node<K, V> x) {
        if (x.left == null) return x.right;
        x.left = deleteMin(x.left);
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }

    public final void deleteMax() {
        if (root == null) return;
        root = deleteMax(root);
    }

    protected final Node<K, V> deleteMax(Node<K, V> x) {
        if (x.right == null) return x.left;
        x.right = deleteMax(x.right);
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }

    public final void delete(K key) {
        if (key == null) throw new NullPointerException();
        root = delete(root, key);
    }

    /**
     * 一般情况下这段代码的效率不错，但对于大规模的应用来说可能会有一点问题
     */
    protected final Node<K, V> delete(Node<K, V> x, K key) {
        if (x == null) return null;
        int cmp = key.compareTo(x.key);
        if (cmp < 0) x.left = delete(x.left, key);
        else if (cmp > 0) x.right = delete(x.right, key);
        else {
            if (x.left == null) return x.right;
            if (x.right == null) return x.left;
            Node<K, V> t = x;
            //这里使用右子树最小的节点补位
            x = min(t.right);
            x.left = t.left;
            x.right = deleteMin(t.right);
        }
        x.size = size(x.left) + size(x.right) + 1;
        return x;
    }

    /**
     * 查找所有键(中序遍历)
     */
    public final Iterable<K> keys() {
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keys(root, queue);
        return queue;
    }

    protected final void keys(Node<K, V> x, LLinkedQueue<K> queue) {
        if (x == null) return;
        keys(x.left, queue);
        queue.offer(x.key);
        keys(x.right, queue);
    }

    /**
     * 查找所有不大于hi的键
     */
    public final Iterable<K> keysLe(K hi) {
        if (hi == null) throw new NullPointerException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keysLe(root, queue, hi);
        return queue;
    }

    protected final void keysLe(Node<K, V> x, LLinkedQueue<K> queue, K hi) {
        if (x == null) return;
        int cmp = hi.compareTo(x.key);
        if (cmp < 0) keysLe(x.left, queue, hi);
        else {
            keys(x.left, queue);
            queue.offer(x.key);
            keysLe(x.right, queue, hi);
        }
    }

    /**
     * 查找所有不小于lo的键
     */
    public final Iterable<K> keysGe(K lo) {
        if (lo == null) throw new NullPointerException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keysGe(root, queue, lo);
        return queue;
    }

    protected final void keysGe(Node<K, V> x, LLinkedQueue<K> queue, K lo) {
        if (x == null) return;
        int cmp = lo.compareTo(x.key);
        if (cmp > 0) keysGe(x.right, queue, lo);
        else {
            keysGe(x.left, queue, lo);
            queue.offer(x.key);
            keys(x.right, queue);
        }
    }

    /**
     * 查找所有在区间[lo,hi]的键
     */
    public final Iterable<K> keys(K lo, K hi) {
        if (lo == null || hi == null) throw new NullPointerException();
        if (lo.compareTo(hi) > 0) throw new IllegalArgumentException();
        LLinkedQueue<K> queue = LLinkedQueue.of();
        keys(root, queue, lo, hi);
        return queue;
    }

    protected final void keys(Node<K, V> x, LLinkedQueue<K> queue, K lo, K hi) {
        if (x == null) return;
        int cmp_lo = lo.compareTo(x.key);
        if (cmp_lo > 0) {
            keys(x.right, queue, lo, hi);
        } else if (cmp_lo == 0) {
            queue.offer(x.key);
            keysLe(x.right, queue, hi);
        } else {
            int cmp_hi = hi.compareTo(x.key);
            if (cmp_hi >= 0) {
                keysGe(x.left, queue, lo);
                queue.offer(x.key);
                keysLe(x.right, queue, hi);
            } else keys(x.left, queue, lo, hi);
        }
    }
}

