package com.xzchaoo.commons.basic.concurrent;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import lombok.Getter;
import lombok.ToString;

/**
 * 提供与 {@link java.util.concurrent.Semaphore} 一样的语义, 但支持动态修改信号量最大值.
 *
 * <p>created at 2021/2/8
 *
 * @author xiangfeng.xzc
 */
public class DynamicSemaphore {
    private final Lock      lock      = new ReentrantLock();
    private final Condition available = lock.newCondition();
    /**
     * 最大的信号量
     */
    private       int       max;
    /**
     * 已经使用的信号量
     */
    private       int       used;
    /**
     * 正在等待的线程数
     */
    private       int       waiting;

    public DynamicSemaphore(int max) {
        this.max = max;
    }

    public void setMax(int max) {
        if (max <= 0) {
            throw new IllegalArgumentException("setMax " + max);
        }

        lock.lock();
        try {
            int oldMax = this.max;
            this.max = max;

            // max变大之后最多可以唤醒(max-used)个等待者, 这里偷懒直接唤醒全部, 这是低频操作
            if (waiting > 0 && max > oldMax) {
                available.signalAll();
            }
        } finally {
            lock.unlock();
        }
    }

    public void acquire(int n) {
        if (n <= 0) {
            throw new IllegalArgumentException("acquire " + n);
        }

        lock.lock();
        try {
            // 当缩小max时, 可能会出现used>max的情况
            while (used + n > max) {
                ++waiting;
                try {
                    available.awaitUninterruptibly();
                } finally {
                    --waiting;
                }
            }
            used += n;
        } finally {
            lock.unlock();
        }
    }

    public void acquire() {
        acquire(1);
    }

    public void acquireInterruptibly() throws InterruptedException {
        acquireInterruptibly(1);
    }

    public void acquireInterruptibly(int n) throws InterruptedException {
        if (n <= 0) {
            throw new IllegalArgumentException("acquire " + n);
        }
        lock.lockInterruptibly();
        try {
            // 当缩小max时, 可能会出现used>max的情况
            while (used + n > max) {
                ++waiting;
                try {
                    available.await();
                } finally {
                    --waiting;
                }
            }
            used += n;
        } finally {
            lock.unlock();
        }
    }

    public boolean tryAcquire() {
        return tryAcquire(1);
    }

    public boolean tryAcquire(long timeout, TimeUnit unit) throws InterruptedException {
        return tryAcquire(1, timeout, unit);
    }

    public boolean tryAcquire(int n, long timeout, TimeUnit unit) throws InterruptedException {
        if (n <= 0) {
            throw new IllegalArgumentException("acquire " + n);
        }
        if (timeout <= 0) {
            return tryAcquire(n);
        }
        long nanos = unit.toNanos(timeout);
        lock.lock();
        try {
            while (used + n > max) {
                ++waiting;
                try {
                    nanos = available.awaitNanos(nanos);
                    // timeout
                    if (nanos <= 0) {
                        return false;
                    }
                } finally {
                    --waiting;
                }
            }
            used += n;
        } finally {
            lock.unlock();
        }
        return true;
    }

    public boolean tryAcquire(int n) {
        if (n <= 0) {
            throw new IllegalArgumentException("acquire " + n);
        }
        lock.lock();
        try {
            if (used + n <= max) {
                used += n;
                return true;
            }
        } finally {
            lock.unlock();
        }
        return false;
    }

    public void release(int n) {
        if (n <= 0) {
            throw new IllegalArgumentException("release " + n);
        }
        lock.lock();
        try {
            used -= n;
            if (waiting > 0 && used < max) {
                // 唤醒一些等待者
                int signalN = Math.min(waiting, max - used);
                if (signalN == waiting) {
                    available.signalAll();
                } else {
                    for (int i = 0; i < signalN; i++) {
                        available.signal();
                    }
                }
            }
        } finally {
            lock.unlock();
        }
    }

    public void release() {
        release(1);
    }

    public Stat stat() {
        Stat stat = new Stat();
        lock.lock();
        try {
            stat.max = max;
            stat.used = used;
            stat.waiting = waiting;
        } finally {
            lock.unlock();
        }
        return stat;
    }

    @Getter
    @ToString
    public static class Stat {
        private int max;
        private int used;
        private int waiting;
    }
}
