package org.mule.weave.v2.interpreted.transform

import org.mule.weave.v2.interpreted.marker.LazyVarDirectiveAnnotation
import org.mule.weave.v2.interpreted.node.NameSlot
import org.mule.weave.v2.interpreted.node.ValueNode
import org.mule.weave.v2.interpreted.node.expressions.InlineTailRecFunctionCall
import org.mule.weave.v2.interpreted.node.expressions.TailRecFunctionBodyNode
import org.mule.weave.v2.interpreted.node.pattern.{ PatternMatcherNode => XPatternMatcherNode }
import org.mule.weave.v2.interpreted.node.structure
import org.mule.weave.v2.interpreted.node.structure.header.directives
import org.mule.weave.v2.interpreted.node.structure.header.directives.Directive
import org.mule.weave.v2.interpreted.node.structure.header.directives.FunctionDirective
import org.mule.weave.v2.interpreted.node.structure.header.directives.{ VarDirective => XVarDirective }
import org.mule.weave.v2.interpreted.node.structure.{ DoBlockNode => XDoBlockNode }
import org.mule.weave.v2.interpreted.node.{ DefaultNode => XDefaultNode }
import org.mule.weave.v2.interpreted.node.{ IfNode => XIfNode }
import org.mule.weave.v2.interpreted.node.{ UnlessNode => XUnlessNode }
import org.mule.weave.v2.interpreted.node.{ UsingNode => XUsingNode }
import org.mule.weave.v2.model.values.math.Number
import org.mule.weave.v2.parser
import org.mule.weave.v2.parser.annotation.TailRecFunctionAnnotation
import org.mule.weave.v2.parser.annotation.TailRecFunctionCallAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.WeaveLocationCapable
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.conditional.DefaultNode
import org.mule.weave.v2.parser.ast.conditional.IfNode
import org.mule.weave.v2.parser.ast.conditional.UnlessNode
import org.mule.weave.v2.parser.ast.functions
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.functions.UsingNode
import org.mule.weave.v2.parser.ast.functions.{ FunctionNode => AstFunctionNode }
import org.mule.weave.v2.parser.ast.header.directives._
import org.mule.weave.v2.parser.ast.patterns.PatternExpressionNode
import org.mule.weave.v2.parser.ast.patterns.PatternMatcherNode
import org.mule.weave.v2.parser.ast.structure.UriNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.runtime.exception.CompilationExecutionException

trait EngineDirectiveTransformations extends AstTransformation with EnginePatternTransformations {
  def transformFormat(format: FormatExpression): directives.FormatExpression = {
    format match {
      case ContentType(mime) => new directives.ContentType(mime)
      case DataFormatId(id)  => new directives.DataFormatId(id)
    }
  }

  def transformContentType(mime: Option[ContentType]) = mime.map(m => new directives.ContentType(m.mime))

  def transformNamespaceDirective(prefix: NameIdentifier, uri: UriNode, transformationStack: TransformationStack): directives.NamespaceDirective = {
    new directives.NamespaceDirective(transform(prefix, transformationStack), transform(uri, transformationStack), !transformingModule)
  }

  def transformVersionMajor(v: String): directives.VersionMajor = new directives.VersionMajor(Number(v))

  def transformDirectiveOption(name: DirectiveOptionName, value: AstNode, transformationStack: TransformationStack): directives.DirectiveOption = {
    new directives.DirectiveOption(structure.StringNode(name.name), transform(value, transformationStack))
  }

  def transformVersionDirective(major: VersionMajor, minor: VersionMinor, transformationStack: TransformationStack): directives.VersionDirective =
    new directives.VersionDirective(transform(major, transformationStack), transform(minor, transformationStack))

  def transformOutputDirective(id: Option[DataFormatId], mime: Option[ContentType], options: Option[Seq[DirectiveOption]], transformationStack: TransformationStack): directives.OutputDirective =
    new directives.OutputDirective(transformContentType(mime), transformOptionSeq(options, transformationStack), id.map(_.id))

  def transformVersionMinor(v: String): directives.VersionMinor = new directives.VersionMinor(Number(v))

  def transformVarDirective(variable: NameIdentifier, value: AstNode, codeAnnotations: Seq[AnnotationNode], transformationStack: TransformationStack): Directive = {
    if (value.isInstanceOf[AstFunctionNode]) {
      transformFunctionDirective(variable, value, codeAnnotations, value, transformationStack)
    } else {
      val lazyInit = variable.isAnnotatedWith(classOf[LazyVarDirectiveAnnotation])
      if (lazyInit) {
        new directives.LazyVarDirective(transform(variable, transformationStack), transform(value, transformationStack), needsMaterialization(variable))
      } else {
        new directives.VarDirective(transform(variable, transformationStack), transform(value, transformationStack), needsMaterialization(variable))
      }
    }
  }

  def transformFunctionNode(fn: functions.FunctionNode, functionName: Option[NameIdentifier] = None, transformationStack: TransformationStack): ValueNode[_]

  def createFunctionNode(fn: functions.FunctionNode, functionName: Option[NameIdentifier], bodyValue: ValueNode[_], transformationStack: TransformationStack): ValueNode[_]

  def transformOverloadedFunctionNode(ofn: OverloadedFunctionNode, name: String, transformationStack: TransformationStack): ValueNode[_]

  private def isRecursiveCallExpression(functionIdentifier: NameIdentifier): Boolean = {
    functionIdentifier.isAnnotatedWith(classOf[TailRecFunctionAnnotation])
  }

  def transformFunctionDirective(functionName: NameIdentifier, node: AstNode, codeAnnotations: Seq[AnnotationNode], location: WeaveLocationCapable, transformationStack: TransformationStack): FunctionDirective = {
    val functionValue: ValueNode[_] = node match {
      case fn: parser.ast.functions.FunctionNode => {
        transformationStack.pushNode(fn)
        //Detect tail recursion to do tail recursion elimination
        val functionNode = if (isRecursiveCallExpression(functionName)) {
          val recursiveCallBody = new TailRecFunctionBodyNode(transformTailRecursionElimination(functionName, fn.body, fn.params, transformationStack))
          recursiveCallBody._location = Some(location.location())
          createFunctionNode(fn, Some(functionName), recursiveCallBody, transformationStack)
        } else {
          val functionNode = transformFunctionNode(fn, Some(functionName), transformationStack)
          applyInterceptor(functionNode, functionName, codeAnnotations, transformationStack)
        }
        transformationStack.dropNode()
        functionNode
      }
      case ofn: parser.ast.functions.OverloadedFunctionNode => {
        transformationStack.pushNode(ofn)
        val overloadedFunctionNode = transformOverloadedFunctionNode(ofn, functionName.name, transformationStack)
        transformationStack.dropNode()
        overloadedFunctionNode
      }
    }

    new FunctionDirective(transform(functionName, transformationStack), functionValue)

  }

  def applyInterceptor(functionValue: ValueNode[_], functionName: NameIdentifier, codeAnnotations: Seq[AnnotationNode], transformationStack: TransformationStack): ValueNode[_]

  def transformReference(reference: NameIdentifier): Option[NameSlot]

  private def transformTailRecursionElimination(functionName: NameIdentifier, body: AstNode, params: FunctionParameters, transformationStack: TransformationStack): ValueNode[Any] = {
    val result = body match {
      case ifnode: IfNode => {
        val ifExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, ifnode.ifExpr, params, transformationStack)
        } else {
          transform(ifnode.ifExpr, transformationStack)
        }
        val elseExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, ifnode.elseExpr, params, transformationStack)
        } else {
          transform(ifnode.elseExpr, transformationStack)
        }
        new XIfNode(ifExpression, transform(ifnode.condition, transformationStack), elseExpression)
      }
      case defaultNode: DefaultNode => {
        val rhs: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, defaultNode.rhs, params, transformationStack)
        } else {
          transform(defaultNode.rhs, transformationStack)
        }
        new XDefaultNode(transform(defaultNode.lhs, transformationStack), rhs)
      }
      case unless: UnlessNode => {
        val ifExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, unless.ifExpr, params, transformationStack)
        } else {
          transform(unless.ifExpr, transformationStack)
        }

        val elseExpression: ValueNode[Any] = if (isRecursiveCallExpression(functionName)) {
          transformRecursiveCall(functionName, unless.elseExpr, params, transformationStack)
        } else {
          transform(unless.elseExpr, transformationStack)
        }
        new XUnlessNode(ifExpression, transform(unless.condition, transformationStack), elseExpression)
      }
      case doBlock: DoBlockNode => {
        new XDoBlockNode(transform(doBlock.header, transformationStack), transformTailRecursionElimination(functionName, doBlock.body, params, transformationStack))
      }
      case usingNode: UsingNode => {
        val varDirectives = usingNode.assignments.assignmentSeq.map((variable) => {
          val engineVariable: NameSlot = transform(variable.name, transformationStack)
          val engineAstNode: ValueNode[_] = transform(variable.value, transformationStack)
          new directives.VarDirective(engineVariable, engineAstNode, needsMaterialization(variable.name))
        })
        XUsingNode(varDirectives, transformTailRecursionElimination(functionName, usingNode.expr, params, transformationStack))
      }
      case patternMatcher: PatternMatcherNode => {
        val patterns = patternMatcher.patterns.patterns.map((pattern) => {
          transformPattern(functionName, pattern, params, transformationStack)
        })
        new XPatternMatcherNode(transform(patternMatcher.lhs, transformationStack), patterns.toArray)
      }
      case fcn: FunctionCallNode if fcn.isAnnotatedWith(classOf[TailRecFunctionCallAnnotation]) =>
        transformRecursiveCall(functionName, fcn, params, transformationStack)
      case _ => {
        transform(body, transformationStack)
      }
    }
    result._location = Some(body.location())
    result
  }

  private def transformPattern(functionName: NameIdentifier, body: PatternExpressionNode, params: FunctionParameters, transformationStack: TransformationStack) = {
    body match {
      case parser.ast.patterns.ExpressionPatternNode(pattern, name, function) => {
        transformExpressionPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.LiteralPatternNode(pattern, name, function) => {
        transformLiteralPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.RegexPatternNode(pattern, name, function) => {
        transformRegexPatternNode(pattern, name, function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.DefaultPatternNode(value, name) => {
        transformDefaultPatternNode(name, value, transformTailRecursionElimination(functionName, value, params, transformationStack), transformationStack)
      }
      case tpn: parser.ast.patterns.TypePatternNode =>
        transformTypePatternNode(tpn, transformTailRecursionElimination(functionName, tpn.onMatch, params, transformationStack), transformationStack)
      case parser.ast.patterns.EmptyArrayPatternNode(function) => {
        transformEmptyArrayNode(function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.DeconstructArrayPatternNode(head, tail, function) => {
        transformDeconstructArrayNode(head, tail, function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.EmptyObjectPatternNode(function) => {
        transformEmptyObjectNode(function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
      case parser.ast.patterns.DeconstructObjectPatternNode(headKey, headValue, tail, function) => {
        transformDeconstructObjectNode(headKey, headValue, tail, function, transformTailRecursionElimination(functionName, function, params, transformationStack), transformationStack)
      }
    }
  }

  private def transformRecursiveCall(variable: NameIdentifier, node: AstNode, params: FunctionParameters, transformationStack: TransformationStack): ValueNode[Any] = {
    node match {
      case functionCallNode: FunctionCallNode if functionCallNode.isAnnotatedWith(classOf[TailRecFunctionCallAnnotation]) => {
        var args = functionCallNode.args.args
        val paramsSeq = params.paramList
        //This way we can simply support default values by injecting them
        if (args.size < paramsSeq.size) {
          if (params.paramList.last.defaultValue.isDefined) {
            args = args ++ params.paramList.takeRight(paramsSeq.size - args.size).flatMap(_.defaultValue)
          } else {
            args = params.paramList.take(paramsSeq.size - args.size).flatMap(_.defaultValue) ++ args
          }
        }

        if (args.size != paramsSeq.size) {
          throw new CompilationExecutionException(functionCallNode.location(), s"Not enough arguments: `${args.size}`, function: `${variable.name}` requires: `${paramsSeq.size}`.")
        }

        val statements: Seq[XVarDirective] = paramsSeq.zipWithIndex.map((param) => {
          val paramName = param._1.variable
          new XVarDirective(transform(paramName, transformationStack), transform(args(param._2), transformationStack), needsMaterialization(paramName))
        })

        new InlineTailRecFunctionCall(statements.toArray)
      }
      case _ => transformTailRecursionElimination(variable, node, params, transformationStack)
    }
  }
}
