/*
 * Decompiled with CFR 0.152.
 */
package ai.knowly.langtorch.capability.graph;

import ai.knowly.langtorch.capability.graph.AutoValue_CapabilityGraph;
import ai.knowly.langtorch.capability.graph.NodeAdapter;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.reflect.TypeToken;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;

@AutoValue
public abstract class CapabilityGraph {
    public static CapabilityGraph create() {
        return new AutoValue_CapabilityGraph(new HashMap(), (Multimap<String, Object>)ArrayListMultimap.create(), new HashMap<String, Object>(), (Multimap<String, String>)ArrayListMultimap.create(), new HashMap());
    }

    abstract HashMap<String, NodeAdapter<?, ?>> nodes();

    abstract Multimap<String, Object> inputMap();

    abstract HashMap<String, Object> outputMap();

    abstract Multimap<String, String> inDegreeMap();

    abstract HashMap<String, TypeToken<?>> inputTypes();

    public <I, O> void addNode(NodeAdapter<I, O> nodeAdapter, Class<I> inputType) {
        this.nodes().put(nodeAdapter.getId(), nodeAdapter);
        this.inputTypes().put(nodeAdapter.getId(), TypeToken.of(inputType));
        for (String outDegree : nodeAdapter.getOutDegree()) {
            this.inDegreeMap().put((Object)outDegree, (Object)nodeAdapter.getId());
        }
    }

    public Map<String, Object> process(Map<String, Object> initialInputMap) throws ExecutionException, InterruptedException {
        for (Map.Entry<String, Object> entry : initialInputMap.entrySet()) {
            this.setInitialInput(entry.getKey(), entry.getValue());
        }
        List<String> sortedList = this.topologicalSort();
        for (String id : sortedList) {
            NodeAdapter<?, ?> nodeAdapter = this.nodes().get(id);
            Collection input = this.inputMap().get((Object)id);
            Object output = this.processNode(nodeAdapter, input);
            this.addOutput(id, output);
            for (String outDegree : nodeAdapter.getOutDegree()) {
                this.addInput(outDegree, output);
            }
        }
        HashMap<String, Object> hashMap = new HashMap<String, Object>();
        for (String id : this.getEndNodeIds()) {
            hashMap.put(id, this.outputMap().get(id));
        }
        return hashMap;
    }

    private <I, O> O processNode(NodeAdapter<I, O> nodeAdapter, Collection<Object> input) throws ExecutionException, InterruptedException {
        Collection<Object> typedInput = input;
        return nodeAdapter.process(typedInput);
    }

    public Object getOutput(String id) {
        return this.outputMap().get(id);
    }

    private List<String> getEndNodeIds() {
        ArrayList<String> endNodeIds = new ArrayList<String>();
        for (NodeAdapter<?, ?> nodeAdapter : this.nodes().values()) {
            if (!nodeAdapter.getOutDegree().isEmpty()) continue;
            endNodeIds.add(nodeAdapter.getId());
        }
        return endNodeIds;
    }

    private void setInitialInput(String id, Object input) {
        TypeToken<?> expectedType = this.inputTypes().get(id);
        if (!expectedType.isSupertypeOf(input.getClass())) {
            throw new IllegalArgumentException("Input type for node " + id + " does not match the expected type");
        }
        this.inputMap().put((Object)id, input);
    }

    private void addInput(String id, Object input) {
        this.inputMap().put((Object)id, input);
    }

    private void addOutput(String id, Object output) {
        this.outputMap().put(id, output);
    }

    private List<String> topologicalSort() {
        ArrayList<String> sorted = new ArrayList<String>();
        LinkedList<String> queue = new LinkedList<String>();
        HashMap<String, Integer> inDegrees = new HashMap<String, Integer>();
        for (Map.Entry<String, NodeAdapter<?, ?>> entry : this.nodes().entrySet()) {
            int degree = this.inDegreeMap().get((Object)entry.getKey()).size();
            inDegrees.put(entry.getKey(), degree);
            if (degree != 0) continue;
            queue.offer(entry.getKey());
        }
        while (!queue.isEmpty()) {
            String current = (String)queue.poll();
            sorted.add(current);
            for (String outDegree : this.nodes().get(current).getOutDegree()) {
                int degree = (Integer)inDegrees.get(outDegree) - 1;
                inDegrees.put(outDegree, degree);
                if (degree != 0) continue;
                queue.offer(outDegree);
            }
        }
        if (sorted.size() != this.nodes().size()) {
            throw new IllegalStateException("The graph contains a cycle");
        }
        return sorted;
    }
}

