package shz.st.triest;

import shz.queue.LLinkedQueue;
import shz.stack.LLinkedStack;

import java.util.*;
import java.util.function.Predicate;

/**
 * 值为E类型的基于单词查找树的符号表
 * <p>
 * 8+24+2*r(r为chars数组长度)=chars
 * 8+[49+E(类型字节)+8*r+对齐填充]*n(n为元素个数)+8*r*n*w(w为键的平均长度)
 * <p>
 * B=48*(n+1)+[E+对齐填充]*n+8*r*n*(w+1)+8+(2*r+n+对齐填充)
 */
public class LTrieST<E> extends TrieST {
    /**
     * 8+E(类型字节)+25+8*r(r为数组长度)+对齐填充
     * <p>
     * B=49+E(类型字节)+8*r+对齐填充
     * <p>
     * 若E为String,r=4(例如基因序列A,T,C,G),则B=49+64+2*n(n为字符串长度)+8*4=113+2*n+32=145+2*n+对齐填充
     */
    protected static final class Node<E> extends TrieST.Node {
        public E val;

        public Node(int r) {
            super(r);
        }
    }

    protected LTrieST(char[] chars) {
        super(chars);
        root = new Node<>(chars.length);
    }

    public static <E> LTrieST<E> of(char[] chars) {
        if (chars == null || chars.length == 0) throw new NullPointerException();
        return new LTrieST<>(chars);
    }

    public final void put(char[] a, E val) {
        if (val == null) throw new NullPointerException();
        check(a);
        Node<E> x = root();
        for (char c : a) {
            int i = idx(c);
            if (x.next[i] == null) x.next[i] = new Node<>(chars.length);
            x = x.next(i);
        }
        x.val = val;
        x.leaf = true;
    }

    public final E get(char[] a) {
        check(a);
        Node<E> x = get(root, a, a.length);
        return x == null || !x.leaf ? null : x.val;
    }

    public final Iterable<E> getAll() {
        return get(root());
    }

    protected final Iterable<E> get(Node<E> x) {
        LLinkedQueue<E> queue = LLinkedQueue.of();
        LLinkedStack<Node<E>> stack = LLinkedStack.of();
        push(stack, x);
        while (stack.size() > 0) {
            Node<E> pop = stack.pop();
            if (pop.leaf) queue.offer(pop.val);
            push(stack, pop);
        }
        return queue.isEmpty() ? Collections.emptyList() : queue;
    }

    private void push(LLinkedStack<Node<E>> stack, Node<E> x) {
        if (x.next == null) return;
        for (int i = 0; i < chars.length; ++i) if (x.next[i] != null) stack.push(x.next(i));
    }

    public final Iterable<E> getByPrefix(char[] prefix) {
        check(prefix);
        Node<E> x = get(root, prefix, prefix.length);
        if (x == null) return Collections.emptyList();
        return get(x);
    }

    @SuppressWarnings("unchecked")
    public final Iterable<char[]> getChars(Predicate<E> predicate, int limit) {
        return getChars0(x -> predicate == null || predicate.test(((Node<E>) x).val), limit);
    }
}
