package shz.st.tst;

import shz.queue.CArrayQueue;
import shz.queue.LLinkedQueue;

import java.util.Collections;
import java.util.Iterator;
import java.util.function.Predicate;

/**
 * 基于三向单词查找树的符号表
 * <p>
 * 由N个平均长度为w的字符串构造的三向单词查找树中的链接总数在3N到3Nw之间
 * <p>
 * 在一棵由N个随机字符串构造的三向单词查找树中，查找未命中平均需要比较字符～lnN次。除～lnN次外，一次插入或命中的查找会比较一次被查找的键中的每个字符
 */
@SuppressWarnings("unchecked")
public abstract class TST {
    protected static abstract class Node {
        public char c;
        public Node left, mid, right;
        public boolean leaf;

        protected Node(char c) {
            this.c = c;
        }

        public final <T extends Node> T left() {
            return (T) left;
        }

        public final <T extends Node> T mid() {
            return (T) mid;
        }

        public final <T extends Node> T right() {
            return (T) right;
        }
    }

    protected Node root;

    protected TST() {
    }

    protected final <T extends Node> T root() {
        return (T) root;
    }

    protected final <T extends Node> T get(Node x, char[] a, int d, int n) {
        if (x == null) return null;
        if (a[d] < x.c) return get(x.left, a, d, n);
        if (a[d] > x.c) return get(x.right, a, d, n);
        if (d < n) return get(x.mid, a, d + 1, n);
        return (T) x;
    }

    protected final <T extends Node> T get(Node x, char[] a, int d) {
        return get(x, a, d, a.length - 1);
    }

    public final void delete(char[] a) {
        if (a == null) throw new NullPointerException();
        Node p = get(root, a, 0, a.length - 2);
        if (p == null) return;
        if (p.left != null && p.left.c == a[a.length - 1]) {
            p.left.leaf = false;
            if (p.left.left == null && p.left.mid == null && p.left.right == null) p.left = null;
        } else if (p.mid != null && p.mid.c == a[a.length - 1]) {
            p.mid.leaf = false;
            if (p.mid.left == null && p.mid.mid == null && p.mid.right == null) p.mid = null;
        } else if (p.right != null && p.right.c == a[a.length - 1]) {
            p.right.leaf = false;
            if (p.right.left == null && p.right.mid == null && p.right.right == null) p.right = null;
        }
    }

    protected final Iterable<char[]> getChars0(Predicate<Node> predicate, int limit) {
        LLinkedQueue<char[]> result = LLinkedQueue.of();
        CArrayQueue key = CArrayQueue.of(16);
        getChars0(result, key, root, predicate, limit);
        return result.isEmpty() ? Collections.emptyList() : result;
    }

    private void getChars0(LLinkedQueue<char[]> result, CArrayQueue key, Node x, Predicate<Node> predicate, int limit) {
        if (x == null || (limit > 0 && result.size() >= limit)) return;
        if (x.leaf && predicate.test(x)) {
            char[] chars = new char[key.size() + 1];
            Iterator<Character> it = key.iterator();
            int i = 0;
            while (it.hasNext()) chars[i++] = it.next();
            chars[i] = x.c;
            result.offer(chars);
        }
        if (x.left != null) getChars0(result, key, x.left, predicate, limit);
        if (x.mid != null) {
            key.offer(x.c);
            getChars0(result, key, x.mid, predicate, limit);
            key.removeTail();
        }
        if (x.right != null) getChars0(result, key, x.right, predicate, limit);
    }
}
