package com.xzchaoo.commons.basic.concurrent;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Objects;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;

/**
 * Learn from grpc, io.grpc.SynchronizationContext.
 * All tasks submitted will be executed in their submit order serially. Tasks may be executed in different threads.
 * This instance becomes <strong>broken</strong> if any exceptions occur.
 * <ul>
 * <li>Ordering</li>
 * <li>Serialization</li>
 * <li>Non-reentrancy</li>
 * </ul>
 *
 * <p>created at 2020-10-30
 *
 * @author xiangfeng.xzc
 */
public class SynchronizationContext3<T> {
    private static final Logger LOGGER = LoggerFactory.getLogger(SynchronizationContext3.class);

    private static final AtomicReferenceFieldUpdater<SynchronizationContext3, Thread> THREAD_UPDATER //
            = AtomicReferenceFieldUpdater.newUpdater(SynchronizationContext3.class, Thread.class, "thread");

    private final ConcurrentLinkedQueue<T> q = new ConcurrentLinkedQueue<>();
    private final Consumer<T> consumer;
    private volatile Thread thread;

    public SynchronizationContext3(Consumer<T> consumer) {
        this.consumer = Objects.requireNonNull(consumer);
    }

    public static <T> SynchronizationContext3<T> create(Consumer<T> consumer) {
        return new SynchronizationContext3<>(consumer);
    }

    /**
     * Just add a task to queue
     *
     * @param t
     */
    public void add(T t) {
        q.offer(t);
        drain();
    }

    public boolean isCurrentThreadInContext() {
        return this.thread == Thread.currentThread();
    }

    public void drain() {
        // 这是与 基于wip的DrainLoop 相对应的另外一种方式, 利用Thread的CAS竞争
        Thread thread = Thread.currentThread();
        ConcurrentLinkedQueue<T> q = this.q;
        do {
            if (!THREAD_UPDATER.compareAndSet(this, null, thread)) {
                return;
            }
            T t;
            while ((t = q.poll()) != null) {
                try {
                    // user task should never throws exception
                    consumer.accept(t);
                } catch (Throwable e) {
                    LOGGER.error("Exception caught when run task", e);
                }
            }
            THREAD_UPDATER.set(this, null);
        } while (!q.isEmpty());
    }
}
