package shz.st.tst;

import shz.queue.CArrayQueue;
import shz.stack.LLinkedStack;

import java.util.Collections;
import java.util.function.Predicate;

/**
 * 值为char类型的TST
 * <p>
 * 8+48*n(n为元素个数)
 * <p>
 * B=24+48*n
 */
public class CTST extends TST {
    /**
     * 2+27+对齐填充=32
     * <p>
     * B=48
     */
    protected static final class Node extends TST.Node {
        public char val;

        public Node(char c) {
            super(c);
        }
    }

    protected CTST() {
    }

    public static CTST of() {
        return new CTST();
    }

    public final void put(char[] a, char val) {
        if (a == null) throw new NullPointerException();
        root = put(root(), a, val, 0);
    }

    protected final Node put(Node x, char[] a, char val, int d) {
        if (x == null) x = new Node(a[d]);
        if (a[d] < x.c) x.left = put(x.left(), a, val, d);
        else if (a[d] > x.c) x.right = put(x.right(), a, val, d);
        else if (d < a.length - 1) x.mid = put(x.mid(), a, val, d + 1);
        else {
            x.val = val;
            x.leaf = true;
        }
        return x;
    }

    public final Character get(char[] a) {
        if (a == null) throw new NullPointerException();
        Node x = get(root, a, 0);
        return x == null || !x.leaf ? null : x.val;
    }

    public final Iterable<Character> getAll() {
        return get(root(), false);
    }

    protected final Iterable<Character> get(Node x, boolean prefix) {
        if (x == null) return Collections.emptyList();
        CArrayQueue queue = CArrayQueue.of();
        LLinkedStack<Node> stack = LLinkedStack.of();
        if (x.mid != null) stack.push(x.mid());
        if (!prefix) {
            if (x.left != null) stack.push(x.left());
            if (x.right != null) stack.push(x.right());
        }
        while (stack.size() > 0) {
            Node pop = stack.pop();
            if (pop.leaf) queue.offer(pop.val);
            push(stack, pop);
        }
        return queue.isEmpty() ? Collections.emptyList() : queue;
    }

    private void push(LLinkedStack<Node> stack, Node x) {
        if (x.left != null) stack.push(x.left());
        if (x.mid != null) stack.push(x.mid());
        if (x.right != null) stack.push(x.right());
    }

    public final Iterable<Character> getByPrefix(char[] prefix) {
        if (prefix == null) throw new NullPointerException();
        Node x = get(root, prefix, 0);
        if (x == null) return Collections.emptyList();
        return get(x, true);
    }

    public final Iterable<char[]> getChars(Predicate<Character> predicate, int limit) {
        return getChars0(x -> predicate == null || predicate.test(((Node) x).val), limit);
    }
}
