/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.cfg.api;

import static org.mule.runtime.api.component.TypedComponentIdentifier.ComponentType.ERROR_HANDLER;
import static org.mule.runtime.api.message.error.matcher.ErrorTypeMatcherUtils.createErrorTypeMatcher;
import static org.mule.runtime.api.meta.model.parameter.ParameterGroupModel.DEFAULT_GROUP_NAME;
import static org.mule.sdk.api.stereotype.MuleStereotypes.FLOW;
import static org.mule.sdk.api.stereotype.MuleStereotypes.SUB_FLOW;

import static java.lang.String.format;
import static java.util.Collections.reverse;
import static java.util.Optional.empty;
import static java.util.Optional.of;
import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toList;

import org.mule.runtime.api.component.ComponentIdentifier;
import org.mule.runtime.api.util.LazyValue;
import org.mule.runtime.ast.api.ArtifactAst;
import org.mule.runtime.ast.api.ComponentAst;
import org.mule.runtime.ast.api.ComponentParameterAst;
import org.mule.runtime.cfg.internal.node.ChainedExecutionPathNodeBuilder;
import org.mule.runtime.cfg.internal.node.NullNode;
import org.mule.runtime.cfg.internal.node.ReferencedChainNode;
import org.mule.runtime.cfg.internal.node.RouterExecutionPathNodeBuilder;
import org.mule.runtime.cfg.internal.node.ScopeExecutionPathNodeBuilder;
import org.mule.runtime.cfg.internal.node.SimpleOperationNode;
import org.mule.runtime.cfg.internal.node.SourceNode;
import org.mule.runtime.cfg.internal.node.errorhandling.ErrorHandlerNode;
import org.mule.runtime.cfg.internal.node.errorhandling.ErrorHandlerWrapperNode;
import org.mule.runtime.cfg.internal.node.errorhandling.ErrorHandlingContext;
import org.mule.runtime.cfg.internal.node.errorhandling.ErrorHandlingExecutionPathNodeBuilder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * Factory that creates the {@link ChainExecutionPathTree} corresponding to the given {@link ComponentAst}.
 * 
 * @since 1.1.
 */
public class ChainExecutionPathTreeFactory {

  // TODO (W-12591953) Handle this using information from the extension model
  private final static String FLOW_REF = "flow-ref";
  private final static String BATCH_NAMESPACE = "batch";
  private final static ComponentIdentifier BATCH_ON_COMPLETE =
      ComponentIdentifier.builder().namespace(BATCH_NAMESPACE).name("on-complete").build();
  private static final String MUNIT_NAMESPACE = "munit";
  private static final ComponentIdentifier MUNIT_TEST_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("test").build();
  private static final ComponentIdentifier MUNIT_BEHAVIOR_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("behavior").build();
  private static final ComponentIdentifier MUNIT_EXECUTION_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("execution").build();
  private static final ComponentIdentifier MUNIT_VALIDATION_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("validation").build();
  private static final ComponentIdentifier MUNIT_BEFORE_SUITE_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("before-suite").build();
  private static final ComponentIdentifier MUNIT_BEFORE_TEST_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("before-test").build();
  private static final ComponentIdentifier MUNIT_AFTER_SUITE_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("after-suite").build();
  private static final ComponentIdentifier MUNIT_AFTER_TEST_IDENTIFIER =
      ComponentIdentifier.builder().namespace(MUNIT_NAMESPACE).name("after-test").build();
  private static final Set<ComponentIdentifier> MUNIT_CHAINS = new HashSet<>();
  private static final String ANY_POSSIBLE_ERROR = "MULE:ANY";

  static {
    MUNIT_CHAINS.add(MUNIT_TEST_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_BEHAVIOR_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_EXECUTION_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_VALIDATION_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_BEFORE_SUITE_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_BEFORE_TEST_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_AFTER_SUITE_IDENTIFIER);
    MUNIT_CHAINS.add(MUNIT_AFTER_TEST_IDENTIFIER);
  }

  private final ArtifactAst application;
  private final Map<String, ChainExecutionPathTree> cachedTrees = new HashMap<>();

  public ChainExecutionPathTreeFactory(ArtifactAst application) {
    this.application = application;
  }

  /**
   * Creates an {@link ChainExecutionPathTree} for the given {@link ComponentAst} representing a chain (for instance, a
   * {@link #FLOW}
   *
   * @param chainComponentAst the chain {@link ComponentAst} to generate the execution path tree for.
   * @return an execution path tree representing how the chain would be executed.
   */
  public ChainExecutionPathTree generateFor(ComponentAst chainComponentAst) {
    return cachedTrees.computeIfAbsent(chainComponentAst.getLocation().getLocation(),
                                       location -> recursiveGenerateFor(chainComponentAst, new ErrorHandlingContext()));
  }

  private ChainExecutionPathTree generateForWithoutFlows(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    return cachedTrees.computeIfAbsent(chainComponentAst.getLocation().getLocation(),
                                       location -> recursiveGenerateFor(chainComponentAst, errorHandlers, false));
  }

  private ChainExecutionPathTree recursiveGenerateFor(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    return recursiveGenerateFor(chainComponentAst, errorHandlers, true);
  }

  private ChainExecutionPathTree recursiveGenerateFor(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers,
                                                      boolean includeSources) {
    ComponentIdentifier identifier = chainComponentAst.getIdentifier();
    switch (chainComponentAst.getComponentType()) {
      case OPERATION:
        return getOperationNode(chainComponentAst, errorHandlers);
      case SOURCE:
        return includeSources ? new SourceNode(chainComponentAst) : NullNode.getInstance();
      case FLOW:
      case SUB_FLOW:
      case CHAIN:
      case ROUTE:
        return getNestedChainFrom(chainComponentAst, errorHandlers, includeSources);
      case ROUTER:
        return createRouter(chainComponentAst, errorHandlers);
      case SCOPE:
        // TODO (W-17317906) For an unexplained reason, munit:validation is set to be a scope breaking everything
        // if we treat it as one.
        return isMunitChain(identifier) ? getNestedChainFrom(chainComponentAst, errorHandlers, includeSources)
            : createScope(chainComponentAst, errorHandlers);
      case ERROR_HANDLER:
        return createCompleteErrorHandler(chainComponentAst, errorHandlers);
      case ON_ERROR:
        return processErrorHandler(chainComponentAst, errorHandlers);
      case UNKNOWN:
        if (identifier.equals(BATCH_ON_COMPLETE)) {
          return createScope(chainComponentAst, errorHandlers);
        } else if (identifier.getNamespace().equals(BATCH_NAMESPACE)) {
          return getNestedChainFrom(chainComponentAst, errorHandlers, includeSources);
        } else if (identifier.equals(MUNIT_TEST_IDENTIFIER)) {
          return munitTestChain(chainComponentAst, errorHandlers);
        } else if (isMunitChain(identifier)) {
          return getNestedChainFrom(chainComponentAst, errorHandlers, includeSources);
        }
    }
    throw new IllegalStateException(format("ComponentType of ComponentAST is not currently handled by the factory. Component: %s. Location: %s",
                                           chainComponentAst.getIdentifier().getName(),
                                           chainComponentAst.getLocation().getLocation()));
  }

  private ErrorHandlerWrapperNode createCompleteErrorHandler(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    // We need to treat them in reverse order so the first one to appear later on is the first error handler and not the last one
    List<ChainExecutionPathTree> handlers =
        getReferencedErrorHandler(chainComponentAst).orElse(chainComponentAst)
            .directChildren().stream().collect(collectingAndThen(toList(), l -> {
              reverse(l);
              return l;
            })).stream().map(child -> recursiveGenerateFor(child, errorHandlers)).collect(toList());
    errorHandlers.addHandlers(handlers);
    return new ErrorHandlerWrapperNode(chainComponentAst, handlers);
  }

  private Optional<ComponentAst> getReferencedErrorHandler(ComponentAst referentErrorHandler) {
    ComponentParameterAst errorHandlerRef = referentErrorHandler.getParameter(DEFAULT_GROUP_NAME, "ref");
    if (errorHandlerRef != null && errorHandlerRef.getValue().getRight() instanceof String) {
      String errorHandlerName = (String) errorHandlerRef.getValue().getRight();
      return getTopLevelElementWithName(errorHandlerName);
    }
    return empty();
  }

  private ErrorHandlerNode processErrorHandler(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    String errorExpression = chainComponentAst.getParameter(DEFAULT_GROUP_NAME, "type").getRawValue();
    if (errorExpression == null) {
      errorExpression = ANY_POSSIBLE_ERROR;
    }
    ErrorHandlerNode errHandler = new ErrorHandlingExecutionPathNodeBuilder(chainComponentAst)
        .setChild(getNestedChainFrom(chainComponentAst, errorHandlers, false))
        .setErrorMatcher(createErrorTypeMatcher(application.getErrorTypeRepository(), errorExpression)).build();

    if (!"on-error-propagate".equals(chainComponentAst.getIdentifier().getName())
        && !"on-error-continue".equals(chainComponentAst.getIdentifier().getName())) {
      // TODO Think about on-error and other error-handling references (W-12392496)
      throw new IllegalStateException("Identifier type not supported: " + chainComponentAst.getIdentifier().getName());
    }
    return errHandler;
  }

  private ChainExecutionPathTree getNestedChainFrom(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers,
                                                    boolean includeSources) {
    List<ComponentAst> children = new ArrayList<>(chainComponentAst.directChildren());
    // Error handler should be processed first, so it is available while processing the rest of the chain
    ChainedExecutionPathNodeBuilder builder = new ChainedExecutionPathNodeBuilder(chainComponentAst);
    Optional<ComponentAst> hasEH = getErrorHandler(children);
    hasEH.map(err -> recursiveGenerateFor(err, errorHandlers))
        .ifPresent(errHandler -> builder.addOwnedErrorHandler((ErrorHandlerWrapperNode) errHandler));
    children.forEach(child -> builder.addChild(recursiveGenerateFor(child, errorHandlers, includeSources)));
    // The error handler has now to be removed because in following chains, nodes, etc... this error handler doesn't apply anymore
    hasEH.ifPresent(err -> errorHandlers.removeLastHandlers(err.directChildren().size()));
    ChainExecutionPathTree node = builder.build();
    if (node == null) {
      throw new RuntimeException("empty route/scope/flow");
    }
    return node;
  }

  private Optional<ComponentAst> getErrorHandler(List<ComponentAst> children) {
    // If there is an error handler, it would be the last component of the chain. We also remove it in that case from the rest of
    // the elements of the chain, to avoid re-processing it (which would be an error, since it's not part of the actual execution
    // chain, and it could also lead to references from this error handler to itself)
    if (children.isEmpty()) {
      return empty();
    }
    return children.get(children.size() - 1).getComponentType().equals(ERROR_HANDLER) ? of(children.remove(children.size() - 1))
        : empty();
  }

  private Optional<ComponentAst> getTopLevelElementWithName(String name) {
    return application.topLevelComponentsStream().filter(ast -> ast.getComponentId().map(id -> id.equals(name)).orElse(false))
        .findFirst();
  }

  private ChainExecutionPathTree getOperationNode(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    if (chainComponentAst.getIdentifier().getName().equals(FLOW_REF)) {
      Optional<ComponentParameterAst> parameter = chainComponentAst.getParameters().stream()
          .filter(param -> param.getModel().getAllowedStereotypes().stream()
              .anyMatch(stereotypeModel -> stereotypeModel.isAssignableTo(FLOW) || stereotypeModel.isAssignableTo(SUB_FLOW)))
          .findFirst();
      Optional<ComponentAst> flowComponent = parameter.flatMap(param -> getTopLevelElementWithName(param.getRawValue()));
      if (flowComponent.isPresent()) {
        // if it is not present, it could be because it is a dynamic reference, in which case we don't have a full flow
        // reference
        // Error handling context needs to be cloned because it is consumed lazily.
        ErrorHandlingContext errorHandlingContext = new ErrorHandlingContext(errorHandlers);
        return new ReferencedChainNode(chainComponentAst,
                                       new LazyValue<>(() -> generateForWithoutFlows(flowComponent.get(), errorHandlingContext)));
      }
    } else if (chainComponentAst.getIdentifier().getNamespace().equals(BATCH_NAMESPACE)) {
      return createScope(chainComponentAst, errorHandlers);
    }
    return new SimpleOperationNode(chainComponentAst, errorHandlers, application.getErrorTypeRepository());
  }

  private ChainExecutionPathTree createScope(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    return new ScopeExecutionPathNodeBuilder(chainComponentAst)
        .withChild(getNestedChainFrom(chainComponentAst, new ErrorHandlingContext(errorHandlers),
                                      false))
        .withErrorHandlerContext(errorHandlers)
        .build();
  }

  private ChainExecutionPathTree createRouter(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    RouterExecutionPathNodeBuilder router = new RouterExecutionPathNodeBuilder(chainComponentAst);
    chainComponentAst.directChildren()
        .forEach(routeChild -> router
            .withRoute(recursiveGenerateFor(routeChild, new ErrorHandlingContext(errorHandlers))));
    return router.withErrorHandlerContext(errorHandlers).build();
  }

  private ChainExecutionPathTree munitTestChain(ComponentAst chainComponentAst, ErrorHandlingContext errorHandlers) {
    ChainedExecutionPathNodeBuilder builder = new ChainedExecutionPathNodeBuilder(chainComponentAst);
    chainComponentAst.directChildrenStream().map(child -> recursiveGenerateFor(child, errorHandlers)).forEach(builder::addChild);
    return builder.build();
  }

  private static boolean isMunitChain(ComponentIdentifier identifier) {
    return MUNIT_CHAINS.contains(identifier);
  }

}
