package com.xzchaoo.commons.basic.topology;

import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import com.xzchaoo.commons.basic.Ack;

/**
 * Lock based {@link TopologyExecutor} impl
 * <p> created at 2020/8/12
 *
 * @author xzchaoo
 */
public class LockTopologyExecutor<N extends TopologyExecutor.Node>
    implements TopologyExecutor<N> {

    private final Set<N> nodes = Collections
        .newSetFromMap(new IdentityHashMap<>());

    private final IdentityHashMap<N, List<N>> edges = new IdentityHashMap<>();
    private final Lock                        lock  = new ReentrantLock();
    private final Runnable                    onComplete;
    /**
     * Max work in progress count, '&lt; 0' means no limit
     */
    private final int                         maxWip;
    private final LinkedList<N>               q     = new LinkedList<>();

    private int wip;
    private int ackCount;

    public LockTopologyExecutor(TopologyExecutorConfig config) {
        this.maxWip = config.getMaxWip();
        this.onComplete = config.getOnComplete();
    }

    @Override
    public void add(N from, N to) {
        edges.computeIfAbsent(from, ignored -> new ArrayList<>()) //
            .add(to);
        ++from.out;
        ++to.in;
        nodes.add(from);
        nodes.add(to);
    }

    @Override
    public List<N> check() {
        LinkedList<N> q = new LinkedList<>();
        for (N n : nodes) {
            n.inBackup = n.in;
            n.outBackup = n.out;
            if (n.in == 0) {
                this.q.offerLast(n);
                q.offerLast(n);
            }
        }
        if (q.isEmpty()) {
            throw new IllegalStateException("roots are empty");
        }
        List<N> result = new ArrayList<>();
        while (!q.isEmpty()) {
            N n = q.poll();
            result.add(n);
            List<N> toList = this.edges.get(n);
            if (toList != null) {
                for (N to : toList) {
                    if (--to.in == 0) {
                        q.offerLast(to);
                    }
                }
            }
        }
        if (result.size() != nodes.size()) {
            throw new IllegalStateException("Not a DAG");
        }
        for (N n : nodes) {
            n.in = n.inBackup;
            n.out = n.outBackup;
        }
        return result;
    }

    @Override
    public void execute() {
        drainLoop();
    }

    private void ack(N n) {
        lock.lock();
        try {
            --wip;
            List<N> toList = edges.get(n);
            if (toList != null) {
                for (N to : toList) {
                    if (--to.in == 0) {
                        q.offerLast(to);
                    }
                }
            }
            drainLoop();
            if (++ackCount == nodes.size()) {
                if (onComplete != null) {
                    onComplete.run();
                }
            }
        } finally {
            lock.unlock();
        }
    }

    private void drainLoop() {
        lock.lock();
        try {
            while ((maxWip <= 0 || wip < maxWip) && !q.isEmpty()) {
                ++wip;
                N n = q.pollFirst();
                execute(n);
            }
        } finally {
            lock.unlock();
        }
    }

    private void execute(N n) {
        n.execute(Ack.once(() -> ack(n)));
    }
}
