package org.mule.weave.v2.parser.ast

import org.mule.weave.v2.grammar.MetadataAdditionOpId
import org.mule.weave.v2.parser.annotation.AstNodeAnnotation
import org.mule.weave.v2.parser.annotation.CustomInterpolationAnnotation
import org.mule.weave.v2.parser.annotation.InfixNotationFunctionCallAnnotation
import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast.conditional.DefaultNode
import org.mule.weave.v2.parser.ast.conditional.IfNode
import org.mule.weave.v2.parser.ast.conditional.UnlessNode
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.header.directives
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.InputDirective
import org.mule.weave.v2.parser.ast.header.directives.NamespaceDirective
import org.mule.weave.v2.parser.ast.header.directives.OutputDirective
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.patterns.DeconstructArrayPatternNode
import org.mule.weave.v2.parser.ast.patterns.PatternMatcherNode
import org.mule.weave.v2.parser.ast.selectors.NullSafeNode
import org.mule.weave.v2.parser.ast.selectors.NullUnSafeNode
import org.mule.weave.v2.parser.ast.structure.ArrayNode
import org.mule.weave.v2.parser.ast.structure.AttributesNode
import org.mule.weave.v2.parser.ast.structure.ConditionalNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.KeyNode
import org.mule.weave.v2.parser.ast.structure.KeyValuePairNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.NameValuePairNode
import org.mule.weave.v2.parser.ast.structure.NamespaceNode
import org.mule.weave.v2.parser.ast.structure.ObjectNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.scope.ScopesNavigator

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * Helper class with functions that helps querying the AST
  */
object AstNodeHelper {

  val ANONYMOUS_FUNCTION = "AnonymousFunction"

  def isWeaveTypeNode(nodeToExtract: AstNode): Boolean = {
    nodeToExtract.isInstanceOf[WeaveTypeNode]
  }

  def isExpressionNode(nodeToExtract: AstNode) = {
    nodeToExtract.isInstanceOf[ExpressionAstNode]
  }

  def lastNode(root: AstNode): AstNode = {
    val nodes = root.children()
    if (nodes.isEmpty) {
      root
    } else {
      //Can not be changed to maxBy as in cases where the end index is the same I want to keep the order of the childs
      val astNode = nodes
        .sortBy(_.location().startPosition.index)
        .sortBy(_.location().endPosition.index)
        .last
      lastNode(astNode)
    }
  }

  def functionRefName(value: AstNode, scopesNavigator: ScopesNavigator): String = {
    value match {
      case vrn: VariableReferenceNode => {
        scopesNavigator.resolveVariable(vrn.variable) match {
          case Some(value) => {
            val referencedNode = value.referencedNode
            value.scope.astNavigator().parentOf(referencedNode) match {
              case Some(fdn: FunctionDirectiveNode) => {
                fdn.variable.name
              }
              case _ => {
                referencedNode.name
              }
            }
          }
          case _ => {
            vrn.variable.name
          }
        }
      }
      case _ => {
        ANONYMOUS_FUNCTION
      }
    }
  }

  private def isFunctionOfArity(fdn: FunctionDirectiveNode, arity: Int) = {
    fdn.literal match {
      case fn: FunctionNode =>
        fn.params.paramList.size == arity
      case ofn: OverloadedFunctionNode => {
        ofn.functions.forall(_.params.paramList.size == arity)
      }
      case _ => false
    }
  }

  /**
    * Returns true if the function is a binary function. It has only two arguments.
    *
    * @param fdn The function definition to validate
    * @return True if it is a binary function definition
    */
  def isBinaryFunctionDirective(fdn: FunctionDirectiveNode): Boolean = {
    isFunctionOfArity(fdn, 2)
  }

  def isUnaryFunctionDirective(fdn: FunctionDirectiveNode): Boolean = {
    isFunctionOfArity(fdn, 1)
  }

  def isTernaryFunctionDirective(fdn: FunctionDirectiveNode): Boolean = {
    isFunctionOfArity(fdn, 3)
  }

  /**
    * Returns the Namespace Directives
    *
    * @param document The root document
    * @return The collection of Namespace
    */
  def getNamespaceDirectives(document: AstNode): Seq[NamespaceDirective] = {
    document match {
      case dn: DocumentNode => {
        collectDirectChildrenWith(dn.header, classOf[NamespaceDirective])
      }
      case mn: ModuleNode => {
        collectDirectChildrenWith(mn, classOf[NamespaceDirective])
      }
      case _ => {
        collectChildrenWith(document, classOf[NamespaceDirective])
      }
    }
  }

  /**
    * Selects all the AstNode value of the provided ObjectNode where the key name matches the provided one
    *
    * @param name The field name
    * @param on   The ObjectNode where to search
    * @return list of values
    */
  def selectAllFieldValue(name: String, on: ObjectNode): Seq[AstNode] = {
    on.elements.toStream.flatMap({
      case KeyValuePairNode(key: KeyNode, value, _) => {
        key.keyName match {
          case str: StringNode if (str.value == name) => Some(value)
          case _                                      => None
        }
      }
      case _ => None
    })
  }

  /**
    * Selects the first the AstNode value of the provided ObjectNode where the key name matches the provided one
    *
    * @param name The field name
    * @param on   The ObjectNode where to search
    * @return value if any match
    */
  def selectFirstFieldValue(name: String, on: ObjectNode): Option[AstNode] = {
    selectAllFieldValue(name, on).headOption
  }

  def containsConditionalElements(array: ArrayNode): Boolean = {
    array.elements.exists({
      case _: ConditionalNode => true
      case _                  => false
    })
  }

  def isArrayLiteral(on: ArrayNode, scopesNavigator: ScopesNavigator): Boolean = {
    val dynamicArray = on.elements.exists({
      case _: ConditionalNode => true
      case v                  => !isLiteralValue(v, scopesNavigator)
    })
    !dynamicArray
  }

  def isConstantType(value: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    value match {
      case _: FunctionNode => true
      case nn: NameNode => {
        //We can improve the namespace support we should validate if it reference an static namespace declaration or not
        isConstantType(nn.keyName, scopesNavigator) && nn.ns.isEmpty
      }
      case vrn: VariableReferenceNode => {
        scopesNavigator.resolveVariable(vrn.variable) match {
          case Some(ref) if (ref.isLocalReference) => {
            ref.scope.astNavigator().parentOf(ref.referencedNode) match {
              case Some(deconstructArrayPatternNode: DeconstructArrayPatternNode) => {
                deconstructArrayPatternNode.tail eq ref.referencedNode
              }
              case _ => false
            }
          }
          case _ => false
        }
      }
      case _ => isLiteralValue(value, scopesNavigator)
    }
  }

  def isFunctionParameter(value: AstNode, scopesNavigator: ScopesNavigator, fn: FunctionNode): Boolean = {
    value match {

      case vrn: VariableReferenceNode => {
        scopesNavigator.resolveVariable(vrn.variable) match {
          case Some(ref) if (ref.isLocalReference) => {
            scopesNavigator.astNavigator().parentOf(ref.referencedNode) match {
              case Some(fp: FunctionParameter) => {
                scopesNavigator
                  .astNavigator()
                  .parentWithType(fp, classOf[FunctionNode])
                  .exists((pfn) => pfn eq fn)
              }
              case _ => false
            }
          }
          case _ => false
        }
      }
      case _ => false
    }
  }

  def isLiteralValue(value: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    value match {
      case on: ObjectNode => isObjectLiteral(on, scopesNavigator)
      case on: ArrayNode  => isArrayLiteral(on, scopesNavigator)
      case v              => isLiteral(v)
    }
  }

  def isLiteralName(key: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    key match {
      case nn: NameNode => {
        isLiteral(nn.keyName) && isLiteralNamespace(nn.ns, scopesNavigator)
      }
      case _ => false
    }
  }

  def isLiteralAttributes(attributesNode: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    attributesNode match {
      case attr: AttributesNode => {
        val dynamicAttributes = attr.attrs.exists({
          case attr: NameValuePairNode => {
            if (attr.cond.isEmpty) {
              !(isLiteralName(attr.key, scopesNavigator) && isLiteralValue(attr.value, scopesNavigator))
            } else {
              true
            }
          }
          case _ => true

        })
        !dynamicAttributes
      }
      case _ => false
    }
  }

  def isLiteralKey(key: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    key match {
      case kn: KeyNode => {
        val ns = kn.ns
        val nsLiteral = isLiteralNamespace(ns, scopesNavigator)
        isLiteral(kn.keyName) &&
          (kn.attr.isEmpty || isLiteralAttributes(kn.attr.get, scopesNavigator)) &&
          nsLiteral
      }
      case _ => false
    }
  }

  def isLiteralNamespace(ns: Option[AstNode], scopesNavigator: ScopesNavigator): Boolean = {
    ns.isEmpty || (ns.get match {
      case ns: NamespaceNode => {
        val maybeReference = scopesNavigator.resolveVariable(ns.prefix)
        if (maybeReference.isDefined) {
          val reference = maybeReference.get
          val node = reference.referencedNode
          //If ref is child of Namespace Directive is literal otherwise is a dynamic expression
          val maybeDirective = reference.scope.astNavigator().parentWithType(node, classOf[NamespaceDirective])
          maybeDirective.isDefined
        } else {
          false
        }
      }
      case _ => false
    })
  }

  def isObjectLiteral(node: ObjectNode, scopesNavigator: ScopesNavigator): Boolean = {
    val dynamicObject = node.elements.exists({
      case kvp: KeyValuePairNode => {
        if (kvp.cond.isEmpty) {
          !(isLiteralKey(kvp.key, scopesNavigator) && isLiteralValue(kvp.value, scopesNavigator))
        } else {
          true
        }
      }
      case _ => true
    })
    !dynamicObject

  }

  def notInjectedNode(node: AstNode): Boolean = {
    node.annotation(classOf[InjectedNodeAnnotation]).isEmpty
  }

  def markInjectedNode(node: AstNode): AstNode = {
    node.annotate(InjectedNodeAnnotation())
    node
  }

  def isInjectedNode(node: AstNode): Boolean = {
    !notInjectedNode(node)
  }

  def isInfixFunctionCall(node: AstNode): Boolean = {
    node.annotation(classOf[InfixNotationFunctionCallAnnotation]).isDefined
  }

  def isCustomInterpolatedNode(node: AstNode): Boolean = {
    node.annotation(classOf[CustomInterpolationAnnotation]).isDefined
  }

  def isMetadataAdditionNode(node: AstNode): Boolean = {
    node match {
      case BinaryOpNode(MetadataAdditionOpId, _, _, _) => true
      case _ => false
    }
  }

  /**
    * Returns true if the node is a recursive call to the specified function identifier
    *
    * @param functionIdentifier The function identifier
    * @param node               The node to check
    * @return True if it is a recursive call to that
    */
  def isTailRecursiveCallExpression(functionIdentifier: NameIdentifier, node: AstNode, scopeNavigator: ScopesNavigator): Boolean = {
    node match {
      case fcn: FunctionCallNode => {
        isRecursiveFunctionCall(functionIdentifier, fcn, scopeNavigator)
      }
      case ifnode: IfNode => {
        isTailRecursiveCallExpression(functionIdentifier, ifnode.ifExpr, scopeNavigator) || isTailRecursiveCallExpression(functionIdentifier, ifnode.elseExpr, scopeNavigator)
      }
      case defaultNode: DefaultNode => {
        isTailRecursiveCallExpression(functionIdentifier, defaultNode.rhs, scopeNavigator)
      }
      case unless: UnlessNode => {
        isTailRecursiveCallExpression(functionIdentifier, unless.ifExpr, scopeNavigator) || isTailRecursiveCallExpression(functionIdentifier, unless.elseExpr, scopeNavigator)
      }
      case usingNode: UsingNode => {
        isTailRecursiveCallExpression(functionIdentifier, usingNode.expr, scopeNavigator)
      }
      case patternMatcher: PatternMatcherNode => {
        patternMatcher.patterns.patterns.exists((pattern) => {
          isTailRecursiveCallExpression(functionIdentifier, pattern.onMatch, scopeNavigator)
        })
      }
      case doNode: DoBlockNode => {
        isTailRecursiveCallExpression(functionIdentifier, doNode.body, scopeNavigator)
      }
      case _ => false
    }
  }

  def isRecursiveFunctionCall(functionName: NameIdentifier, fcn: FunctionCallNode, scopeNavigator: ScopesNavigator): Boolean = {
    val functionRef = fcn.function
    functionRef match {
      case vrn: VariableReferenceNode => {
        //Make sure the variable reference points to the same function
        scopeNavigator.resolveVariable(vrn.variable).exists(_.referencedNode eq functionName)
      }
      case _ => false
    }
  }

  /**
    *
    * @return the first child that matches the given condition
    */
  def find(root: AstNode, cond: (AstNode) => Boolean): Option[AstNode] = {
    if (cond(root)) {
      Some(root)
    } else {
      root
        .children()
        .toStream
        .flatMap((node) => find(node, cond))
        .headOption
    }
  }

  /**
    * @return True if any element in the tree matches the given condition
    */
  def exists(root: AstNode, cond: (AstNode) => Boolean): Boolean = {
    val maybeNode = find(root, cond)
    maybeNode.isDefined
  }

  /**
    * Returns the parent of an ast node
    *
    * @param rootNode The root node to where look for the parent
    * @param node     The child node
    * @return The parent if found
    */
  def parentOf(rootNode: AstNode, node: AstNode): Option[AstNode] = {
    val children = rootNode.children()

    val rootNodeIsParent = children.exists(child => child eq node)

    if (rootNodeIsParent) {
      Some(rootNode)
    } else {
      for (child <- children) {
        val parent = parentOf(child, node)
        if (parent.isDefined) {
          return parent
        }
      }
      None
    }
  }

  /**
    * Returns the list of elements at a given line
    *
    * @param root       The root node where to search
    * @param lineNumber The line number
    * @return The list of nodes declared at this line
    */
  def elementsAtLine(root: AstNode, lineNumber: Int): Seq[AstNode] = {
    collectChildren(root, (node) => {
      node.location().endPosition.line == lineNumber && node.location().startPosition.line == lineNumber
    })
  }

  /**
    * Returns the output mimeType of the document
    *
    * @param documentNode The document to look at
    * @return The output mimeType if present
    */
  def getOutputDirective(documentNode: DirectivesCapableNode): Option[OutputDirective] = {
    documentNode.directives.collectFirst({ case od: OutputDirective => od })
  }

  /**
    * Returns the output mimeType of the document
    *
    * @param documentNode The document to look at
    * @return The output mimeType if present
    */
  def updateOutputDirective(documentNode: DocumentNode, outputDirective: OutputDirective): DocumentNode = {
    documentNode.header.directives = documentNode.header.directives.filterNot(_.isInstanceOf[OutputDirective]) :+ outputDirective
    documentNode
  }

  def createOutputDirective(mime: String): OutputDirective = {
    OutputDirective(directives.ContentType(mime), None, None)
  }

  /**
    * Returns the output mimeType of the document
    *
    * @param documentNode The document to look at
    * @return The output mimeType if present
    */
  def getOutputMimeType(documentNode: DirectivesCapableNode): Option[String] = {
    getOutputDirective(documentNode).flatMap(od => od.mime.flatMap(m => Some(m.mime)))
  }

  /**
    * Returns the output data format of the document
    *
    * @param node The document to look at
    * @return The output mimeType if present
    */
  def getOutputDataFormat(node: DirectivesCapableNode): Option[String] = {
    getOutputDirective(node).flatMap(od => od.dataFormat.flatMap(d => Some(d.id)))
  }

  /**
    * Returns the list of NamespaceDirective declared
    *
    * @param node The node were are declared
    * @return The list of NamespaceDirective
    */
  def namespaceDirective(node: DirectivesCapableNode): Seq[NamespaceDirective] = {
    node.directives.collect({
      case e: NamespaceDirective => e
    })
  }

  /**
    * Returns the list of InputDirective declared in this DocumentNode
    *
    * @param documentNode The document where to search
    * @return The list of InputDirective
    */
  def getInputs(documentNode: DocumentNode): Seq[InputDirective] = {
    documentNode.header.directives.collect({ case id: InputDirective => id })
  }

  def collectAllAnnotationsIn[T <: AstNodeAnnotation](astNode: AstNode, annotationType: Class[T]): Seq[T] = {
    astNode.annotation(annotationType).map(Seq(_)).getOrElse(Seq()) ++ astNode.children().flatMap(collectAllAnnotationsIn(_, annotationType))
  }

  /**
    * Returns whether or not this node represents a native call
    *
    * @param astNode The ast to query
    * @return True if this node is a native call
    */
  @scala.annotation.tailrec
  def isNativeCall(astNode: AstNode, scopesNavigator: ScopesNavigator): Boolean = {
    astNode match {
      case NullSafeNode(selector, _)   => isNativeCall(selector, scopesNavigator)
      case NullUnSafeNode(selector, _) => isNativeCall(selector, scopesNavigator)
      case fcn: FunctionCallNode       => isNativeFunctionRef(scopesNavigator, fcn.function)
      case _                           => false
    }
  }

  @scala.annotation.tailrec
  def getNativeIdentifierCall(astNode: AstNode): Option[NameIdentifier] = {
    astNode match {
      case fcn: FunctionCallNode => {
        fcn.args.args.head match {
          case StringNode(strValue, _) => Some(NameIdentifier(strValue))
          case _                       => None
        }
      }
      case NullSafeNode(selector, _)   => getNativeIdentifierCall(selector)
      case NullUnSafeNode(selector, _) => getNativeIdentifierCall(selector)
      case _                           => None
    }

  }

  @scala.annotation.tailrec
  private def isNativeFunctionRef(scopesNavigator: ScopesNavigator, value: AstNode): Boolean = {
    value match {
      case NullSafeNode(selector, _) => {
        isNativeFunctionRef(scopesNavigator, selector)
      }
      case NullUnSafeNode(selector, _) => {
        isNativeFunctionRef(scopesNavigator, selector)
      }
      case vrn: VariableReferenceNode => {
        scopesNavigator.resolveReference(vrn) match {
          case Some(name: NameIdentifier) =>
            name.localName().name.equals("native")
          case _ => false
        }
      }
      case _ => false
    }
  }

  /**
    * Returns whether this parent contains the specified children or not
    *
    * @param parent The container
    * @param child  The child
    * @return True if child is children of the parent
    */
  def containsChild(parent: AstNode, child: AstNode): Boolean = {
    val parentLocation = parent.location()
    val childLocation = child.location()
    parentLocation.startPosition.index <= childLocation.startPosition.index && parentLocation.endPosition.index >= childLocation.endPosition.index
  }

  /**
    * Returns true if any descendant is accepted by the condition
    *
    * @param parent    The container
    * @param condition The condition
    * @return True if any descendant accepts
    */
  def existsChild(parent: AstNode, condition: AstNode => Boolean): Boolean = {
    val children = parent.children()
    children.exists((node) => condition(node) || existsChild(node, condition))
  }

  /**
    * Collects all descendants that has the specified type
    *
    * @param node      The node to traverse
    * @param classType The type of nodes to be collected
    * @return The matched list of nodes
    */
  def collectChildrenWith[T <: AstNode](node: AstNode, classType: Class[T]): Seq[T] = {
    def doCollect(rootNode: AstNode, classType: Class[T], collector: ArrayBuffer[T]): Unit = {
      val children = new mutable.Stack[AstNode]()
      children.push(rootNode)
      while (children.nonEmpty) {
        val childNodes = children.pop().children()
        var i = 0
        while (i < childNodes.length) {
          val child = childNodes(i)
          if (classType.isAssignableFrom(child.getClass)) {
            collector.+=(classType.cast(child))
          }
          children.push(child)
          i = i + 1
        }
      }
    }

    val collector = new ArrayBuffer[T]()
    doCollect(node, classType, collector)
    collector
  }

  /**
    * Collects all descendants that has the specified type
    *
    * @param node   The node to traverse
    * @param filter The type of nodes to be collected
    * @return The matched list of nodes
    */
  def collectChildren(node: AstNode, filter: (AstNode) => Boolean): Seq[AstNode] = {
    def doCollect(rootNode: AstNode, filter: (AstNode) => Boolean, collector: ArrayBuffer[AstNode]): Unit = {
      val children = new mutable.Stack[AstNode]()
      children.push(rootNode)
      while (children.nonEmpty) {
        val childNodes = children.pop().children()
        var i = 0
        while (i < childNodes.length) {
          val child = childNodes(i)
          if (filter(child)) {
            collector.+=(child)
          }
          children.push(child)
          i = i + 1
        }
      }
    }

    val collector = new ArrayBuffer[AstNode]()
    doCollect(node, filter, collector)
    collector
  }

  /**
    * Traverse over the child nodes
    *
    * @param node   The nodes to be traversed
    * @param filter The callback that is going to be used. If returns false it stops traversing that branch
    */
  def traverseChildren(node: AstNode, filter: (AstNode) => Boolean): Unit = {

    def doTraverse(rootNode: AstNode, filter: (AstNode) => Boolean): Unit = {
      val children = new mutable.Stack[AstNode]()
      children.push(rootNode)
      while (children.nonEmpty) {
        val childNodes = children.pop().children()
        var i = 0
        while (i < childNodes.length) {
          val child = childNodes(i)
          if (filter(child)) {
            children.push(child)
          }
          i = i + 1
        }
      }
    }

    doTraverse(node, filter)
  }

  def traverse(node: AstNode, filter: (AstNode) => Boolean): Unit = {
    if (filter(node)) {
      traverseChildren(node, filter)
    }
  }

  /**
    * Collects direct children that has the specified type
    *
    * @param node      The node to traverse
    * @param classType The type of nodes to be collected
    * @return The matched list of nodes
    */
  def collectDirectChildrenWith[T <: AstNode](node: AstNode, classType: Class[T]): Seq[T] = {
    val collector = new ArrayBuffer[T]()
    val childNodes = node.children()
    var i = 0
    while (i < childNodes.length) {
      val child = childNodes(i)
      if (classType.isAssignableFrom(child.getClass)) {
        collector.+=(classType.cast(child))
      }
      i = i + 1
    }
    collector
  }

  /**
    * Returns true if the node is a simple literal node
    *
    * @param node The node
    * @return true if it is a simple literal node
    */
  def isLiteral(node: AstNode): Boolean = {
    node match {
      case _: LiteralValueAstNode => true
      case _                      => false
    }
  }
}
