package shz.st.triest;

import shz.queue.CArrayQueue;
import shz.queue.LLinkedQueue;
import shz.stack.LLinkedStack;
import shz.stack.ZArrayStack;
import shz.tuple.Tuple2;

import java.util.Collections;
import java.util.Iterator;
import java.util.function.Predicate;

/**
 * 基于单词查找树的符号表
 * <p>
 * 在单词查找树中查找一个键或是插入一个键时，访问数组的次数最多为键的长度加1
 * <p>
 * 字母表的大小为R，在一棵由N个随机键构造的单词查找树中，未命中查找平均所需检查的结点数量为～logRN
 * <p>
 * 一棵单词查找树中的链接总数在RN到RNw之间，其中w为键的平均长度
 */
@SuppressWarnings("unchecked")
public abstract class TrieST {
    protected static abstract class Node {
        public boolean leaf;
        public Node[] next;

        protected Node(int r) {
            next = new Node[r];
        }

        public final <T extends Node> T next(int i) {
            return (T) next[i];
        }
    }

    protected final char[] chars;
    protected Node root;

    protected TrieST(char[] chars) {
        this.chars = chars;
    }

    protected final <T extends Node> T root() {
        return (T) root;
    }

    protected final void check(char[] a) {
        if (a == null || a.length == 0) throw new NullPointerException();
        for (char c : a) if (idx(c) == -1) throw new IllegalStateException();
    }

    protected final int idx(char c) {
        for (int i = 0; i < chars.length; ++i) if (chars[i] == c) return i;
        return -1;
    }

    protected final <T extends Node> T get(Node x, char[] a, int d) {
        for (int i = 0; i < d; ++i) {
            x = x.next[idx(a[i])];
            if (x == null) break;
        }
        return (T) x;
    }

    public final void delete(char[] a) {
        check(a);
        Node p = get(root, a, a.length - 1);
        if (p == null) return;
        int i = idx(a[a.length - 1]);
        Node x = p.next[i];
        if (x == null) return;
        x.leaf = false;
        if (x.next == null) p.next[i] = null;
    }

    protected final Iterable<char[]> getChars0(Predicate<Node> predicate, int limit) {
        LLinkedQueue<char[]> result = LLinkedQueue.of();
        CArrayQueue key = CArrayQueue.of();

        LLinkedStack<Tuple2<Character, Node>> stack = LLinkedStack.of();
        ZArrayStack remove = ZArrayStack.of();
        push(stack, root, remove);
        while (stack.size() > 0) {
            if (limit > 0 && result.size() >= limit) break;

            Tuple2<Character, Node> pop = stack.pop();
            remove.pop();
            key.offer(pop._1);
            if (pop._2.leaf && predicate.test(pop._2)) {
                char[] chars = new char[key.size()];
                Iterator<Character> it = key.iterator();
                int i = 0;
                while (it.hasNext()) chars[i++] = it.next();
                result.offer(chars);
            }
            if (!push(stack, pop._2, remove)) {
                key.removeTail();
                while (remove.size() > 1 && remove.peek()) {
                    remove.pop();
                    key.removeTail();
                }
            }
        }
        return result.isEmpty() ? Collections.emptyList() : result;
    }

    private boolean push(LLinkedStack<Tuple2<Character, Node>> stack, Node x, ZArrayStack remove) {
        if (x.next == null) return false;
        remove.push(true);
        boolean flag = false;
        for (int i = 0; i < chars.length; ++i) {
            if (x.next[i] != null) {
                flag = true;
                stack.push(Tuple2.apply(chars[i], x.next[i]));
                remove.push(false);
            }
        }
        if (!flag) remove.pop();
        return flag;
    }

    private static final char[] ZERO = {'0'};
    private static final char[] MIN_VALUE = {'1', '0'};

    public static char[] toChars(int i) {
        if (i == 0) return ZERO;
        boolean lt0 = i < 0;
        if (lt0) {
            if (i == Integer.MIN_VALUE) return MIN_VALUE;
            else i = -i;
        }
        return ((lt0 ? 1 : 0) + Integer.toBinaryString(i)).toCharArray();
    }

    public static char[] toChars(long j) {
        if (j == 0) return ZERO;
        boolean lt0 = j < 0L;
        if (lt0) {
            if (j == Long.MIN_VALUE) return MIN_VALUE;
            else j = -j;
        }
        return ((lt0 ? 1 : 0) + Long.toBinaryString(j)).toCharArray();
    }

    public static int toInt(char[] a) {
        if (a == null || a.length == 0) throw new NullPointerException();
        for (char c : a) if (c != '0' && c != '1') throw new IllegalStateException();
        if (a.length == 1) {
            if (a[0] == '0') return 0;
            throw new IllegalArgumentException();
        } else if (a.length == 2 && a[1] == '0') {
            if (a[0] == '1') return Integer.MIN_VALUE;
            throw new IllegalArgumentException();
        } else if (a.length > Integer.SIZE) throw new IllegalStateException();

        int num = 0;
        for (int i = 1; i < a.length; ++i) if (a[i] == '1') num += 1 << a.length - 1 - i;
        return a[0] == '1' ? -num : num;
    }

    public static long toLong(char[] a) {
        if (a == null || a.length == 0) throw new NullPointerException();
        for (char c : a) if (c != '0' && c != '1') throw new IllegalStateException();
        if (a.length == 1) {
            if (a[0] == '0') return 0L;
            throw new IllegalArgumentException();
        } else if (a.length == 2 && a[1] == '0') {
            if (a[0] == '1') return Long.MIN_VALUE;
            throw new IllegalArgumentException();
        } else if (a.length > Long.SIZE) throw new IllegalStateException();

        long num = 0;
        for (int i = 1; i < a.length; ++i) if (a[i] == '1') num += 1 << a.length - 1 - i;
        return a[0] == '1' ? -num : num;
    }
}
