package org.mule.weave.v2.interpreted.transform

import org.mule.weave.v2.core.functions.BinaryFunctionValue
import org.mule.weave.v2.core.functions.UnaryFunctionValue
import org.mule.weave.v2.grammar._
import org.mule.weave.v2.interpreted.marker.ConstantArgumentAnnotation
import org.mule.weave.v2.interpreted.marker.DoBlockNodeNonShadowedVariablesAnnotation
import org.mule.weave.v2.interpreted.marker.FunctionCallNameAnnotation
import org.mule.weave.v2.interpreted.marker.ParameterReferenceFunctionCallArgumentAnnotation
import org.mule.weave.v2.interpreted.marker.RequiresMaterializationAnnotation
import org.mule.weave.v2.interpreted.marker.StaticFunctionCallAnnotation
import org.mule.weave.v2.interpreted.marker.StaticRecFunctionCallAnnotation
import org.mule.weave.v2.interpreted.node.AndNode
import org.mule.weave.v2.interpreted.node.BinaryFunctionCallNode
import org.mule.weave.v2.interpreted.node.ChainedBinaryOpNode
import org.mule.weave.v2.interpreted.node.ChainedFunctionCallNode
import org.mule.weave.v2.interpreted.node.DefaultFunctionCallNode
import org.mule.weave.v2.interpreted.node.DefaultNode
import org.mule.weave.v2.interpreted.node.EmptyFunctionCallNode
import org.mule.weave.v2.interpreted.node.IfNode
import org.mule.weave.v2.interpreted.node.NameSlot
import org.mule.weave.v2.interpreted.node.OrNode
import org.mule.weave.v2.interpreted.node.TernaryFunctionCallNode
import org.mule.weave.v2.interpreted.node.UnaryFunctionCallNode
import org.mule.weave.v2.interpreted.node.UnlessNode
import org.mule.weave.v2.interpreted.node.UsingNode
import org.mule.weave.v2.interpreted.node.ValueNode
import org.mule.weave.v2.interpreted.node.executors.BinaryExecutor
import org.mule.weave.v2.interpreted.node.executors.BinaryFunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.BinaryOpExecutor
import org.mule.weave.v2.interpreted.node.executors.BinaryStaticOverloadedFunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.DefaultFunctionCallExecutor
import org.mule.weave.v2.interpreted.node.executors.EmptyExecutor
import org.mule.weave.v2.interpreted.node.executors.EmptyFunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.FunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.TernaryExecutor
import org.mule.weave.v2.interpreted.node.executors.TernaryFunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.TernaryOverloadedStaticExecutor
import org.mule.weave.v2.interpreted.node.executors.UnaryExecutor
import org.mule.weave.v2.interpreted.node.executors.UnaryFunctionExecutor
import org.mule.weave.v2.interpreted.node.executors.UnaryOpExecutor
import org.mule.weave.v2.interpreted.node.executors.UnaryOverloadedFunctionExecutor
import org.mule.weave.v2.interpreted.node.structure.DoBlockNode
import org.mule.weave.v2.interpreted.node.structure.InlineDoBlockNode
import org.mule.weave.v2.interpreted.node.structure.header.directives
import org.mule.weave.v2.interpreted.node.{ BinaryOpNode => XBinaryOpNode }
import org.mule.weave.v2.parser.annotation.DescendantsIncludeThisAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.UsingVariableAssignment
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.operators.UnaryOpNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.NumberNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.runtime.core.operator.conversion.AsFunctionValue
import org.mule.weave.v2.runtime.core.operator.conversion.MetadataAdditionFunctionValue
import org.mule.weave.v2.runtime.core.operator.conversion.MetadataInjectorFunctionValue
import org.mule.weave.v2.runtime.core.operator.equality.EqOperator
import org.mule.weave.v2.runtime.core.operator.equality.IsOperator
import org.mule.weave.v2.runtime.core.operator.equality.NotEqOperator
import org.mule.weave.v2.runtime.core.operator.equality.SimilarOperator
import org.mule.weave.v2.runtime.core.operator.logical.NotOperator
import org.mule.weave.v2.runtime.core.operator.math._
import org.mule.weave.v2.runtime.core.operator.relational.GreaterOrEqualThanOperator
import org.mule.weave.v2.runtime.core.operator.relational.GreaterThanOperator
import org.mule.weave.v2.runtime.core.operator.relational.LessOrEqualThanOperator
import org.mule.weave.v2.runtime.core.operator.relational.LessThanOperator
import org.mule.weave.v2.runtime.core.operator.selectors._

trait EngineFunctionTransformations extends AstTransformation with EngineVariableTransformations with EngineImportTransformations {
  def transformUnlessNode(ifExpr: AstNode, condition: AstNode, elseExpr: AstNode): UnlessNode =
    new UnlessNode(transform(ifExpr), transform(condition), transform(elseExpr))

  def transformIfNode(expression: AstNode, condition: AstNode, elseExpr: AstNode): IfNode =
    new IfNode(transform(expression), transform(condition), transform(elseExpr))

  def transformDefaultNode(lhs: AstNode, rhs: AstNode): DefaultNode =
    new DefaultNode(transform(lhs), transform(rhs))

  private def isBinaryFunction(node: AstNode): Boolean = {
    node match {
      case FunctionCallNode(_, args, _, _) => args.args.size == 2
      case _                               => false
    }
  }

  def transformFunctionCallNode(fcn: FunctionCallNode): ValueNode[_] = {
    if (fcn.args.args.size == 2 && isBinaryFunction(fcn.args.args.head)) {
      //optimization for chained binary functions to avoid stack overflows
      var functionList = Seq[(AstNode, AstNode, Seq[AstNode])]()
      var argList = Seq[AstNode]()
      var currentFunction: (AstNode, AstNode, Seq[AstNode]) = (fcn.function, fcn.function, fcn.args.args)
      var currentArgs: Seq[AstNode] = fcn.args.args
      var continue: Boolean = true
      while (continue) {
        val left = currentArgs.head
        argList = currentArgs(1) +: argList
        functionList = currentFunction +: functionList
        left match {
          case fcn @ FunctionCallNode(innerFun, innerArgs, _, _) if innerArgs.args.size == 2 =>
            currentFunction = (innerFun, fcn, innerArgs.args)
            currentArgs = innerArgs.args
          case _ =>
            argList = left +: argList
            continue = false
        }
      }

      val executors: Array[FunctionExecutor] = functionList.map(ref => {
        val value: AstNode = ref._1
        createExecutor(value, ref._3, fcn)
      }).toArray
      new ChainedFunctionCallNode(executors, transformSeq(argList).toArray)
    } else {
      val functionExecutor: FunctionExecutor = createExecutor(fcn.function, fcn.args.args, fcn)
      functionExecutor match {
        case executor: EmptyExecutor   => new EmptyFunctionCallNode(executor)
        case executor: UnaryExecutor   => new UnaryFunctionCallNode(executor, transform(fcn.args.args.head))
        case executor: BinaryExecutor  => new BinaryFunctionCallNode(executor, transform(fcn.args.args.head), transform(fcn.args.args(1)))
        case executor: TernaryExecutor => new TernaryFunctionCallNode(executor, transform(fcn.args.args.head), transform(fcn.args.args(1)), transform(fcn.args.args(2)))
        case _                         => new DefaultFunctionCallNode(functionExecutor, transformSeq(fcn.args.args).toArray)
      }
    }
  }

  private def createExecutor(functionRefExpr: AstNode, args: Seq[AstNode], functionCallNode: FunctionCallNode): FunctionExecutor = {
    val isStaticRecursiveCall = functionCallNode.isAnnotatedWith(classOf[StaticRecFunctionCallAnnotation])
    val maybeFunctionCallNameAnnotation = functionRefExpr.annotation(classOf[FunctionCallNameAnnotation])
    val functionName = maybeFunctionCallNameAnnotation match {
      case Some(annotation) => annotation.functionName
      case None             => AstNodeHelper.ANONYMOUS_FUNCTION
    }
    args.length match {
      case 0 if isStaticFunctionRefWithArity(functionRefExpr, 0) =>
        new EmptyFunctionExecutor(transform(functionRefExpr), functionName, true, functionRefExpr.location())

      case 1 if isStaticFunctionRefWithArity(functionRefExpr, 1) =>
        val isFirstArgConstant = args.head.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        if (isOverloadedFunctionRef(functionRefExpr)) {
          new UnaryOverloadedFunctionExecutor(transform(functionRefExpr), functionName, isFirstArgConstant, true, functionRefExpr.location())
        } else {
          val isFirstArgConstantType = isFirstArgConstant || (isStaticRecursiveCall && args.head.isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          new UnaryFunctionExecutor(transform(functionRefExpr), functionName, isFirstArgConstant || isFirstArgConstantType, true, functionRefExpr.location())
        }

      case 2 if isStaticFunctionRefWithArity(functionRefExpr, 2) =>
        val isFirstArgConstant = args.head.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        val isSecondArtConstant = args.apply(1).isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        if (isOverloadedFunctionRef(functionRefExpr)) {
          new BinaryStaticOverloadedFunctionExecutor(transform(functionRefExpr), functionName, isFirstArgConstant, isSecondArtConstant, true, functionRefExpr.location())
        } else {
          val isFirstArgConstantType = isFirstArgConstant || (isStaticRecursiveCall && args.head.isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          val isSecondArgConstantType = isSecondArtConstant || (isStaticRecursiveCall && args.apply(1).isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          new BinaryFunctionExecutor(transform(functionRefExpr), functionName, isFirstArgConstantType, isSecondArgConstantType, true, functionRefExpr.location())
        }

      case 3 if isStaticFunctionRefWithArity(functionRefExpr, 3) =>
        val isFirstArgConstant = args.head.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        val isSecondArgConstant = args.apply(1).isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        val isThirdArgConstant = args.apply(2).isAnnotatedWith(classOf[ConstantArgumentAnnotation])
        if (isOverloadedFunctionRef(functionRefExpr)) {
          new TernaryOverloadedStaticExecutor(transform(functionRefExpr), functionName, isFirstArgConstant, isSecondArgConstant, isThirdArgConstant, true, functionRefExpr.location())
        } else {
          val isFirstArgConstantType = isFirstArgConstant || (isStaticRecursiveCall && args.head.isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          val isSecondArgConstantType = isSecondArgConstant || (isStaticRecursiveCall && args.apply(1).isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          val isThirdArgConstantType = isThirdArgConstant || (isStaticRecursiveCall && args.apply(2).isAnnotatedWith(classOf[ParameterReferenceFunctionCallArgumentAnnotation]))
          new TernaryFunctionExecutor(transform(functionRefExpr), functionName, isFirstArgConstantType, isSecondArgConstantType, isThirdArgConstantType, true, functionRefExpr.location())
        }

      case _ =>
        new DefaultFunctionCallExecutor(transform(functionRefExpr), true, functionRefExpr.location())
    }

  }

  def transformUsing(params: Seq[UsingVariableAssignment], expr: AstNode): UsingNode = {
    val varDirectives = params.map((variable) => {
      val engineVariable: NameSlot = transform(variable.name)
      val engineAstNode: ValueNode[_] = transform(variable.value)
      new directives.VarDirective(engineVariable, engineAstNode, needsMaterialization(variable.name))
    })
    new UsingNode(varDirectives, transform(expr))
  }

  def transformDoBlock(doBlock: org.mule.weave.v2.parser.ast.functions.DoBlockNode): ValueNode[_] = {
    if (doBlock.header.directives.isEmpty) {
      transform(doBlock.body)
    } else {
      val nonShadowedVariablesAnnotation = doBlock.isAnnotatedWith(classOf[DoBlockNodeNonShadowedVariablesAnnotation])
      if (nonShadowedVariablesAnnotation) {
        new InlineDoBlockNode(transform(doBlock.header), transform(doBlock.body))
      } else {
        new DoBlockNode(transform(doBlock.header), transform(doBlock.body))
      }
    }
  }

  def transformUnaryOpNode(uon: UnaryOpNode): ValueNode[_] = {
    val location = uon.location()
    val typeDispatchers: Array[_ <: UnaryFunctionValue] =
      uon.opId match {
        case MinusOpId                 => Array(new NumberMinusOperator(location))
        case AllSchemaSelectorOpId     => Array(new AllSchemaSelectorOperator(location))
        case DescendantsSelectorOpId   => DescendantsSelectorOperator.value(uon.annotation(classOf[DescendantsIncludeThisAnnotation]).isDefined, location)
        case NamespaceSelectorOpId     => Array(new NamespaceSelectorOperator(location))
        case NotOpId                   => Array(new NotOperator(location))
        case AllAttributesSelectorOpId => Array(new AllAttributesSelectorOperator(location))
      }
    val rightValue: ValueNode[_] = transform[ValueNode[_]](uon.rhs)
    val isConstantTypeArg = uon.rhs.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
    val unaryOpExecutor = new UnaryOpExecutor(typeDispatchers, uon.opId.name, uon.location(), isConstantTypeArg)
    new org.mule.weave.v2.interpreted.node.UnaryOpNode(unaryOpExecutor, rightValue)
  }

  def transformBinaryOpNode(bon: BinaryOpNode): ValueNode[_] = {
    val location = bon.location()
    val typeDispatchers: Array[BinaryFunctionValue] =
      bon.binaryOpId match {
        case AttributeValueSelectorOpId => Array(new AttributeValueSelectorOperator(location, bon.rhs.isInstanceOf[NameNode]))
        case MultiValueSelectorOpId     => Array(new KeyMultiValueSelectorOperator(location), new ArrayMultiValueSelectorOperator(location), new NullMultiValueSelectorOperator(location))
        case MultiAttributeValueSelectorOpId =>
          Array(
            new ArrayAttributeMultiValueSelectorOperator(location),
            new ObjectAttributeMultiValueSelectorOperator(location),
            new NullAttributeMultiValueSelectorOperator(location))
        case GreaterOrEqualThanOpId => Array(new GreaterOrEqualThanOperator(location))
        case AdditionOpId =>
          Array(
            new NumberAdditionOperator(location),
            new LocalDateAdditionPeriodOperator(location),
            new PeriodAdditionLocalDateOperator(location),
            new LocalDateTimeAdditionPeriodOperator(location),
            new PeriodAdditionLocalDateTimeOperator(location),
            new DateTimeAdditionPeriodOperator(location),
            new TimeAdditionPeriodOperator(location),
            new PeriodAdditionDateTimeOperator(location),
            new PeriodAdditionTimeOperator(location),
            new ArrayAdditionAnyOperator(location),
            new LocalTimeAdditionPeriodOperator(location),
            new PeriodAdditionLocalTimeOperator(location))
        case IsOpId => Array(new IsOperator(location))
        case SubtractionOpId =>
          Array(
            new NumberSubtractionNumberOperator(location),
            new ObjectSubtractionKeyOperator(location),
            new ObjectSubtractionStringOperator(location),
            new ArraySubtractionOperator(location),
            new LocalDateSubtractionPeriodOperator(location),
            new PeriodSubtractionLocalDateOperator(location),
            new LocalDateSubtractLocalDateOperator(location),
            new LocalDateTimeSubtractionPeriodOperator(location),
            new PeriodSubtractionLocalDateTimeOperator(location),
            new LocalDateTimeSubtractLocalDateTimeOperator(location),
            new DateTimeSubtractionPeriodOperator(location),
            new PeriodSubtractionDateTimeOperator(location),
            new DateTimeSubtractDateTimeOperator(location),
            new TimeSubtractionPeriodOperator(location),
            new PeriodSubtractionTimeOperator(location),
            new TimeSubtractTimeOperator(location),
            new LocalTimeSubtractionPeriodOperator(location),
            new PeriodSubtractionLocalTimeOperator(location),
            new LocalTimeSubtractLocalTimeOperator(location))
        case DivisionOpId       => Array(new NumberDivisionOperator(location))
        case MultiplicationOpId => Array(new NumberMultiplicationOperator(location))
        case RightShiftOpId     => Array(new DateTimeRightShiftOperator(location), new PrependArrayOperator(location))
        case LeftShiftOpId      => Array(new ArrayAdditionAnyOperator(location))
        case DynamicSelectorOpId =>
          if (bon.rhs.isInstanceOf[NumberNode]) {
            Array(
              new ArrayIndexSelectorOperator(location),
              new ObjectIndexSelectorOperator(location),
              new StringIndexSelectorOperator(location),
              new BinaryIndexSelectorOperator(location),
              new NullIndexSelectorOperator(location))
          } else {
            Array(
              new ObjectStringValueSelectorOperator(location, bon.rhs.isInstanceOf[StringNode]),
              new ObjectNameValueSelectorOperator(location, bon.rhs.isInstanceOf[NameNode]),
              new ObjectIndexSelectorOperator(location),
              new ArrayIndexSelectorOperator(location),
              new StringIndexSelectorOperator(location),
              new BinaryIndexSelectorOperator(location),
              new ArrayStringValueSelectorOperator(location, bon.rhs.isInstanceOf[StringNode]),
              new ArrayNameValueSelectorOperator(location, bon.rhs.isInstanceOf[NameNode]),
              new LocalDateValueSelectorOperator(location),
              new DateTimeValueSelectorOperator(location),
              new LocalDateTimeValueSelectorOperator(location),
              new TimeValueSelectorOperator(location),
              new LocalTimeValueSelectorOperator(location),
              new PeriodValueSelectorOperator(location),
              new ArrayRangeSelectorOperator(location),
              new BinaryRangeSelectorOperator(location),
              new StringRangeSelectorOperator(location),
              new NullIndexSelectorOperator(location),
              new NullNameValueSelectorOperator(location))
          }
        case EqOpId                  => Array(new EqOperator(location))
        case NotEqOpId               => Array(new NotEqOperator(location))
        case GreaterThanOpId         => Array(new GreaterThanOperator(location))
        case SchemaValueSelectorOpId => Array(new SchemaValueSelectorOperator(location))
        case ValueSelectorOpId =>
          Array(
            new ObjectNameValueSelectorOperator(location, bon.rhs.isInstanceOf[NameNode]),
            new ObjectStringValueSelectorOperator(location, bon.rhs.isInstanceOf[StringNode]),
            new ArrayNameValueSelectorOperator(location, bon.rhs.isInstanceOf[NameNode]),
            new ArrayStringValueSelectorOperator(location, bon.rhs.isInstanceOf[StringNode]),
            new LocalDateValueSelectorOperator(location),
            new DateTimeValueSelectorOperator(location),
            new LocalDateTimeValueSelectorOperator(location),
            new TimeValueSelectorOperator(location),
            new LocalTimeValueSelectorOperator(location),
            new PeriodValueSelectorOperator(location),
            new NullNameValueSelectorOperator(location))
        case ObjectKeyValueSelectorOpId => ObjectKeyValueSelector.value(location)
        case FilterSelectorOpId         => FilterSelectorOperator.value(location)
        case SimilarOpId                => Array(new SimilarOperator(location))
        case RangeSelectorOpId          => Array(new ArrayRangeSelectorOperator(location), new StringRangeSelectorOperator(location), new NullRangeSelectorOperator(location))
        case LessThanOpId               => Array(new LessThanOperator(location))
        case LessOrEqualThanOpId        => Array(new LessOrEqualThanOperator(location))
        case AsOpId =>
          val requiresMaterialize = bon.rhs.isAnnotatedWith(classOf[RequiresMaterializationAnnotation])
          Array(new AsFunctionValue(location, requiresMaterialize))
        case MetadataInjectorOpId => Array(new MetadataInjectorFunctionValue(location))
        case MetadataAdditionOpId => Array(new MetadataAdditionFunctionValue(location))
      }

    val leftTransformation: ValueNode[_] = transform(bon.lhs)
    val isLeftArgConstant = bon.lhs.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
    val rightTransformation: ValueNode[_] = transform(bon.rhs)
    val isRightArgConstant = bon.rhs.isAnnotatedWith(classOf[ConstantArgumentAnnotation])
    val executor = new BinaryOpExecutor(typeDispatchers, bon.binaryOpId.name, bon.location(), isLeftArgConstant, isRightArgConstant)
    leftTransformation match {
      case xbon: XBinaryOpNode => {
        new ChainedBinaryOpNode(Array(xbon.binaryOpExecutor, executor), Array(xbon.lhs, xbon.rhs, rightTransformation))
      }
      case cen: ChainedBinaryOpNode => {
        new ChainedBinaryOpNode(cen.operations :+ executor, cen.nodes :+ rightTransformation)
      }
      case _ => {
        new XBinaryOpNode(executor, leftTransformation, rightTransformation)
      }
    }

  }

  def transformAndNode(lhs: AstNode, rhs: AstNode): AndNode = new AndNode(transform(lhs), transform(rhs))

  def transformOrNode(lhs: AstNode, rhs: AstNode): OrNode = new OrNode(transform(lhs), transform(rhs))

  private def isStaticFunctionRefWithArity(node: AstNode, arity: Int): Boolean = {
    node match {
      case vrn: VariableReferenceNode =>
        val maybeStaticFunctionCallAnnotation = vrn.annotation(classOf[StaticFunctionCallAnnotation])
        if (maybeStaticFunctionCallAnnotation.isDefined) {
          val staticFunctionCallAnnotation = maybeStaticFunctionCallAnnotation.get
          staticFunctionCallAnnotation.isFunctionOfArity(arity)
        } else {
          false
        }
      case _ =>
        false
    }
  }

  private def isOverloadedFunctionRef(node: AstNode): Boolean = {
    node match {
      case vrn: VariableReferenceNode =>
        val maybeStaticFunctionCallAnnotation = vrn.annotation(classOf[StaticFunctionCallAnnotation])
        if (maybeStaticFunctionCallAnnotation.isDefined) {
          maybeStaticFunctionCallAnnotation.get.isOverloaded
        } else {
          false
        }
      case _ =>
        false
    }
  }
}
