package com.xzchaoo.commons.basic.topology.sort;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import lombok.val;

import com.google.common.base.Preconditions;

/**
 * Topology sort
 * <p>created at 2020-08-11
 *
 * @author xiangfeng.xzc
 */
public class TopologySort<N> {
    private final NodeFunction<N>          nf;
    private final Map<Object, TNode>       nodes = new HashMap<>();
    private final Map<Object, List<TNode>> edges = new HashMap<>();

    public TopologySort(NodeFunction<N> nf) {
        this.nf = Objects.requireNonNull(nf);
    }

    public void addNode(N n) {
        Preconditions.checkArgument(n != null, "node is null");

        Object id = nf.id(n);
        Preconditions.checkArgument(id != null, "id is null");

        TNode tn = nodes.get(id);
        if (tn == null) {
            nodes.put(id, new TNode(id, n));
        }
    }

    public void addEdge(N from, N to) {
        Preconditions.checkArgument(from != null, "from is null");
        Preconditions.checkArgument(to != null, "to is null");

        Object fromId = nf.id(from);
        Preconditions.checkArgument(fromId != null, "fromId is null");

        Object toId = nf.id(to);
        Preconditions.checkArgument(toId != null, "toId is null");

        val fromNode = nodes.computeIfAbsent(fromId, ignored -> new TNode(fromId, from));

        val toNode = nodes.computeIfAbsent(toId, ignored -> new TNode(toId, to));

        ++fromNode.out;
        ++toNode.in;
        edges.computeIfAbsent(fromId, ignored -> new ArrayList<>()).add(toNode); //
    }

    /**
     * Topology sort
     *
     * @return topology sort result
     */
    public List<N> sort() {
        LinkedList<TNode> q = new LinkedList<>();
        for (TNode tn : nodes.values()) {
            if (tn.in == 0) {
                q.add(tn);
            }
        }
        if (q.isEmpty()) {
            throw new IllegalStateException("fail to find roots with in = 0");
        }
        int visited = 0;
        List<N> result = new ArrayList<>(nodes.size());
        while (!q.isEmpty()) {
            ++visited;
            TNode tn = q.removeFirst();
            result.add(tn.n);
            List<TNode> toNodes = edges.get(tn.id);
            if (toNodes != null) {
                for (TNode to : toNodes) {
                    if (--to.in == 0) {
                        q.addLast(to);
                    }
                }
            }
        }
        if (visited != nodes.size()) {
            throw new IllegalStateException("not a DAG");
        }
        return result;
    }

    public interface NodeFunction<N> {
        /**
         * Extract node id from n.
         *
         * @param n
         * @return 用作node的id. 必须能被作为Map的key.
         */
        Object id(N n);
    }

    private class TNode {
        final Object id;
        final N      n;
        int in;
        int out;

        TNode(Object id, N n) {
            this.id = id;
            this.n = n;
        }
    }
}
