package com.feingto.cloud.cache.provider;

import com.feingto.cloud.cache.enums.IntervalUnit;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.data.redis.core.RedisTemplate;

import java.util.Optional;

/**
 * Redis 令牌桶算法
 *
 * @author longfei
 */
@SuppressWarnings("rawtypes")
public class RedisTokenProvider {
    private static final String CACHE_KEY = "feingto:monitor:control";
    private final RedisHashCache<TokenBucket> redisHashCache;

    public RedisTokenProvider(RedisTemplate redisTemplate) {
        redisHashCache = new RedisHashCache<>(redisTemplate);
    }

    /**
     * 验证一秒钟内是否达到次数限制
     *
     * @param sign  唯一标识, 用做缓存field
     * @param limit 次数
     * @return true: limit; false: access
     */
    public boolean isLimit(String sign, long limit) {
        return isLimit(sign, limit, 1L, IntervalUnit.SECONDS);
    }

    /**
     * 验证单位时间内是否达到次数限制
     *
     * @param sign         唯一标识, 用做缓存field
     * @param limit        次数
     * @param frequency    频率
     * @param intervalUnit 时间间隔单位, 例如每分钟/每秒钟
     * @return true: limit; false: access
     */
    public synchronized boolean isLimit(String sign, long limit, long frequency, IntervalUnit intervalUnit) {
        if (limit <= 0 || frequency <= 0) {
            return false;
        }

        final String hashKey = sign;
        long frequencyInMills;

        switch (intervalUnit) {
            case MILLISECONDS:
                frequencyInMills = frequency;
                break;
            case SECONDS:
                frequencyInMills = frequency * 1000;
                break;
            case MINUTES:
                frequencyInMills = frequency * 60000;
                break;
            case HOURS:
                frequencyInMills = frequency * 360000;
                break;
            case DAYS:
                frequencyInMills = frequency * 8640000;
                break;
            default:
                // 不受限
                return false;
        }

        return Optional.ofNullable(redisHashCache.get(CACHE_KEY, hashKey, TokenBucket.class))
                .map(tokenBucket -> {
                    long lastRefillTime = tokenBucket.lastRefillTime();
                    long refillTime = System.currentTimeMillis();
                    long intervalSinceLast = refillTime - lastRefillTime;
                    long currentTokensRemaining;
                    if (intervalSinceLast > frequencyInMills) {
                        currentTokensRemaining = limit;
                    } else {
                        double interval = frequencyInMills * 1.0 / limit;
                        long grantedTokens = (long) (intervalSinceLast / interval);
                        currentTokensRemaining = Math.min(grantedTokens + tokenBucket.tokensRemaining(), limit);
                    }
                    if (currentTokensRemaining == 0) {
                        tokenBucket.tokensRemaining(currentTokensRemaining);
                        redisHashCache.put(CACHE_KEY, hashKey, tokenBucket);
                        return true;
                    } else {
                        tokenBucket.lastRefillTime(refillTime);
                        tokenBucket.tokensRemaining(currentTokensRemaining - 1);
                        redisHashCache.put(CACHE_KEY, hashKey, tokenBucket);
                        return false;
                    }
                })
                .orElseGet(() -> {
                    redisHashCache.put(CACHE_KEY, hashKey, new TokenBucket()
                            .lastRefillTime(System.currentTimeMillis())
                            .tokensRemaining(limit - 1));
                    return false;
                });
    }

    /**
     * 重置令牌
     */
    public synchronized void reset() {
        redisHashCache.clear(CACHE_KEY);
    }

    @Data
    @Accessors(fluent = true)
    private static class TokenBucket {
        private long lastRefillTime;
        private long tokensRemaining;
    }
}
