package com.xzchaoo.commons.basic.topology;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.function.BiConsumer;

import com.xzchaoo.commons.basic.Ack;

/**
 * <p>created at 2020-08-11
 *
 * @author xiangfeng.xzc
 */
public class TopologyGraph<N> {
    private final NodeFunction<N>           nf;
    private final Map<String, GNode<N>>     nodes = new HashMap<>();
    private final Map<String, List<String>> edges = new HashMap<>();
    private final Semaphore                 semaphore;
    private       int                       ackCount;

    public TopologyGraph(NodeFunction<N> nf, int maxWip) {
        this.nf = nf;
        this.semaphore = new Semaphore(maxWip);
    }

    /**
     * Add an edge: 'from - to'.
     *
     * @param from from
     * @param to   to
     */
    public void add(N from, N to) {
        String fromId = nf.id(from);
        String toId = nf.id(to);
        GNode<N> fromNode =
            nodes.computeIfAbsent(fromId, ignored -> new GNode<>(fromId, from));
        GNode<N> toNode =
            nodes.computeIfAbsent(toId, ignored -> new GNode<>(toId, to));
        ++fromNode.out;
        ++toNode.in;
        edges.computeIfAbsent(fromId, ignored -> new ArrayList<>()) //
             .add(toId);
    }

    /**
     * Async consume node in topology order
     */
    public synchronized CountDownLatch consume(BiConsumer<N, Ack> consumer) {
        CountDownLatch cdl = new CountDownLatch(nodes.size());
        List<GNode<N>> roots = new ArrayList<>();
        for (GNode<N> gn : nodes.values()) {
            if (gn.in == 0) {
                roots.add(gn);
            }
        }
        for (GNode<N> gn : roots) {
            consume(gn, consumer, cdl);
        }
        return cdl;
    }

    /**
     * Use synchronized to prevent this methods
     */
    private void ack(GNode<N> gn, BiConsumer<N, Ack> consumer,
                     CountDownLatch cdl) {
        System.out.println("ack");
        semaphore.release();
        cdl.countDown();
        synchronized (this) {
            if (++ackCount == nodes.size()) {
                // 已经完成了
                System.out.println("done");
            }
            List<String> toList = edges.get(gn.id);
            if (toList != null) {
                for (String toId : toList) {
                    GNode<N> toNode = nodes.get(toId);
                    if (--toNode.in == 0) {
                        consume(toNode, consumer, cdl);
                    }
                }
            }
        }
    }

    private void consume(GNode<N> gn, BiConsumer<N, Ack> consumer,
                         CountDownLatch cdl) {
        System.out.println("acquire");
        semaphore.acquireUninterruptibly();
        consumer.accept(gn.n, () -> ack(gn, consumer, cdl));
    }

    static class GNode<N> {
        final String id;
        final N      n;
        int in;
        int out;

        public GNode(String id, N n) {
            this.id = id;
            this.n = n;
        }
    }
}
