/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.ast.graph.internal;

import static java.util.Collections.emptySet;
import static java.util.Collections.reverse;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Stream.concat;

import static com.google.common.collect.Lists.newArrayList;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.api.util.LazyValue;
import org.mule.runtime.ast.api.ArtifactAst;
import org.mule.runtime.ast.api.ComponentAst;
import org.mule.runtime.ast.graph.api.ArtifactAstDependencyGraph;
import org.mule.runtime.ast.graph.api.ComponentAstDependency;
import org.mule.runtime.ast.graph.internal.cycle.GraphCycleRemover;
import org.mule.runtime.ast.internal.FilteredArtifactAst;

import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

import org.jgrapht.Graph;
import org.jgrapht.traverse.TopologicalOrderIterator;
import org.slf4j.Logger;


public class DefaultArtifactAstDependencyGraph implements ArtifactAstDependencyGraph {

  private static final Logger LOGGER = getLogger(DefaultArtifactAstDependencyGraph.class);

  private final ArtifactAst source;
  private final Graph<ComponentAst, ComponentAstEdge> graph;
  private final Map<ComponentAst, Set<ComponentAst>> transitiveDependenciesCache = new HashMap<>();
  private final Set<ComponentAstDependency> missingDependencies;

  private LazyValue<List<ComponentAst>> dependenciesFirstList;

  public DefaultArtifactAstDependencyGraph(ArtifactAst source, Graph<ComponentAst, ComponentAstEdge> graph,
                                           Set<ComponentAstDependency> missingDependencies) {
    this.source = source;
    this.graph = graph;
    this.missingDependencies = missingDependencies;

    this.dependenciesFirstList = new LazyValue<>(() -> {
      List<ComponentAst> topLevelComponents = source.topLevelComponentsStream().collect(toList());
      Comparator<ComponentAst> comparator = (o1, o2) -> {
        int c1Index = topLevelComponents.indexOf(o1);
        int c2Index = topLevelComponents.indexOf(o2);

        if (c1Index > c2Index) {
          return 1;
        } else if (c1Index < c2Index) {
          return -1;
        } else {
          return 0;
        }
      };

      GraphCycleRemover<ComponentAst, ComponentAstEdge> graphCycleRemover = new GraphCycleRemover<>(graph, comparator);

      List<ComponentAst> componentAsts = newArrayList(new TopologicalOrderIterator<>(graphCycleRemover.removeCycles()));
      // we just need the reverse of the topological order for the graph
      reverse(componentAsts);
      return componentAsts;
    });
  }

  @Override
  public ArtifactAst minimalArtifactFor(ComponentAst vertex) {
    Set<ComponentAst> requiredComponents = findVertexDependenciesWithVertex(vertex);

    logResolvedMinimal(vertex, requiredComponents);
    return new FilteredArtifactAst(source, requiredComponents::contains);
  }

  @Override
  public ArtifactAst minimalArtifactFor(Predicate<ComponentAst> vertexPredicate) {
    Set<ComponentAst> requiredComponents = source.recursiveStream()
        .filter(vertexPredicate)
        .flatMap(vertex -> findVertexDependenciesWithVertex(vertex).stream())
        .collect(toSet());

    logResolvedMinimal(vertexPredicate, requiredComponents);
    return new FilteredArtifactAst(source, requiredComponents::contains);
  }

  @Override
  public Set<ComponentAst> getRequiredComponents(String componentName) {
    Set<ComponentAst> requiredComponents = source
        .recursiveStream().filter(isComponentName(componentName))
        .findFirst()
        .map(vertex -> findVertexDependenciesWithVertex(vertex).stream()
            .collect(toSet()))
        .orElse(emptySet());

    logResolvedMinimal(isComponentName(componentName), requiredComponents);
    return requiredComponents;
  }

  private Predicate<ComponentAst> isComponentName(String componentName) {
    return x -> x.getComponentId().map(name -> name.equals(componentName)).orElse(false);
  }

  @Override
  public Set<ComponentAstDependency> getMissingDependencies() {
    return missingDependencies;
  }

  private Set<ComponentAst> findVertexDependenciesWithVertex(ComponentAst vertex) {
    Set<ComponentAst> requiredComponents = new HashSet<>(findVertexDependencies(vertex));
    requiredComponents.add(vertex);
    return requiredComponents;
  }

  private Set<ComponentAst> findVertexDependencies(ComponentAst vertex) {
    return transitiveDependenciesCache.computeIfAbsent(vertex, comp -> {
      LOGGER.trace("> Processing vertex '{}'...", vertex);
      // need all the needed element depends on
      return graph.outgoingEdgesOf(vertex).stream()
          .flatMap(outgoingEdge -> {
            Stream<ComponentAst> transitiveOutgoingDependenciesOf = outgoingEdge.transitiveOutgoingDependenciesOf(graph);

            if (LOGGER.isTraceEnabled()) {
              final Set<ComponentAst> deps = transitiveOutgoingDependenciesOf.collect(toSet());
              LOGGER.trace("> Transitive outgoing deps:");
              deps.forEach(rq -> LOGGER.debug("    '{}'", rq));
              transitiveOutgoingDependenciesOf = deps.stream();
            }
            return concat(Stream.of(outgoingEdge.getTarget()), transitiveOutgoingDependenciesOf);
          })
          .collect(toSet());
    });
  }

  private void logResolvedMinimal(Object vertexPredicate, Set<ComponentAst> requiredComponents) {
    if (LOGGER.isDebugEnabled()) {
      LOGGER.debug("Minimal artifact for '{}' contains:", vertexPredicate);
      requiredComponents.forEach(rq -> LOGGER.debug("    '{}'", rq));
    }
  }

  @Override
  public Comparator<ComponentAst> dependencyComparator() {
    return (c1, c2) -> {
      List<ComponentAst> dependenciesFirstIndex = dependenciesFirstList.get();
      int c1Index = dependenciesFirstIndex.indexOf(c1);
      int c2Index = dependenciesFirstIndex.indexOf(c2);

      if (c1Index > c2Index) {
        if (LOGGER.isTraceEnabled()) {
          LOGGER.trace("dependencyComparator('{}' > '{}')", c1, c2);
        }
        return 1;
      } else if (c1Index < c2Index) {
        if (LOGGER.isTraceEnabled()) {
          LOGGER.trace("dependencyComparator('{}' < '{}')", c1, c2);
        }
        return -1;
      } else {
        if (LOGGER.isTraceEnabled()) {
          LOGGER.trace("dependencyComparator('{}' == '{}')", c1, c2);
        }
        return 0;
      }
    };
  }
}
