package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.grammar.BinaryOpIdentifier
import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.MutableAstNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationArgumentsNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.DirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.InputDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.selectors.NullSafeNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.scope.AstNavigator

import scala.collection.mutable.ArrayBuffer

/**
  * This phase is intended to only run in the context of a mule application. In that context all mappings
  * have an implicit input named vars, where a user may define any amount of separate values that they want.
  *
  * Recently we had issues due to materialization of vars, the way we have it implemented in mule we declare it as a
  * java map where its values are considered as already materialized, this causes issues on some cases. For example the
  * following script will not work on mule if the underlying value of one of its elements is an iterator:
  * <pre>
  * %dw 2.6
  * ---
  * [1,2,3] map vars.a.b
  * </pre>
  * all iterations apart from the first one will result in a null value, materializing vars doesn't propagate inside `a`
  * so the iterator is consumed the first time and never reset.
  *
  * Propagating materialization was attempted, but it caused performance issues on mule applications and had to be reverted.
  *
  * This phase is a second attempt at solving this issue, the way it works is that it will take all first level value
  * selections over `vars` and extract them into new variables. This fixes the issue because execution of a var directive
  * performs materialization of its value, so instead of relying on materialization through propagation from `vars` we
  * attempt materialization on the specific value returned after the selection.
  *
  * In the previous example the script will be transformed into the following
  * <pre>
  * %dw 2.6
  *
  * &#064;dw::Core::Lazy()
  * var &#95;&#95;vars_a = vars.a
  * ---
  * [1,2,3] map vars.a.b
  * </pre>
  * Now `vars.a.b` gets evaluated and the result of the evaluation gets materialized.
  *
  * @tparam T result type of the phase
  */
class SingleUseVariableMaterializationPass[T <: AstNodeResultAware[DocumentNode] with DummyScopeNavigatorAware] extends FullCompileOnlyPhase[DocumentNode, T] {

  override def run(source: T, context: ParsingContext): Unit = {
    if (!context.settings.enableVarsMaterializationPass) return

    val rootNode: DocumentNode = source.astNode

    context.settings.forcedMaterializationVariables.foreach(varName => {
      val implicitVarDef = getImplicitVarDef(rootNode, varName)
      if (implicitVarDef.isDefined) {
        val collector = new DeclarationCollector()
        implicitVarDef.foreach(varsDef => {
          val navigator = source.dummyScopesNavigator.rootScope.astNavigator()
          val varsUsages: Seq[BinaryOpNode] = source.dummyScopesNavigator
            .resolveLocalReferencedBy(varsDef)
            .flatMap(r => navigator.parentWithTypeMaxLevel(r.referencedNode, classOf[BinaryOpNode], 2))
            .filter(validReplaceExpression)
          val expressionNodes = varsUsages.flatMap(buildReplacement(_, navigator))

          if (expressionNodes.nonEmpty) {
            val varsRoot: RootExpressionNode = expressionNodes.reduce(merge)
            if (varsRoot != null && varsRoot.children.nonEmpty) {
              var insertAfterNode = navigator.parentWithTypeMaxLevel(implicitVarDef.get, classOf[DirectiveNode], 3).orNull
              varsRoot.children.foreach(child => {
                val newRef = VariableReferenceNode(
                  NameIdentifier(varsRoot.variableReferenceNode.variable.name),
                  varsRoot.variableReferenceNode.codeAnnotations)
                val newValueNode = NullSafeNode(BinaryOpNode(child.opId, newRef, child.expression.toSelectorNode))
                val newDirective = collector.addDeclaration(newValueNode, insertAfterNode, varName, child.expression.path())

                child.replacements.foreach(r => {
                  navigator.parentOf(r) match {
                    case Some(masn: MutableAstNode) =>
                      val node = VariableReferenceNode(NameIdentifier(newDirective.variable.name), varsRoot.variableReferenceNode.codeAnnotations)
                      node.annotate(InjectedNodeAnnotation())
                      masn.update(r, node)
                    case _ =>
                  }
                })
                insertAfterNode = newDirective
              })
            }
          }

          val maybeHeaderNode = rootNode match {
            case dn: DocumentNode => Some(dn.header)
            case _                => None
          }
          maybeHeaderNode.map(collector.createDeclarations)
        })
      }
    })
  }

  private def getImplicitVarDef(rootNode: DocumentNode, name: String): Option[NameIdentifier] = {
    rootNode.header.directives.find({
      case id: InputDirective => id.variable.name == name && AstNodeHelper.isInjectedNode(id)
      case _                  => false
    }).map(_.asInstanceOf[InputDirective].variable)
  }

  private def validReplaceExpression(bo: BinaryOpNode): Boolean = {
    val validRhs = bo.rhs match {
      case _: StringNode => true
      case nn: NameNode  => nn.keyName.isInstanceOf[StringNode]
      case _             => false
    }
    val validLhs = bo.lhs.isInstanceOf[VariableReferenceNode]
    validRhs && validLhs
  }

  private def buildReplacement(node: BinaryOpNode, astNavigator: AstNavigator): Option[RootExpressionNode] = {
    node match {
      case bn @ BinaryOpNode(opId, v @ VariableReferenceNode(_, _), rhs, _) =>
        val maybeSelector = rhs match {
          case NameNode(sn @ StringNode(value, _), None, _) => Some(SelectorExpression(value, sn.quotedBy()))
          case _ => None
        }
        maybeSelector.map(selector => {
          val root = RootExpressionNode(v, astNavigator.parentWithType(v, classOf[NullSafeNode]).get)
          val newVal = SelectorExpressionNode(selector, opId, ArrayBuffer(bn))
          root.addChild(newVal)
          root
        })
      case _ => None
    }
  }

  private def merge(acc: RootExpressionNode, v: RootExpressionNode): RootExpressionNode = {
    v.children.headOption match {
      case Some(child: SelectorExpressionNode) =>
        val childSelector = child.expression
        val opId = child.opId
        val maybeMatchingChild = acc.children.find(selector =>
          selector.expression.canBeMerged(childSelector) && selector.opId == opId)
        maybeMatchingChild match {
          case Some(matchingChild: SelectorExpressionNode) => matchingChild.replacements.++=(child.replacements)
          case _ => acc.addChild(child)
        }
      case _ =>
    }
    acc
  }

  /**
    * Collects all the variables that where extracted from the pass that will be injected
    * into the header.
    */
  class DeclarationCollector {

    private val newVarDirectives = ArrayBuffer[InsertAfter]()

    /**
      * Adds a new variables
      *
      * @param value The variable expression node
      * @return The name of the new variable
      */
    def addDeclaration(value: AstNode, afterNode: DirectiveNode, name: String, nameSuffix: String): VarDirective = {
      val nameIdentifier = NameIdentifier(s"__${name}_$nameSuffix")
      val directive = VarDirective(nameIdentifier, value)
      directive.setAnnotations(Seq(AnnotationNode(NameIdentifier("dw::Core::Lazy"), Some(AnnotationArgumentsNode(Seq())))))
      directive.annotate(InjectedNodeAnnotation())
      newVarDirectives.+=(InsertAfter(directive, afterNode))
      directive
    }

    /**
      * Adds all the variables declaration to the HeaderNode
      *
      * @param header The header target node to where all the declarations are going to be inserted
      * @return Returns true if variables where added
      */
    def createDeclarations(header: HeaderNode): Unit = {
      if (newVarDirectives.nonEmpty) {
        newVarDirectives.foreach(toInsert => header.addDirectiveAfter(toInsert.directiveToInsert, toInsert.afterNode))
      }
    }

    def hasDeclaration: Boolean = newVarDirectives.nonEmpty
  }

  /**
    * The top node in a selection path
    *
    * @param variableReferenceNode The variable reference
    * @param nullSafeNode          The parent root null safe node
    */
  case class RootExpressionNode(variableReferenceNode: VariableReferenceNode, nullSafeNode: NullSafeNode) {
    val children: ArrayBuffer[SelectorExpressionNode] = ArrayBuffer()

    def addChild(expressionNode: SelectorExpressionNode): RootExpressionNode = {
      children.+=(expressionNode)
      this
    }
  }

  /**
    * Represents a node in the expression tree
    *
    * @param expression The name of the selection
    * @param replacements   The replacements are all the nodes that point to this selection path
    */
  case class SelectorExpressionNode(expression: SelectorExpression, opId: BinaryOpIdentifier, replacements: ArrayBuffer[BinaryOpNode])

  case class SelectorExpression(name: String, quoted: Option[Char]) {
    def toSelectorNode: AstNode = {
      val stringNode = StringNode(name)
      quoted.foreach(quote => stringNode.withQuotation(quote))
      NameNode(stringNode, None)
    }

    def canBeMerged(childSelector: SelectorExpression): Boolean = {
      childSelector match {
        case SelectorExpression(name, _) => this.name == name
        case _                           => false
      }
    }

    def path(): String = name
  }

  case class InsertAfter(directiveToInsert: DirectiveNode, afterNode: DirectiveNode)
}