package shz.model;

import shz.Validator;
import shz.msg.ServerFailureMsg;

import java.util.Arrays;
import java.util.BitSet;

public class BloomFilter {
    private final int k;
    private final BitSet bitSet;
    private final int[] seeds;
    private final int MP;

    protected BloomFilter(double p, int n) {
        if (n <= 0) throw new IllegalArgumentException();
        if (p <= 0d) p = 1.0d / n;
        k = (int) Math.ceil((Math.abs(n * Math.log(p) / (Math.pow(Math.log(2), 2))) / n) * Math.log(2));
        long bitSize = (long) k * n;
        ServerFailureMsg.requireNon(bitSize > Integer.MAX_VALUE, "无法构建符合指定失误率%f及元素个数%d的布隆过滤器", p, n);
        bitSet = new BitSet((int) bitSize);
        seeds = k <= SEEDS.length ? SEEDS : getSeeds(k);
        MP = getMP((int) bitSize - 1);
    }

    public static BloomFilter of(double p, int n) {
        return new BloomFilter(p, n);
    }

    public static BloomFilter of(int n) {
        return new BloomFilter(0d, n);
    }

    private static final int[] SEEDS = {31, 33, 37, 39, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97};

    private int[] getSeeds(int n) {
        int[] seeds = new int[n];
        System.arraycopy(SEEDS, 0, seeds, 0, SEEDS.length);
        for (int i = SEEDS.length; i < n; ++i) {
            int p = seeds[i - 1] + 2;
            while (nonPrime(p)) p += 2;
            seeds[i] = p;
        }
        return seeds;
    }

    private boolean nonPrime(int x) {
        if (x == 2) return false;
        if (x < 2 || (x & 1) == 0) return true;
        double sqrt = Math.sqrt(x);
        for (int i = 3; i <= sqrt; i += 2) if (x % i == 0) return true;
        return false;
    }

    private int getMP(int mp) {
        while (nonPrime(mp)) --mp;
        return mp;
    }

    public final void add(String... keys) {
        Arrays.stream(keys).forEach(key -> {
            if (Validator.isBlank(key)) bitSet.set(0);
            else {
                char[] chars = key.toCharArray();
                for (int i = 0; i < k; ++i) {
                    int h = 0;
                    for (char c : chars) h = (seeds[i] * h + c) % MP;
                    bitSet.set(h);
                }
            }
        });
    }

    public final boolean exists(String key) {
        if (Validator.isBlank(key)) return bitSet.get(0);
        char[] chars = key.toCharArray();
        for (int i = 0; i < k; ++i) {
            int h = 0;
            for (char c : chars) h = (seeds[i] * h + c) % MP;
            if (!bitSet.get(h)) return false;
        }
        return true;
    }
}
