package shz.model;

import shz.queue.PQueue;
import shz.ToMap;
import shz.msg.ServerFailureMsg;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

public abstract class ShortestPath {
    protected static class Node {
        public String name;
        public Integer weight;

        public Node(String name) {
            this.name = name;
        }
    }

    protected Node node;
    protected Map<ShortestPath, Integer> nodes;
    protected ShortestPath prev;

    protected ShortestPath(String name, int size) {
        this.node = new Node(name);
        this.nodes = ToMap.get(size).build();
    }

    public final ShortestPath add(ShortestPath optimal, int weight) {
        nodes.put(optimal, weight);
        return this;
    }

    public static final class Result {
        int sum;
        List<String> ways;

        Result(int sum, List<String> ways) {
            this.sum = sum;
            this.ways = ways;
        }

        public int sum() {
            return sum;
        }

        public List<String> ways() {
            return ways;
        }

        @Override
        public String toString() {
            return "Result{" +
                    "sum=" + sum +
                    ", ways=" + ways +
                    '}';
        }
    }

    public final Result go(ShortestPath des) {
        node.weight = 0;
        update();
        TreeMap<Integer, String> treeMap = new TreeMap<>((t, u) -> u - t);
        int count = 0;
        treeMap.put(count++, des.node.name);
        ShortestPath optimal = des;
        while ((optimal = optimal.prev) != null) treeMap.put(count++, optimal.node.name);
        Result result = new Result(des.node.weight, new ArrayList<>(treeMap.values()));
        reset();
        return result;
    }

    protected abstract void update();

    public final void reset() {
        Queue<ShortestPath> queue = new LinkedList<>();
        Set<String> book = new HashSet<>();
        queue.offer(this);
        while (!queue.isEmpty()) {
            ShortestPath poll = queue.poll();
            poll.node.weight = null;
            poll.prev = null;
            book.add(poll.node.name);
            poll.nodes.forEach((k, v) -> {
                if (!book.contains(k.node.name)) queue.offer(k);
            });
        }
    }

    public static final class BellmanFord extends ShortestPath {
        public BellmanFord(String name, int size) {
            super(name, size);
        }

        @Override
        protected void update() {
            Queue<ShortestPath> queue = new LinkedList<>();
            Set<String> book = new HashSet<>();
            AtomicInteger atomic = new AtomicInteger();
            update0(queue, book, atomic);
            int sum = 0, count = 0;
            while (sum != (sum = atomic.get()) && count++ < book.size()) {
                atomic.set(0);
                book.clear();
                update0(queue, book, atomic);
            }
            ServerFailureMsg.requireNon(count >= book.size(), "无法确定最优方案");
        }

        private void update0(Queue<ShortestPath> queue, Set<String> book, AtomicInteger atomic) {
            queue.offer(this);
            while (!queue.isEmpty()) {
                ShortestPath poll = queue.poll();
                book.add(poll.node.name);
                poll.nodes.forEach((k, v) -> {
                    int weight = poll.node.weight + v;
                    if (k.node.weight == null || weight < k.node.weight) {
                        k.node.weight = weight;
                        k.prev = poll;
                    }
                    weight = k.node.weight + v;
                    if (weight < poll.node.weight) {
                        poll.node.weight = weight;
                        poll.prev = k;
                    }
                    if (!book.contains(k.node.name)) queue.offer(k);
                });
                atomic.addAndGet(poll.node.weight);
            }
        }
    }

    public static final class Dijkstra extends ShortestPath {
        public Dijkstra(String name, int size) {
            super(name, size);
        }

        @Override
        protected void update() {
            PQueue<ShortestPath> queue = PQueue.of(Comparator.comparingInt(t -> t.node.weight));
            Set<String> book = new HashSet<>();
            queue.offer(this);
            while (!queue.isEmpty()) {
                ShortestPath poll = queue.poll();
                if (poll == null) continue;
                book.add(poll.node.name);
                poll.nodes.forEach((k, v) -> {
                    int weight = poll.node.weight + v;
                    if (k.node.weight == null || weight < k.node.weight) {
                        k.node.weight = weight;
                        k.prev = poll;
                    }
                    if (!book.contains(k.node.name)) queue.offer(k);
                });
            }
        }
    }
}