/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.common;

import com.google.common.collect.Lists;
import com.linkedin.feathr.common.ErasedEntityTaggedFeature;
import com.linkedin.feathr.common.TaggedFeatureName;
import com.linkedin.feathr.common.TaggedFeatureUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

public class FeatureDependencyGraph {
    private Map<String, Node> _nodeMap;

    public FeatureDependencyGraph(Map<String, Set<ErasedEntityTaggedFeature>> dependencyFeatures, Collection<String> anchoredFeatures) {
        HashMap<String, Node> nodes = new HashMap<String, Node>();
        anchoredFeatures.forEach(featureName -> {
            Node node = nodes.computeIfAbsent((String)featureName, Node::new);
            node._anchored = true;
        });
        dependencyFeatures.forEach((featureName, inputFeatures) -> {
            Node node = nodes.computeIfAbsent((String)featureName, Node::new);
            node._dependencies = inputFeatures.stream().map(x -> x.getFeatureName().toString()).distinct().map(x -> nodes.computeIfAbsent((String)x, Node::new)).collect(Collectors.toSet());
            node._inputs = inputFeatures;
        });
        HashSet visited = new HashSet();
        HashSet ancestors = new HashSet();
        nodes.forEach((name, node) -> {
            FeatureDependencyGraph.checkForReachabilityAndCyclicDependencies(ancestors, visited, node);
            if (!ancestors.isEmpty()) {
                throw new RuntimeException(String.format("Assertion failed, ancestor not reachable. ancestors=%s visited=%s node=%s", ancestors, visited, node));
            }
        });
        this._nodeMap = Collections.unmodifiableMap(nodes);
        Set<String> reachableFeatures = this._nodeMap.keySet().stream().filter(this::isReachable).collect(Collectors.toSet());
        List<String> topologicalOrdering = this.findSimpleDependencyOrdering(reachableFeatures);
        topologicalOrdering.forEach(name -> {
            Node node = (Node)nodes.get(name);
            node._maxDepth = node._dependencies.isEmpty() ? 0 : node._dependencies.stream().mapToInt(n -> n._maxDepth).max().getAsInt() + 1;
        });
    }

    private static void checkForReachabilityAndCyclicDependencies(Set<Node> ancestors, Set<Node> visited, Node node) {
        if (ancestors.contains(node)) {
            throw new RuntimeException("Detected dependency cycle: " + ancestors.stream().map(x -> x._name).collect(Collectors.joining("->")));
        }
        if (!visited.contains(node)) {
            visited.add(node);
            boolean allDependenciesAreReachable = !node._dependencies.isEmpty();
            ancestors.add(node);
            for (Node dependency : node._dependencies) {
                FeatureDependencyGraph.checkForReachabilityAndCyclicDependencies(ancestors, visited, dependency);
                allDependenciesAreReachable &= dependency._reachable;
            }
            ancestors.remove(node);
            node._reachable = allDependenciesAreReachable | node._anchored;
        }
    }

    public boolean isDeclared(String feature) {
        return this._nodeMap.containsKey(feature);
    }

    @Deprecated
    public boolean isReachable(String feature) {
        Node node = this._nodeMap.get(feature);
        return node != null && node._reachable;
    }

    public Pair<Boolean, String> isReachableWithErrorMessage(String feature) {
        Node node = this._nodeMap.get(feature);
        String errorMessage = "";
        if (node == null) {
            errorMessage = String.format("Trying to find feature %s in the dependency graph but didn't find any matched feature node. ", feature);
        }
        if (node != null && !node._reachable) {
            errorMessage = String.format("Trying to find dependencies of feature %s in the dependency graph but failed. Please check its dependencies.%n", feature);
        }
        return new Pair<Boolean, String>(node != null && node._reachable, errorMessage);
    }

    private List<String> findSimpleDependencyOrdering(Collection<String> features) {
        LinkedHashSet plan = new LinkedHashSet();
        features.forEach(feature -> this.findDependencyOrderingRecursive((String)feature, plan));
        return new ArrayList<String>(plan);
    }

    private void findDependencyOrderingRecursive(String feature, LinkedHashSet<String> plan) {
        Pair<Boolean, String> reachableWithError = this.isReachableWithErrorMessage(feature);
        this.checkReachable(reachableWithError, feature);
        Node node = this._nodeMap.get(feature);
        if (!node._anchored && this.isReachable(node._name)) {
            node._dependencies.forEach(dependency -> this.findDependencyOrderingRecursive(dependency._name, plan));
        }
        plan.add(feature);
    }

    public List<String> getPlan(Collection<String> features) {
        for (String feature : features) {
            Pair<Boolean, String> reachableWithError = this.isReachableWithErrorMessage(feature);
            this.checkReachable(reachableWithError, feature);
        }
        return this.findSimpleDependencyOrdering(features);
    }

    private void checkReachable(Pair<Boolean, String> reachableWithError, String feature) {
        if (!((Boolean)reachableWithError.fst).booleanValue()) {
            throw new IllegalArgumentException("Requirement failed. Feature " + feature + " can't be resolved in the dependency graph.");
        }
    }

    @Deprecated
    public List<TaggedFeatureName> getOrderedPlanForRequest(Collection<TaggedFeatureName> request) {
        List<TaggedFeatureName> planWithStringTags = this.getOrderedPlanForFeatureUrns(request).stream().collect(Collectors.toList());
        return planWithStringTags;
    }

    public List<TaggedFeatureName> getOrderedPlanForFeatureUrns(Collection<TaggedFeatureName> request) {
        List keyList = request.stream().flatMap(x -> x.getKeyTag().stream()).distinct().collect(Collectors.toList());
        List<ErasedEntityTaggedFeature> erased = request.stream().map(x -> TaggedFeatureUtils.eraseStringTags(x, keyList)).collect(Collectors.toList());
        List<ErasedEntityTaggedFeature> plan = this.getOrderedPlanWithIntegerKeys(erased);
        List<TaggedFeatureName> planWithStringTags = plan.stream().map(x -> TaggedFeatureUtils.getTaggedFeatureNameFromStringTags(x, keyList)).collect(Collectors.toList());
        return planWithStringTags;
    }

    public List<Set<TaggedFeatureName>> getComputationPipeline(Collection<TaggedFeatureName> requestedFeatures) {
        List<TaggedFeatureName> orderedPlanForFeatureUrns = this.getOrderedPlanForFeatureUrns(requestedFeatures);
        TreeMap<Integer, Set> featureStages = new TreeMap<Integer, Set>();
        for (TaggedFeatureName featureUrn : orderedPlanForFeatureUrns) {
            Node currentFeatureNode = this._nodeMap.get(featureUrn.getFeatureName().toString());
            featureStages.computeIfAbsent(currentFeatureNode._maxDepth, k -> new HashSet()).add(featureUrn);
        }
        return new ArrayList<Set<TaggedFeatureName>>(featureStages.values());
    }

    private List<ErasedEntityTaggedFeature> getOrderedPlanWithIntegerKeys(Collection<ErasedEntityTaggedFeature> requested) {
        List<String> distinctFeatureNames = requested.stream().map(name -> name.getFeatureName().toString()).distinct().collect(Collectors.toList());
        List<String> ordering = this.findSimpleDependencyOrdering(distinctFeatureNames);
        HashMap scratch = new HashMap();
        requested.forEach(tfn -> scratch.computeIfAbsent(tfn.getFeatureName().toString(), x -> new HashSet()).add(tfn.getBinding()));
        Lists.reverse(ordering).forEach(featureName -> {
            if (!scratch.containsKey(featureName)) {
                throw new RuntimeException("dependency graph " + scratch.toString() + " doesn't has this feature " + featureName);
            }
            Set<ErasedEntityTaggedFeature> dependencies = this._nodeMap.get((Object)featureName)._inputs;
            Set keyArgumentGroups = (Set)scratch.get(featureName);
            for (List keyArguments : keyArgumentGroups) {
                for (ErasedEntityTaggedFeature dependency : dependencies) {
                    List<Integer> dependencyKeyParameters = dependency.getBinding();
                    String dependencyFeatureName = dependency.getFeatureName().toString();
                    List substitutedKeyParameters = dependencyKeyParameters.stream().map(keyArguments::get).collect(Collectors.toList());
                    scratch.computeIfAbsent(dependencyFeatureName, x -> new HashSet()).add(substitutedKeyParameters);
                }
            }
        });
        List<ErasedEntityTaggedFeature> result = ordering.stream().flatMap(featureName -> ((Set)scratch.get(featureName)).stream().map(keyTag -> new ErasedEntityTaggedFeature((List<Integer>)keyTag, (String)featureName))).collect(Collectors.toList());
        return result;
    }

    public String toString() {
        return this._nodeMap.values().toString();
    }

    private static class Pair<T, T1> {
        public T fst;
        public T1 snd;

        public Pair(T first, T1 second) {
            this.fst = first;
            this.snd = second;
        }
    }

    private static class Node
    implements Serializable {
        String _name;
        Set<Node> _dependencies = Collections.emptySet();
        Set<ErasedEntityTaggedFeature> _inputs = Collections.emptySet();
        boolean _anchored = false;
        boolean _reachable = false;
        int _maxDepth = -1;

        Node(String name) {
            this._name = name;
        }

        public String toString() {
            return this._name + " [reachable=" + this._reachable + ", anchored=" + this._anchored + ", depth=" + this._maxDepth + "] -> " + this._inputs;
        }
    }
}

