package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.grammar.BinaryOpIdentifier
import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.annotation.MaterializeVariableAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.MutableAstNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationArgumentsNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.DirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.InputDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.selectors.ExistsSelectorNode
import org.mule.weave.v2.parser.ast.selectors.NullSafeNode
import org.mule.weave.v2.parser.ast.selectors.NullUnSafeNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.scope.AstNavigator

import scala.collection.mutable.ArrayBuffer

/**
  * This phase is intended to only run in the context of a mule application. In that context all mappings
  * have an implicit input named vars, where a user may define any amount of separate values that they want.
  *
  * Recently we had issues due to materialization of vars, the way we have it implemented in mule we declare it as a
  * java map where its values are considered as already materialized, this causes issues on some cases. For example the
  * following script will not work on mule if the underlying value of one of its elements is an iterator:
  * <pre>
  * %dw 2.9
  * ---
  * if (vars.a?) [1,2,3] map vars.a.b else []
  * </pre>
  * all iterations apart from the first one will result in a null value, materializing vars doesn't propagate inside `a`
  * so the iterator is consumed the first time and never reset.
  *
  * Propagating materialization was attempted, but it caused performance issues on mule applications and had to be reverted.
  *
  * This phase is a second attempt at solving this issue, the way it works is that it will take all first level value
  * selections over `vars` and extract them into new variables. This fixes the issue because execution of a var directive
  * performs materialization of its value, so instead of relying on materialization through propagation from `vars` we
  * attempt materialization on the specific value returned after the selection.
  *
  * In the previous example the script will be transformed into the following
  * <pre>
  * %dw 2.9
  *
  * * &#064;dw::Core::Lazy()
  * * var &#95;&#95;vars_a_en = vars.a?
  *
  * &#064;dw::Core::Lazy()
  * var &#95;&#95;vars_a_nsn = vars.a
  * ---
  * if (&#95;&#95;vars.en) [1,2,3] map &#95;&#95;vars_a_nsn.b else []
  * </pre>
  * Now `vars.a.b` gets evaluated and the result of the evaluation gets materialized.
  *
  * @tparam T result type of the phase
  */
class SingleUseVariableMaterializationPass[T <: AstNodeResultAware[DocumentNode] with DummyScopeNavigatorAware] extends CompilationPhase[T, T] {

  override def doCall(source: T, context: ParsingContext): PhaseResult[T] = {
    if (!context.settings.enableVarsMaterializationPass) {
      return SuccessResult(source, context)
    }

    val rootNode: DocumentNode = source.astNode

    context.settings.forcedMaterializationVariables.foreach(varName => {
      val implicitVarDef = getImplicitVarDef(rootNode, varName)
      if (implicitVarDef.isDefined) {
        val collector = new DeclarationCollector()
        implicitVarDef.foreach(varsDef => {
          val navigator = source.dummyScopesNavigator.rootScope.astNavigator()
          val varUsages: Seq[BinaryOpNode] = source.dummyScopesNavigator
            .resolveLocalReferencedBy(varsDef)
            .flatMap(r => navigator.parentWithTypeMaxLevel(r.referencedNode, classOf[BinaryOpNode], 2))
            .filter(validReplaceExpression)
          val expressionNodes = varUsages.flatMap(buildReplacement(_, navigator))

          if (expressionNodes.nonEmpty) {
            val varRoot: RootExpressionNode = expressionNodes.reduce(merge)
            if (varRoot != null && varRoot.children.nonEmpty) {
              var insertAfterNode = navigator.parentWithTypeMaxLevel(implicitVarDef.get, classOf[DirectiveNode], 3).orNull
              var needsMaterialize = varRoot.children.size > 1
              varRoot.children.foreach(child => {
                val newRef = VariableReferenceNode(
                  NameIdentifier(varRoot.variableReferenceNode.variable.name),
                  varRoot.variableReferenceNode.codeAnnotations)

                val newValueNode = child.kind match {
                  case SelectionKind.NULL_SAFE =>
                    NullSafeNode(BinaryOpNode(child.opId, newRef, child.expression.toSelectorNode))
                  case SelectionKind.NULL_UNSAFE =>
                    NullUnSafeNode(BinaryOpNode(child.opId, newRef, child.expression.toSelectorNode))
                  case SelectionKind.EXISTS =>
                    ExistsSelectorNode(BinaryOpNode(child.opId, newRef, child.expression.toSelectorNode))
                }
                val newDirective = collector.addDeclaration(newValueNode, insertAfterNode, varName, child.expression.path(), child.kind, needsMaterialize)

                child.replacements.foreach(r => {
                  val parentNode = navigator.parentOf(r).get
                  var replaceFrom: Option[AstNode] = None
                  var toReplace: AstNode = null

                  /**
                    * False case corresponds to chained value selection e.g. `vars.a.b`.
                    *
                    * The ast for the vars.a.b looks roughly like this
                    * <pre>
                    *         NullSafeNode
                    *              |
                    *         BinaryOpNode
                    *          /        \
                    *         /          \
                    *   BinaryOpNode      b
                    *     |     |
                    *    vars   a
                    * </pre>
                    *
                    * Replacing from the grandparent means deleting the chained selection.
                    * For example the pass would transform
                    * <pre>
                    * vars.a.b
                    * </pre>
                    * into
                    * <pre>
                    * &#064;dw::Core::Lazy()
                    * var &#95;&#95;vars_a_nsn = vars.a
                    * ---
                    * &#95;&#95;vars_a_nsn
                    * </pre>
                    */
                  val replaceFromGrandParent = parentNode match {
                    case _: NullSafeNode       => true
                    case _: NullUnSafeNode     => true
                    case _: ExistsSelectorNode => true
                    case _                     => false
                  }

                  if (replaceFromGrandParent) {
                    replaceFrom = navigator.granParentOf(r)
                    toReplace = navigator.parentOf(r).get
                  } else {
                    replaceFrom = navigator.parentOf(r)
                    toReplace = r
                  }

                  replaceFrom match {
                    case Some(masn: MutableAstNode) =>
                      val node = VariableReferenceNode(NameIdentifier(newDirective.variable.name), varRoot.variableReferenceNode.codeAnnotations)
                      node.annotate(InjectedNodeAnnotation())
                      masn.update(toReplace, node)
                    case _ =>
                  }
                })
                insertAfterNode = newDirective
              })
            }
          }

          val maybeHeaderNode = rootNode match {
            case dn: DocumentNode => Some(dn.header)
            case _                => None
          }
          maybeHeaderNode.map(collector.createDeclarations)
        })
      }
    })

    SuccessResult(source, context)
  }

  private def getImplicitVarDef(rootNode: DocumentNode, name: String): Option[NameIdentifier] = {
    rootNode.header.directives.find({
      case id: InputDirective => id.variable.name == name && AstNodeHelper.isInjectedNode(id)
      case _                  => false
    }).map(_.asInstanceOf[InputDirective].variable)
  }

  private def validReplaceExpression(bo: BinaryOpNode): Boolean = {
    val validRhs = bo.rhs match {
      case _: StringNode => true
      case nn: NameNode  => nn.keyName.isInstanceOf[StringNode]
      case _             => false
    }
    val validLhs = bo.lhs.isInstanceOf[VariableReferenceNode]
    validRhs && validLhs
  }

  private def buildReplacement(node: BinaryOpNode, astNavigator: AstNavigator): Option[RootExpressionNode] = {
    node match {
      case bn @ BinaryOpNode(opId, v @ VariableReferenceNode(_, _), rhs, _) =>
        val maybeSelector = rhs match {
          case NameNode(sn @ StringNode(value, _), None, _) => Some(SelectorExpression(value, sn.quotedBy()))
          case _ => None
        }
        maybeSelector.flatMap(selector => {
          val rootNode: AstNode = astNavigator.parentOf(bn).orNull

          val kind = rootNode match {
            case _: NullSafeNode       => SelectionKind.NULL_SAFE
            case _: NullUnSafeNode     => SelectionKind.NULL_UNSAFE
            case _: ExistsSelectorNode => SelectionKind.EXISTS
            case _                     => SelectionKind.NULL_SAFE
          }

          val root = RootExpressionNode(v)
          val newVal = SelectorExpressionNode(selector, opId, ArrayBuffer(bn), kind)
          root.addChild(newVal)
          Some(root)
        })
      case _ => None
    }
  }

  private def merge(acc: RootExpressionNode, v: RootExpressionNode): RootExpressionNode = {
    v.children.headOption match {
      case Some(child: SelectorExpressionNode) =>
        val maybeMatchingChild = acc.children.find(selector => selector.canBeMerged(child))
        maybeMatchingChild match {
          case Some(matchingChild: SelectorExpressionNode) => matchingChild.replacements.++=(child.replacements)
          case _ => acc.addChild(child)
        }
      case _ =>
    }
    acc
  }

  /**
    * Collects all the variables that where extracted from the pass that will be injected
    * into the header.
    */
  class DeclarationCollector {

    private val newVarDirectives = ArrayBuffer[InsertAfter]()

    /**
      * Adds a new variables
      *
      * @param value The variable expression node
      * @return The name of the new variable
      */
    def addDeclaration(value: AstNode, afterNode: DirectiveNode, name: String, nameSuffix: String, nodeKind: String, needsMaterialize: Boolean): VarDirective = {
      val nameIdentifier = NameIdentifier(s"__${name}_${nameSuffix}_$nodeKind")
      val directive = VarDirective(nameIdentifier, value)
      directive.setAnnotations(Seq(AnnotationNode(NameIdentifier("dw::Core::Lazy"), Some(AnnotationArgumentsNode(Seq())))))
      directive.annotate(InjectedNodeAnnotation())
      if (needsMaterialize) {
        directive.variable.annotate(new MaterializeVariableAnnotation(true))
      }
      newVarDirectives.+=(InsertAfter(directive, afterNode))
      directive
    }

    /**
      * Adds all the variables declaration to the HeaderNode
      *
      * @param header The header target node to where all the declarations are going to be inserted
      * @return Returns true if variables where added
      */
    def createDeclarations(header: HeaderNode): Unit = {
      if (newVarDirectives.nonEmpty) {
        newVarDirectives.foreach(toInsert => header.addDirectiveAfter(toInsert.directiveToInsert, toInsert.afterNode))
      }
    }

    def hasDeclaration: Boolean = newVarDirectives.nonEmpty
  }

  /**
    * The top node in a selection path
    *
    * @param variableReferenceNode The variable reference
    */
  case class RootExpressionNode(variableReferenceNode: VariableReferenceNode) {
    val children: ArrayBuffer[SelectorExpressionNode] = ArrayBuffer()

    def addChild(expressionNode: SelectorExpressionNode): RootExpressionNode = {
      children.+=(expressionNode)
      this
    }
  }

  /**
    * Represents a node in the expression tree
    *
    * @param expression The name of the selection
    * @param replacements   The replacements are all the nodes that point to this selection path
    */
  case class SelectorExpressionNode(expression: SelectorExpression, opId: BinaryOpIdentifier, replacements: ArrayBuffer[BinaryOpNode], kind: String) {
    def canBeMerged(other: SelectorExpressionNode): Boolean = {
      kind == other.kind && opId == other.opId && expression.canBeMerged(other.expression)
    }
  }

  case class SelectorExpression(name: String, quoted: Option[Char]) {
    def toSelectorNode: AstNode = {
      val stringNode = StringNode(name)
      quoted.foreach(quote => stringNode.withQuotation(quote))
      NameNode(stringNode, None)
    }

    def canBeMerged(childSelector: SelectorExpression): Boolean = {
      childSelector match {
        case SelectorExpression(name, _) => this.name == name
        case _                           => false
      }
    }

    def path(): String = name
  }

  private object SelectionKind {
    val NULL_SAFE = "nsn"
    val NULL_UNSAFE = "nusn"
    val EXISTS = "en"
  }

  case class InsertAfter(directiveToInsert: DirectiveNode, afterNode: DirectiveNode)
}