package com.xzchaoo.commons.basic.topology;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

import org.jctools.queues.MpscArrayQueue;

import com.xzchaoo.commons.basic.Ack;
import com.xzchaoo.commons.basic.drainloop.DrainLoop;

/**
 * <p>created at 2020-08-11
 *
 * @author xiangfeng.xzc
 */
public class DrainLoopTopologyExecutor3<N> extends DrainLoop
    implements TopologyExecutor3<N> {
    private final NodeFunction<N> nf;

    private final Map<String, GNode<N>>     nodes = new HashMap<>();
    private final Map<String, List<String>> edges = new HashMap<>();
    private final MpscArrayQueue<GNode<N>>  ackQ  = new MpscArrayQueue<>(65536);
    private       LinkedList<GNode<N>>      q;
    // temp
    private       BiConsumer<N, Ack>        consumer;
    private       Runnable                  complete;
    private       int                       ackCount;
    private       int                       maxWip;
    private       int                       wip;

    public DrainLoopTopologyExecutor3(NodeFunction<N> nf, int maxWip) {
        this.nf = nf;
        this.maxWip = maxWip;
    }

    /**
     * Add an edge: 'from - to'.
     *
     * @param from from
     * @param to   to
     */
    public void add(N from, N to) {
        String fromId = nf.id(from);
        String toId = nf.id(to);
        GNode<N> fromNode =
            nodes.computeIfAbsent(fromId, ignored -> new GNode<>(fromId, from));
        GNode<N> toNode =
            nodes.computeIfAbsent(toId, ignored -> new GNode<>(toId, to));
        ++fromNode.out;
        ++toNode.in;
        edges.computeIfAbsent(fromId, ignored -> new ArrayList<>()) //
             .add(toId);
    }

    @Override
    public List<N> check() {
        LinkedList<GNode<N>> q = new LinkedList<>();
        int visited = 0;
        for (GNode<N> gn : nodes.values()) {
            gn.inBackup = gn.in;
            gn.outBackup = gn.out;
            if (gn.in == 0) {
                q.addLast(gn);
            }
        }

        LinkedList<N> result = new LinkedList<>();
        while (!q.isEmpty()) {
            GNode<N> gn = q.removeFirst();
            ++visited;
            result.add(gn.n);
            List<String> toIdList = edges.get(gn.id);
            if (toIdList != null) {
                for (String toId : toIdList) {
                    GNode<N> to = nodes.get(toId);
                    if (--to.in == 0) {
                        q.addLast(to);
                    }
                }
            }
        }

        if (visited != nodes.size()) {
            StringBuilder sb = new StringBuilder();
            for (GNode<N> gn : nodes.values()) {
                if (gn.in != 0) {
                    sb.append(gn.id).append(',');
                }
            }
            sb.setLength(sb.length() - 1);
            throw new IllegalStateException("not DAG, remain nodes = " + sb);
        }

        // restore
        for (GNode<N> gn : nodes.values()) {
            gn.in = gn.inBackup;
            gn.out = gn.outBackup;
        }

        return result;
    }

    @Override
    public void execute(BiConsumer<N, Ack> consumer, Runnable complete) {
        this.consumer = consumer;
        this.complete = complete;
        this.q = new LinkedList<>();

        for (GNode<N> gn : nodes.values()) {
            if (gn.in == 0) {
                q.add(gn);
            }
        }

        drainLoop();
    }

    private void ack(GNode<N> gn) {
        System.out.println("ack");
        if (!ackQ.offer(gn)) {
            throw new IllegalStateException("ack queue is full");
        }
        drainLoop();
    }

    private boolean consumeAcks() {
        for (; ; ) {
            GNode<N> gn = ackQ.relaxedPoll();
            if (gn == null) {
                return true;
            }
            --wip;
            List<String> toList = edges.get(gn.id);
            if (toList != null) {
                for (String toId : toList) {
                    GNode<N> toNode = nodes.get(toId);
                    if (--toNode.in == 0) {
                        q.addLast(toNode);
                    }
                }
            }
            if (++ackCount == nodes.size()) {
                complete.run();
            }
        }
    }

    private boolean consumeRoots() {
        for (; ; ) {
            if (wip == maxWip) {
                return false;
            }
            GNode<N> gn = q.pollFirst();
            if (gn == null) {
                return true;
            }
            ++wip;
            internalExecute(gn);
        }
    }

    @Override
    protected void drainLoop0() {
        for (; ; ) {
            // 此处ack会产生新的任务放入q, 所以将consumeAcks放前面
            boolean ackEmpty = consumeAcks();
            boolean rootEmpty = consumeRoots();
            if (ackEmpty && rootEmpty) {
                break;
            }
        }
    }

    private void internalExecute(GNode<N> gn) {
        consumer.accept(gn.n, Ack.once(() -> {
            ack(gn);
        }));
    }

    static class GNode<N> {
        final String id;
        final N      n;
        int in;
        int out;

        int inBackup;
        int outBackup;

        GNode(String id, N n) {
            this.id = id;
            this.n = n;
        }
    }
}
