package com.xzchaoo.commons.basic.concurrent;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

/**
 * Learn from grpc, io.grpc.SynchronizationContext.
 * All tasks submitted will be executed in their submit order serially. Tasks may be executed in different threads.
 * This instance becomes <strong>broken</strong> if any exceptions occur.
 * <ul>
 * <li>Ordering</li>
 * <li>Serialization</li>
 * <li>Non-reentrancy</li>
 * </ul>
 *
 * <p>created at 2020-10-30
 *
 * @author xiangfeng.xzc
 */
public class SynchronizationContext2 implements Executor {
    private static final Logger LOGGER = LoggerFactory.getLogger(SynchronizationContext2.class);

    private static final AtomicReferenceFieldUpdater<SynchronizationContext2, Thread> THREAD_UPDATER
            = AtomicReferenceFieldUpdater
            .newUpdater(SynchronizationContext2.class, Thread.class, "thread");

    /**
     * TODO: use unlimited queue?
     */
    private final ConcurrentLinkedQueue<Runnable> q = new ConcurrentLinkedQueue<>();
    private volatile Thread thread;

    public static SynchronizationContext2 create() {
        return new SynchronizationContext2();
    }

    /**
     * Just add a task to queue
     *
     * @param command
     */
    @Override
    public void execute(@Nonnull Runnable command) {
        executeLater(command);
        drain();
    }

    /**
     * Just add a task to queue.
     *
     * @param command
     */
    public void executeLater(@Nonnull Runnable command) {
        if (!q.offer(command)) {
            throw new IllegalStateException("queue is full");
        }
    }

    public boolean isCurrentThreadInContext() {
        return this.thread == Thread.currentThread();
    }

    public void drain() {
        // 这是与 基于wip的DrainLoop 相对应的另外一种方式, 利用Thread的CAS竞争
        Thread thread = Thread.currentThread();
        ConcurrentLinkedQueue<Runnable> q = this.q;
        do {
            if (!THREAD_UPDATER.compareAndSet(this, null, thread)) {
                return;
            }
            Runnable r;
            while ((r = q.poll()) != null) {
                try {
                    // user task should never throws exception
                    r.run();
                } catch (Throwable e) {
                    LOGGER.error("Exception caught when run task", e);
                }
            }
            THREAD_UPDATER.set(this, null);
        } while (!q.isEmpty());
    }
}
