/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.compute;

import com.linkedin.data.template.IntegerMap;
import com.linkedin.feathr.compute.AnyNode;
import com.linkedin.feathr.compute.AnyNodeArray;
import com.linkedin.feathr.compute.ComputeGraph;
import com.linkedin.feathr.compute.ComputeGraphBuilder;
import com.linkedin.feathr.compute.Dependencies;
import com.linkedin.feathr.compute.PegasusUtils;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ComputeGraphs {
    private ComputeGraphs() {
    }

    public static ComputeGraph validate(ComputeGraph graph) {
        ComputeGraphs.ensureNodeIdsAreSequential(graph);
        ComputeGraphs.ensureNodeReferencesExist(graph);
        ComputeGraphs.ensureNoDependencyCycles(graph);
        ComputeGraphs.ensureNoExternalReferencesToSelf(graph);
        return graph;
    }

    public static ComputeGraph merge(Collection<ComputeGraph> inputGraphs) {
        ComputeGraphBuilder builder = new ComputeGraphBuilder();
        inputGraphs.forEach(inputGraph -> {
            int offset = builder.peekNextNodeId();
            inputGraph.getNodes().forEach(inputNode -> {
                AnyNode copy = PegasusUtils.copy(inputNode);
                Dependencies.remapDependencies(copy, i -> i + offset);
                builder.addNode(copy);
            });
            inputGraph.getFeatureNames().forEach((featureName, nodeId) -> builder.addFeatureName((String)featureName, nodeId + offset));
        });
        ComputeGraph mergedGraph = builder.build(new ComputeGraph(), false);
        return ComputeGraphs.validate(ComputeGraphs.removeExternalNodesForFeaturesDefinedInThisGraph(mergedGraph));
    }

    public static ComputeGraph removeRedundancies(ComputeGraph inputGraph) throws CloneNotSupportedException {
        Map<Integer, Set<Integer>> whoDependsOnMeIndex = ComputeGraphs.getReverseDependencyIndex(inputGraph);
        Map<Integer, Set<String>> featureDependencyIndex = ComputeGraphs.getReverseFeatureDependencyIndex(inputGraph);
        List<AnyNode> nodes = inputGraph.getNodes().stream().map(PegasusUtils::copy).collect(Collectors.toList());
        nodes.forEach(node -> PegasusUtils.setNodeId(node, 0));
        IntegerMap featureNameMap = inputGraph.getFeatureNames();
        HashMap<AnyNode, Integer> standardizedNodes = new HashMap<AnyNode, Integer>();
        Deque deque = IntStream.range(0, nodes.size()).boxed().collect(Collectors.toCollection(ArrayDeque::new));
        ArrayList<VisitedState> visitedState = new ArrayList<VisitedState>(Collections.nCopies(nodes.size(), VisitedState.NOT_VISITED));
        while (!deque.isEmpty()) {
            int thisNodeId = (Integer)deque.pop();
            if (visitedState.get(thisNodeId) == VisitedState.VISITED) continue;
            AnyNode thisNode = (AnyNode)nodes.get(thisNodeId);
            Set<Integer> myDependencies = new Dependencies().getDependencies(thisNode);
            List<Integer> unfinishedDependencies = myDependencies.stream().filter(i -> visitedState.get((int)i) != VisitedState.VISITED).collect(Collectors.toList());
            if (!unfinishedDependencies.isEmpty()) {
                if (visitedState.get(thisNodeId) == VisitedState.IN_PROGRESS) {
                    throw new RuntimeException("Dependency cycle detected at node " + thisNodeId);
                }
                deque.push(thisNodeId);
                visitedState.set(thisNodeId, VisitedState.IN_PROGRESS);
                unfinishedDependencies.forEach(deque::push);
                continue;
            }
            Integer standardizedNodeId = (Integer)standardizedNodes.get(thisNode);
            if (standardizedNodeId != null) {
                whoDependsOnMeIndex.getOrDefault(thisNodeId, Collections.emptySet()).forEach(nodeWhoDependsOnMe -> Dependencies.remapDependencies((AnyNode)nodes.get((int)nodeWhoDependsOnMe), id -> id == thisNodeId ? standardizedNodeId : id));
                featureDependencyIndex.getOrDefault(thisNodeId, Collections.emptySet()).forEach(featureThatPointsToMe -> {
                    Integer cfr_ignored_0 = (Integer)featureNameMap.put(featureThatPointsToMe, (Object)standardizedNodeId);
                });
            } else {
                standardizedNodes.put(thisNode, thisNodeId);
            }
            visitedState.set(thisNodeId, VisitedState.VISITED);
        }
        standardizedNodes.forEach((node, id) -> PegasusUtils.setNodeId(node, id));
        return ComputeGraphs.reindexNodes(standardizedNodes.keySet(), featureNameMap);
    }

    private static ComputeGraph removeExternalNodesForFeaturesDefinedInThisGraph(ComputeGraph inputGraph) {
        HashMap<Integer, Integer> externalNodeRemappedIds = new HashMap<Integer, Integer>();
        for (int id = 0; id < inputGraph.getNodes().size(); ++id) {
            Integer featureNodeId;
            AnyNode node2 = (AnyNode)inputGraph.getNodes().get(id);
            if (!node2.isExternal() || (featureNodeId = (Integer)inputGraph.getFeatureNames().get((Object)node2.getExternal().getName())) == null) continue;
            externalNodeRemappedIds.put(id, featureNodeId);
        }
        if (externalNodeRemappedIds.isEmpty()) {
            return inputGraph;
        }
        inputGraph.getNodes().forEach(node -> Dependencies.remapDependencies(node, id -> {
            Integer remappedId = (Integer)externalNodeRemappedIds.get(id);
            if (remappedId != null) {
                return remappedId;
            }
            return id;
        }));
        return ComputeGraphs.removeNodes(inputGraph, externalNodeRemappedIds::containsKey);
    }

    static ComputeGraph removeNodes(ComputeGraph computeGraph, Predicate<Integer> predicate) {
        List<AnyNode> nodesToKeep = IntStream.range(0, computeGraph.getNodes().size()).boxed().filter(predicate.negate()).map(arg_0 -> ((AnyNodeArray)computeGraph.getNodes()).get(arg_0)).collect(Collectors.toList());
        return ComputeGraphs.reindexNodes(nodesToKeep, computeGraph.getFeatureNames());
    }

    static ComputeGraph reindexNodes(Collection<AnyNode> nodes, IntegerMap featureNames) {
        HashMap indexRemapping = new HashMap();
        ComputeGraphBuilder builder = new ComputeGraphBuilder();
        nodes.forEach(node -> {
            int oldId = PegasusUtils.getNodeId(node);
            int newId = builder.addNode((AnyNode)node);
            indexRemapping.put(oldId, newId);
        });
        Function<Integer, Integer> remap = oldId -> {
            Integer newId = (Integer)indexRemapping.get(oldId);
            if (newId == null) {
                throw new RuntimeException("Node " + oldId + " not found in subgraph.");
            }
            return newId;
        };
        nodes.forEach(node -> Dependencies.remapDependencies(node, remap));
        featureNames.forEach((featureName, nodeId) -> builder.addFeatureName((String)featureName, (Integer)remap.apply((Integer)nodeId)));
        return builder.build();
    }

    private static Map<Integer, Set<Integer>> getReverseDependencyIndex(ComputeGraph graph) {
        HashMap<Integer, Set<Integer>> reverseDependencies = new HashMap<Integer, Set<Integer>>();
        for (int nodeId = 0; nodeId < graph.getNodes().size(); ++nodeId) {
            AnyNode node = (AnyNode)graph.getNodes().get(nodeId);
            for (int dependencyNodeId : new Dependencies().getDependencies(node)) {
                Set dependentNodes = reverseDependencies.computeIfAbsent(dependencyNodeId, x -> new HashSet());
                dependentNodes.add(nodeId);
            }
        }
        return reverseDependencies;
    }

    static Map<Integer, Set<String>> getReverseFeatureDependencyIndex(ComputeGraph graph) {
        HashMap<Integer, Set<String>> reverseDependencies = new HashMap<Integer, Set<String>>();
        graph.getFeatureNames().forEach((featureName, nodeId) -> {
            Set dependentFeatures = reverseDependencies.computeIfAbsent((Integer)nodeId, x -> new HashSet(1));
            dependentFeatures.add(featureName);
        });
        return reverseDependencies;
    }

    static void ensureNodeIdsAreSequential(ComputeGraph graph) {
        for (int i = 0; i < graph.getNodes().size(); ++i) {
            if (PegasusUtils.getNodeId((AnyNode)graph.getNodes().get(i)) == i) continue;
            throw new RuntimeException("Graph nodes must be ID'd sequentially from 0 to N-1 where N is the number of nodes.");
        }
    }

    static void ensureNodeReferencesExist(ComputeGraph graph) {
        boolean minValidId = false;
        int maxValidId = graph.getNodes().size() - 1;
        graph.getNodes().forEach(anyNode -> {
            Set<Integer> dependencies = new Dependencies().getDependencies((AnyNode)anyNode);
            List missingDependencies = dependencies.stream().filter(id -> id < 0 || id > maxValidId).collect(Collectors.toList());
            if (!missingDependencies.isEmpty()) {
                throw new RuntimeException("Encountered missing dependencies " + missingDependencies + " for node " + anyNode + ". Graph = " + graph);
            }
        });
    }

    static void ensureNoConcreteKeys(ComputeGraph graph) {
        graph.getNodes().forEach(node -> {
            if (node.isExternal() && node.getExternal().hasConcreteKey() || node.isAggregation() && node.getAggregation().hasConcreteKey() || node.isDataSource() && node.getDataSource().hasConcreteKey() || node.isLookup() && node.getLookup().hasConcreteKey() || node.isTransformation() && node.getTransformation().hasConcreteKey()) {
                throw new RuntimeException("A concrete key has already been set for the node " + node);
            }
        });
    }

    static void ensureNoExternalReferencesToSelf(ComputeGraph graph) {
        graph.getNodes().stream().filter(AnyNode::isExternal).forEach(node -> {
            String featureName = node.getExternal().getName();
            if (graph.getFeatureNames().containsKey((Object)featureName)) {
                throw new RuntimeException("Graph contains External node " + node + " but also contains feature " + featureName + " in its feature name table: " + graph.getFeatureNames() + ". Graph = " + graph);
            }
        });
    }

    static void ensureNoDependencyCycles(ComputeGraph graph) {
        Deque deque = IntStream.range(0, graph.getNodes().size()).boxed().collect(Collectors.toCollection(ArrayDeque::new));
        ArrayList<VisitedState> visitedState = new ArrayList<VisitedState>(Collections.nCopies(graph.getNodes().size(), VisitedState.NOT_VISITED));
        while (!deque.isEmpty()) {
            int nodeId = (Integer)deque.pop();
            if (visitedState.get(nodeId) == VisitedState.VISITED) continue;
            AnyNode node = (AnyNode)graph.getNodes().get(nodeId);
            Set<Integer> dependencies = new Dependencies().getDependencies(node);
            List<Integer> unfinishedDependencies = dependencies.stream().filter(i -> visitedState.get((int)i) != VisitedState.VISITED).collect(Collectors.toList());
            if (!unfinishedDependencies.isEmpty()) {
                if (visitedState.get(nodeId) == VisitedState.IN_PROGRESS) {
                    throw new RuntimeException("Dependency cycle involving node " + nodeId);
                }
                deque.push(nodeId);
                unfinishedDependencies.forEach(deque::push);
                visitedState.set(nodeId, VisitedState.IN_PROGRESS);
                continue;
            }
            visitedState.set(nodeId, VisitedState.VISITED);
        }
    }

    private static enum VisitedState {
        NOT_VISITED,
        IN_PROGRESS,
        VISITED;

    }
}

