package com.spring.boxes.dollar.support.throttle;

import static com.google.common.util.concurrent.Futures.addCallback;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.lang.System.currentTimeMillis;
import static java.util.Objects.requireNonNull;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.ListenableFuture;
import com.spring.boxes.dollar.support.ThrowableSupplier;

import lombok.extern.slf4j.Slf4j;

/**
 * Throttling(限制并发线程数)同RateLimiter(每秒许可请求数)的区别
 * RateLimiter为1s的请求限流数据，如果单机1000个qps限制，接口单词响应时间为200m。
 * 单线程1s可以执行5个请求，1000qps，需要20个线程执行完。这里的Throttling的阈值就是20
 *
 */
@Slf4j
public class Throttling {

    private static final Logger logger = LoggerFactory.getLogger(Throttling.class);

    private volatile long lastAccessTimestamp;
    private final ThrottlingKey throttlingKey;
    private final AtomicInteger concurrency = new AtomicInteger();
    private static final ConcurrentMap<String, Throttling> MAP = new ConcurrentHashMap<>();

    private Throttling(ThrottlingKey throttlingKey) {
        this.throttlingKey = throttlingKey;
    }

    public static <T, X extends Throwable> T supplyWithThrottling(ThrottlingKey throttlingKey,
            ThrowableSupplier<? extends T, X> result) throws X {
        return supplyWithThrottling(throttlingKey, result, () -> {
            logger.error("{} invocation throttled", throttlingKey.getName());
            return null;
        });
    }

    // 对于异步请求的限流，请使用 {@link #asyncSupplyWithThrottling}, 请不要把对 {@link Future#get} 执行同步限流
    public static <T, X extends Throwable> T supplyWithThrottling(ThrottlingKey throttlingKey,
            ThrowableSupplier<? extends T, X> result, Supplier<? extends T> onThrottledResult) throws X {
        Throttling throttling = Throttling.of(throttlingKey);
        try {
            if (throttling.tryEnter()) {
                return result.get();
            } else {
                return supplyNullable(onThrottledResult);
            }
        } finally {
            throttling.leave();
        }
    }

    public static <T> ListenableFuture<T> asyncSupplyWithThrottling(ThrottlingKey throttlingKey,
            Supplier<ListenableFuture<T>> result) {
        return asyncSupplyWithThrottling(throttlingKey, result, () -> {
            logger.error("{} invocation throttled", throttlingKey.getName());
            return null;
        });
    }

    public static <T> ListenableFuture<T> asyncSupplyWithThrottling(ThrottlingKey throttlingKey,
            Supplier<ListenableFuture<T>> result, Supplier<T> onThrottledResult) {
        Throttling throttling = Throttling.of(throttlingKey);
        boolean needLeaveOnFinally = true;
        try {
            if (throttling.tryEnter()) {
                ListenableFuture<T> future = result.get();
                requireNonNull(future);
                // 主线程运行
                addCallback(future, new FutureCallback<T>() {
                    @Override
                    public void onSuccess(@Nullable T result) {
                        throttling.leave();
                    }

                    @Override
                    public void onFailure(@Nonnull Throwable t) {
                        throttling.leave();
                    }
                }, directExecutor());
                needLeaveOnFinally = false;
                return future;
            } else {
                return immediateFuture(supplyNullable(onThrottledResult));
            }
        } finally {
            if (needLeaveOnFinally) {
                throttling.leave();
            }
        }
    }

    private static <T> T supplyNullable(@Nullable Supplier<? extends T> onThrottledResult) {
        if (onThrottledResult != null) {
            return onThrottledResult.get();
        } else {
            return null;
        }
    }

    public ThrottlingKey getThrottlingKey() {
        return throttlingKey;
    }

    public int getCurrentConcurrency() {
        return concurrency.get();
    }

    private static int totalThrottlingCount() {
        return MAP.size();
    }

    /**
     * 获取指定Throttling
     */
    public static Throttling of(ThrottlingKey throttlingKey) {
        String name = throttlingKey.getName();
        Throttling throttling = MAP.get(name);
        if (throttling != null) {
            return throttling;
        }
        return MAP.computeIfAbsent(name, k -> new Throttling(throttlingKey));
    }

    /**
     * 尝试进入一个限流方法
     *
     * @return {@code true}表示进入成功, {@code false}表示超过限流了
     */
    public boolean tryEnter() {
        lastAccessTimestamp = currentTimeMillis();
        if (concurrency.incrementAndGet() > throttlingKey.getThreshold()) {
            return false;
        }
        return true;
    }

    /**
     * 离开限流方法
     *
     * @return 离开后的并发数
     */
    public int leave() {
        lastAccessTimestamp = currentTimeMillis();
        return concurrency.decrementAndGet();
    }
}
