/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.feathr.compute;

import com.google.common.collect.Sets;
import com.linkedin.data.template.IntegerArray;
import com.linkedin.feathr.compute.Aggregation;
import com.linkedin.feathr.compute.AnyNode;
import com.linkedin.feathr.compute.ConcreteKey;
import com.linkedin.feathr.compute.DataSource;
import com.linkedin.feathr.compute.External;
import com.linkedin.feathr.compute.Lookup;
import com.linkedin.feathr.compute.NodeReference;
import com.linkedin.feathr.compute.PegasusUtils;
import com.linkedin.feathr.compute.Transformation;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class Dependencies {
    public Set<Integer> getDependencies(AnyNode anyNode) {
        return Sets.union(this.getKeyDependencies(anyNode), Dependencies.getNodeDependencies(anyNode));
    }

    private Set<Integer> getKeyDependencies(AnyNode anyNode) {
        if (PegasusUtils.hasConcreteKey(anyNode)) {
            return new HashSet<Integer>((Collection<Integer>)PegasusUtils.getConcreteKey(anyNode).getKey());
        }
        return Collections.emptySet();
    }

    private static Set<Integer> getNodeDependencies(AnyNode anyNode) {
        if (anyNode.isAggregation()) {
            return Dependencies.getNodeDependencies(anyNode.getAggregation());
        }
        if (anyNode.isDataSource()) {
            return Dependencies.getNodeDependencies(anyNode.getDataSource());
        }
        if (anyNode.isLookup()) {
            return Dependencies.getNodeDependencies(anyNode.getLookup());
        }
        if (anyNode.isTransformation()) {
            return Dependencies.getNodeDependencies(anyNode.getTransformation());
        }
        if (anyNode.isExternal()) {
            return Dependencies.getNodeDependencies(anyNode.getExternal());
        }
        throw new RuntimeException("Unhandled kind of AnyNode: " + anyNode);
    }

    private static Set<Integer> getNodeDependencies(Aggregation node) {
        return Collections.singleton(node.getInput().getId());
    }

    private static Set<Integer> getNodeDependencies(Transformation node) {
        return node.getInputs().stream().map(NodeReference::getId).collect(Collectors.toSet());
    }

    private static Set<Integer> getNodeDependencies(Lookup node) {
        HashSet<Integer> dependencies = new HashSet<Integer>();
        node.getLookupKey().stream().filter(Lookup.LookupKey::isNodeReference).map(Lookup.LookupKey::getNodeReference).map(NodeReference::getId).forEach(dependencies::add);
        dependencies.add(node.getLookupNode());
        return dependencies;
    }

    private static Set<Integer> getNodeDependencies(DataSource node) {
        return Collections.emptySet();
    }

    private static Set<Integer> getNodeDependencies(External node) {
        return Collections.emptySet();
    }

    static void remapDependencies(AnyNode anyNode, Function<Integer, Integer> idMapping) {
        Dependencies.remapKeyDependencies(anyNode, idMapping);
        Dependencies.remapNodeDependencies(anyNode, idMapping);
    }

    private static void remapKeyDependencies(AnyNode anyNode, Function<Integer, Integer> idMapping) {
        if (PegasusUtils.hasConcreteKey(anyNode)) {
            ConcreteKey concreteKey = PegasusUtils.getConcreteKey(anyNode);
            IntegerArray newKeyDependencies = concreteKey.getKey().stream().map(idMapping).collect(Collectors.toCollection(IntegerArray::new));
            concreteKey.setKey(newKeyDependencies);
        }
    }

    private static void remapNodeDependencies(AnyNode anyNode, Function<Integer, Integer> idMapping) {
        if (anyNode.isAggregation()) {
            Dependencies.remapNodeDependencies(anyNode.getAggregation(), idMapping);
        } else if (!anyNode.isDataSource()) {
            if (anyNode.isLookup()) {
                Dependencies.remapNodeDependencies(anyNode.getLookup(), idMapping);
            } else if (anyNode.isTransformation()) {
                Dependencies.remapNodeDependencies(anyNode.getTransformation(), idMapping);
            } else if (!anyNode.isExternal()) {
                throw new RuntimeException("Unhandled kind of AnyNode: " + anyNode);
            }
        }
    }

    private static void remapNodeDependencies(Aggregation node, Function<Integer, Integer> idMapping) {
        int oldInputNodeId = node.getInput().getId();
        int newNodeId = idMapping.apply(oldInputNodeId);
        node.getInput().setId(newNodeId);
    }

    private static void remapNodeDependencies(Transformation node, Function<Integer, Integer> idMapping) {
        node.getInputs().forEach(input -> {
            int oldInputNodeId = input.getId();
            int newNodeId = (Integer)idMapping.apply(oldInputNodeId);
            input.setId(newNodeId);
        });
    }

    private static void remapNodeDependencies(Lookup node, Function<Integer, Integer> idMapping) {
        int oldLookupNodeId = node.getLookupNode();
        int newLookupNodeId = idMapping.apply(oldLookupNodeId);
        node.setLookupNode(newLookupNodeId);
        node.getLookupKey().forEach(lookupKey -> {
            if (lookupKey.isNodeReference()) {
                NodeReference nodeReference = lookupKey.getNodeReference();
                int oldReferenceNodeId = nodeReference.getId();
                int newReferenceNodeId = (Integer)idMapping.apply(oldReferenceNodeId);
                nodeReference.setId(newReferenceNodeId);
            }
        });
    }
}

