package org.mule.weave.v2.interpreted.transform.phase

import org.mule.weave.compiler.WeaveCompilerConfigProperties.DISABLED_CONSTANT_FOLDING_PHASE
import org.mule.weave.v2.grammar.AdditionOpId
import org.mule.weave.v2.grammar.DivisionOpId
import org.mule.weave.v2.grammar.MultiplicationOpId
import org.mule.weave.v2.grammar.SubtractionOpId
import org.mule.weave.v2.model.values.math.Number
import org.mule.weave.v2.parser.annotation.PreCompiledTypeAnnotation
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.structure.NumberNode
import org.mule.weave.v2.parser.ast.structure.StringInterpolationNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier.CORE_MODULE
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.MutableAstNode
import org.mule.weave.v2.parser.location.WeaveLocation
import org.mule.weave.v2.parser.phase.CompilationPhase
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.parser.phase.PhaseResult
import org.mule.weave.v2.parser.phase.ScopeGraphPhase
import org.mule.weave.v2.parser.phase.ScopeGraphResult
import org.mule.weave.v2.parser.phase.SuccessResult
import org.mule.weave.v2.runtime.core.operator.math.NumberAdditionOperator
import org.mule.weave.v2.runtime.core.operator.math.NumberDivisionOperator
import org.mule.weave.v2.runtime.core.operator.math.NumberMultiplicationOperator
import org.mule.weave.v2.runtime.core.operator.math.NumberSubtractionNumberOperator
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.scope.Reference
import org.mule.weave.v2.scope.ScopesNavigator
import org.mule.weave.v2.utils.StringEscapeHelper

import scala.annotation.tailrec
import scala.util.Try

/**
  * Runs Constant folding propagation
  *
  * https://en.wikipedia.org/wiki/Constant_folding
  *
  * @tparam T The Type of Node
  */
class ConstantFoldingPhase[T <: AstNode]() extends CompilationPhase[ScopeGraphResult[T], ScopeGraphResult[T]] {

  val PLUS_PLUS_IDENTIFIER: NameIdentifier = CORE_MODULE.::("++")

  private var astGraphMutated = false

  override def doCall(input: ScopeGraphResult[T], ctx: ParsingContext): PhaseResult[ScopeGraphResult[T]] = {
    if (DISABLED_CONSTANT_FOLDING_PHASE || !ctx.shouldRunConstantFoldingPhase() || input.astNode.isAnnotatedWith(classOf[PreCompiledTypeAnnotation])) {
      SuccessResult(input, ctx)
    } else {
      val scopeNavigator: ScopesNavigator = input.scope
      val rootNode = input.astNode
      foldTree(rootNode, scopeNavigator)
      if (astGraphMutated) {
        // We rebuild the scope graph if it was astGraphMutated
        val scopeGraphPhase = new ScopeGraphPhase[T]()
        scopeGraphPhase.call(input, ctx)
      } else {
        SuccessResult(input, ctx)
      }
    }
  }

  private def foldTree(rootNode: AstNode, scopeNavigator: ScopesNavigator): AstNode = {
    updateTree(rootNode, {
      case bon: BinaryOpNode =>
        foldBinaryOpNode(scopeNavigator, bon)

      case fcn: FunctionCallNode =>
        foldFunctionCallNode(scopeNavigator, fcn)

      case sin: StringInterpolationNode =>
        foldStringInterpolationNode(scopeNavigator, sin)

      case _ => None
    })
  }

  def updateTree(astNode: AstNode, mapper: AstNode => Option[AstNode]): AstNode = {
    astNode match {
      case mn: MutableAstNode =>
        astNode.children().foreach(child => {
          val newChild = updateTree(child, mapper)
          val maybeMapped = mapper(newChild)
          maybeMapped match {
            case Some(value) =>
              mn.update(child, value)
              astGraphMutated = true
            case None =>
          }
        })
      case _ =>
        astNode.children().foreach(n => {
          updateTree(n, mapper)
        })
    }
    astNode
  }

  private def foldBinaryOpNode(sn: ScopesNavigator, bop: BinaryOpNode): Option[AstNode] = {
    bop.binaryOpId match {
      case AdditionOpId | SubtractionOpId | MultiplicationOpId | DivisionOpId =>
        val foldedNode: Option[NumberNode] = Try({
          val leftNumber: Option[Number] = toNumberValue(bop.lhs, sn)
          val rightNumber: Option[Number] = toNumberValue(bop.rhs, sn)
          (leftNumber, rightNumber) match {
            case (Some(leftValue), Some(rightValue)) =>
              val maybeNumber: Option[Number] = bop.binaryOpId match {
                case AdditionOpId       => Some(new NumberAdditionOperator(bop.location()).doAddition(leftValue, rightValue))
                case SubtractionOpId    => Some(new NumberSubtractionNumberOperator(bop.location()).doSubtraction(leftValue, rightValue))
                case MultiplicationOpId => Some(new NumberMultiplicationOperator(bop.location()).doMultiplication(leftValue, rightValue)(None))
                case DivisionOpId       => Some(new NumberDivisionOperator(bop.location()).doDivision(leftValue, rightValue))
                case _                  => None
              }
              maybeNumber.map(n => {
                val node = NumberNode(n.toString)
                node._location = bop._location
                node
              })
            case (_, _) => None
          }
        }).toOption.flatten

        foldedNode
      case _ => None
    }
  }

  private def toNumberValue(node: AstNode, sn: ScopesNavigator): Option[Number] = {
    node match {
      case nn: NumberNode =>
        Some(Number(node.location(), nn.literalValue))
      case vrn: VariableReferenceNode =>
        val maybeReference: Option[Reference] = sn.resolveVariable(vrn.variable)
        maybeReference match {
          case Some(ref) if ref.isLocalReference =>
            val astNavigator: AstNavigator = sn.rootScope.astNavigator()
            val parentNode: Option[AstNode] = astNavigator.parentOf(ref.referencedNode)
            parentNode match {
              case Some(vd: VarDirective) if !astNavigator.isDescendantOf(vd, vrn) =>
                vd.value match {
                  case nn: NumberNode =>
                    Some(Number(nn.location(), nn.literalValue))
                  case _ => None
                }
              case _ => None
            }
          case _ => None
        }
      case _ => None
    }
  }

  private def foldFunctionCallNode(sn: ScopesNavigator, fcn: FunctionCallNode): Option[AstNode] = {
    fcn.function match {
      case vrn: VariableReferenceNode =>
        val maybeReference: Option[Reference] = sn.resolveVariable(vrn.variable)
        if (maybeReference.exists(ref => {
          ref.fqnReferenceName == PLUS_PLUS_IDENTIFIER
        })) {
          val foldedNode: Option[StringNode] = Try({
            // All arguments must be StringNode
            val stringNodesArgs: Seq[StringNode] = fcn.args.args.flatMap(arg => toStringNode(sn, arg))
            if (stringNodesArgs.size == 2) {
              val left = unescapeStringIfNecessary(stringNodesArgs.head)
              val right = unescapeStringIfNecessary(stringNodesArgs(1))
              val value = left ++ right
              val stringNode = StringNode(value)
              stringNode._location = fcn._location
              Some(stringNode)
            } else {
              None
            }
          }).toOption.flatten
          foldedNode
        } else {
          None
        }
      case _ => None
    }
  }

  private def unescapeStringIfNecessary(sn: StringNode): String = {
    sn.quotedBy() match {
      case Some(q) => StringEscapeHelper.unescapeString(sn.value, q, sn.location())
      case None    => sn.value
    }
  }

  private def toStringNode(sn: ScopesNavigator, node: AstNode): Option[StringNode] = {
    node match {
      case sn: StringNode => Some(sn)
      case vrn: VariableReferenceNode =>
        val maybeReference: Option[Reference] = sn.resolveVariable(vrn.variable)
        maybeReference match {
          case Some(ref) if ref.isLocalReference =>
            val astNavigator: AstNavigator = sn.rootScope.astNavigator()
            val parentNode: Option[AstNode] = astNavigator.parentOf(ref.referencedNode)
            parentNode match {
              case Some(vd: VarDirective) if !astNavigator.isDescendantOf(vd, vrn) =>
                vd.value match {
                  case sn: StringNode => Some(sn)
                  case _              => None
                }
              case _ => None
            }
          case _ => None
        }
      case fcn: FunctionCallNode =>
        foldFunctionCallNode(sn, fcn) match {
          case Some(sn: StringNode) => Some(sn)
          case _                    => None
        }
      case _ => None
    }
  }

  private def foldStringInterpolationNode(sn: ScopesNavigator, sin: StringInterpolationNode): Option[AstNode] = {
    val foldedNode = Try({
      reduceStringNodes(sn, sin.elements) match {
        case (sn: StringNode) :: Nil =>
          val node = StringNode(sn.value)
          node._location = sin._location
          Some(sn)
        case elements =>
          val node = StringInterpolationNode(elements)
          node._location = sin._location
          Some(node)
      }
    }).toOption.flatten
    foldedNode
  }

  private def reduceStringNodes(sn: ScopesNavigator, elements: Seq[AstNode]): Seq[AstNode] = {
    elements match {
      case Nil => Nil
      case _ =>
        takeWhileStringNode(sn, elements) match {
          case Nil =>
            elements.head +: reduceStringNodes(sn, elements.tail)
          case stringNodesWithSources =>
            val builder = new StringBuilder
            stringNodesWithSources.foreach(sn => {
              val unescaped = unescapeStringIfNecessary(sn.sn)
              builder.append(unescaped)
            })
            val value = builder.toString()

            val joinedStringNode = StringNode(value)
            val maybeStartLocation = stringNodesWithSources.head.source._location
            val maybeEndLocation = stringNodesWithSources.last.source._location

            val joinedLocation = (maybeStartLocation, maybeEndLocation) match {
              case (Some(st), Some(el)) =>
                Some(WeaveLocation(startPosition = st.startPosition, endPosition = el.endPosition, resourceName = st.resourceName))
              case (_, _) => None
            }

            joinedStringNode._location = joinedLocation
            joinedStringNode +: reduceStringNodes(sn, elements.drop(stringNodesWithSources.length))
        }
    }
  }

  private def takeWhileStringNode(sn: ScopesNavigator, nodes: Seq[AstNode]): Seq[StringNodeWithSource] = {
    @tailrec
    def loop(sn: ScopesNavigator, xs: Seq[AstNode], acc: Seq[StringNodeWithSource]): Seq[StringNodeWithSource] = {
      if (xs.isEmpty) acc
      else {
        val node = xs.head
        val maybeStringNode = toStringNode(sn, node)
        if (maybeStringNode.isDefined) {
          loop(sn, xs.tail, StringNodeWithSource(node, maybeStringNode.get) +: acc)
        } else {
          acc
        }
      }
    }

    loop(sn, nodes, Seq()).reverse
  }
}

case class StringNodeWithSource(source: AstNode, sn: StringNode)