package org.mule.weave.v2.api.expression

import org.mule.weave.v2.api.tooling.expression.Expression
import org.mule.weave.v2.codegen.CodeGenerator
import org.mule.weave.v2.grammar.AttributeValueSelectorOpId
import org.mule.weave.v2.grammar.BinaryOpIdentifier
import org.mule.weave.v2.grammar.MultiAttributeValueSelectorOpId
import org.mule.weave.v2.grammar.MultiValueSelectorOpId
import org.mule.weave.v2.grammar.SchemaValueSelectorOpId
import org.mule.weave.v2.grammar.ValueSelectorOpId
import org.mule.weave.v2.parser.ErrorAstNode
import org.mule.weave.v2.parser.annotation.InfixNotationFunctionCallAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.LiteralValueAstNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallParametersNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.selectors.NullSafeNode
import org.mule.weave.v2.parser.ast.structure.ArrayNode
import org.mule.weave.v2.parser.ast.structure.KeyNode
import org.mule.weave.v2.parser.ast.structure.KeyValuePairNode
import org.mule.weave.v2.parser.ast.structure.NameNode
import org.mule.weave.v2.parser.ast.structure.ObjectNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode

import java.util.Optional
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer

object ExpressionMapper {

  def deleteOutputAtPath(node: AstNode, pathSegments: Array[String]): AstNode = {
    val maybeAst = doDeleteAtPath(node, pathSegments)
    if (maybeAst.isDefined) {
      maybeAst.get
    } else {
      //If we deleted everything inside the node we return an empty object
      node match {
        case _: ArrayNode => ArrayNode(Seq())
        case _            => ObjectNode(Seq())
      }
    }
  }

  private def doDeleteAtPath(astNode: AstNode, pathSegments: Array[String]): Option[AstNode] = {
    if (pathSegments.length == 0) { //Empty segment means root path was passed
      return None
    }

    astNode match {
      case fcn @ FunctionCallNode(VariableReferenceNode(NameIdentifier("++", _), _), argsNode, _, _) =>
        val args = argsNode.args
        val updatedArgs = args.flatMap(doDeleteAtPath(_, pathSegments))

        if (updatedArgs.size == 2) { //Both args are non empty we update concatenation
          val updatedArgsNode = FunctionCallParametersNode(updatedArgs)
          fcn.update(argsNode, updatedArgsNode)
          return Some(fcn)
        } else if (updatedArgs.size == 1) { //One of the concatenated args was empty, we remove concatenation
          return Some(updatedArgs.head)
        } else { //Both args were empty, we delete node
          return None
        }
      case ArrayNode(elems, codeAnnotations) =>
        val newElems: Seq[AstNode] = elems.flatMap(doDeleteAtPath(_, pathSegments))
        return Some(ArrayNode(newElems, codeAnnotations))
      case _ =>
    }

    val currentSegment = pathSegments.head

    if (currentSegment.isEmpty) {
      if (pathSegments.tail.isEmpty) {
        return None
      }

      return doDeleteAtPath(astNode, pathSegments.tail)
    } else if (pathSegments.length == 1) {
      val leafSegment = pathSegments.head
      return astNode match {
        case ObjectNode(elems, codeAnnotations) =>
          val newElems = elems.filter {
            case KeyValuePairNode(KeyNode(StringNode(name, _), _, _, _), _, _) => name != leafSegment
            case _ => true
          }

          if (newElems.isEmpty) {
            None
          } else {
            Some(ObjectNode(newElems, codeAnnotations.map(_.cloneAst())))
          }
        case mapCall @ FunctionCallNode(VariableReferenceNode(NameIdentifier("map", _), _), FunctionCallParametersNode(Seq(_, fn @ FunctionNode(_, body, _, _))), _, _) =>
          val updatedBody = doDeleteAtPath(body, pathSegments)
          if (updatedBody.isDefined) {
            fn.update(body, updatedBody.get)
            Some(mapCall)
          } else {
            None
          }
      }
    }

    val curSegment = pathSegments.head

    astNode match {
      case ObjectNode(elems, codeAnnotations) =>
        val newElems: Seq[AstNode] = elems.flatMap {
          case kvpn @ KeyValuePairNode(KeyNode(StringNode(name, _), _, _, _), value, _) if name == curSegment =>
            val newVal = doDeleteAtPath(value, pathSegments.tail)

            if (newVal.isEmpty) {
              None
            } else {
              kvpn.update(value, newVal.get)
              Some(kvpn)
            }
          case n => Some(n)
        }

        if (newElems.isEmpty) {
          None
        } else {
          Some(ObjectNode(newElems, codeAnnotations))
        }
      case mapCall @ FunctionCallNode(VariableReferenceNode(NameIdentifier("map", _), _), FunctionCallParametersNode(Seq(_, fn @ FunctionNode(_, body, _, _))), _, _) =>
        val updatedBody = doDeleteAtPath(body, pathSegments)
        if (updatedBody.isDefined) {
          fn.update(body, updatedBody.get)
          Some(mapCall)
        } else {
          None
        }
    }
  }

  def getScriptMappingExpression(astNode: AstNode, inputVars: List[String]): Expression = {
    doGetScriptMappingExpression(astNode, "", inputVars, Map.empty)
  }

  @tailrec
  private def getNextNode(node: AstNode, name: String): AstNode = node match {
    case ObjectNode(children, _) =>
      val maybeMatch = children.find {
        case KeyValuePairNode(KeyNode(StringNode(keyName, _), _, _, _), _, _) => name == keyName
        case _ => false
      }

      maybeMatch match {
        case Some(KeyValuePairNode(_, value, _)) => value
        case _                                   => ObjectNode(Seq())
      }
    case FunctionCallNode(VariableReferenceNode(NameIdentifier("map", _), _), FunctionCallParametersNode(Seq(_, _@ FunctionNode(_, body, _, _))), _, _) => getNextNode(body, name) //If we have a map go into body
    case _ => node
  }

  def scaffoldOutputForPath(outputSegments: Array[String], node: AstNode, expressionAst: AstNode): AstNode = {
    if (outputSegments.isEmpty) {
      return expressionAst
    }

    node match {
      case fcn @ FunctionCallNode(VariableReferenceNode(NameIdentifier("++", _), _), argsNode, _, _) =>
        val newArgs = argsNode.args.map(scaffoldOutputForPath(outputSegments, _, expressionAst))
        val updatedArgsNode = FunctionCallParametersNode(newArgs)
        fcn.update(argsNode, updatedArgsNode)
        return fcn
      case ArrayNode(elems, codeAnnotations) =>
        val newElems = elems.map(e => {
          val newElem = scaffoldOutputForPath(outputSegments, e, expressionAst)
          getOrCreateCurrentLevel(newElem, e)
        })
        return ArrayNode(newElems, codeAnnotations)
      case _ =>
    }

    val currentSegment = outputSegments.head
    if (currentSegment.isEmpty) { //Empty segment means we are at the root of the output
      if (outputSegments.tail.isEmpty) {
        return expressionAst
      }

      val newElem = scaffoldOutputForPath(outputSegments.tail, node, expressionAst)
      return getOrCreateCurrentLevel(newElem, node)
    } else if (outputSegments.length == 1) {
      return KeyValuePairNode(KeyNode(StringNode(currentSegment)), expressionAst)
    }

    val nextNode = getNextNode(node, currentSegment)
    val newExpr = scaffoldOutputForPath(outputSegments.tail, nextNode, expressionAst)

    KeyValuePairNode(KeyNode(StringNode(currentSegment)), getOrCreateCurrentLevel(newExpr, nextNode))
  }

  private def getOrCreateCurrentLevel(newElem: AstNode, node: AstNode): AstNode = {
    node match {
      case ObjectNode(elems, codeAnnotations) =>
        val newKey = newElem match {
          case KeyValuePairNode(KeyNode(StringNode(name, _), _, _, _), _, _) => name
          case _ => ""
        }
        var alreadyReplaced = false

        //Try to replace it keeping order in the object
        var newElems = elems.foldLeft(Seq[AstNode]())((acum, curNode) => {
          val alreadyPresent = curNode match {
            case KeyValuePairNode(KeyNode(StringNode(keyName, _), _, _, _), _, _) => keyName == newKey
            case _ => true
          }

          if (alreadyPresent && !alreadyReplaced) {
            alreadyReplaced = true
            acum :+ newElem
          } else {
            acum :+ curNode
          }
        })

        //If it was not present add it as the last children
        if (!alreadyReplaced) {
          newElems = newElems :+ newElem
        }
        ObjectNode(newElems, codeAnnotations)
      case fcn @ FunctionCallNode(VariableReferenceNode(NameIdentifier("map", _), _), FunctionCallParametersNode(Seq(_, fn @ FunctionNode(_, body, _, _))), _, _) =>
        fn.update(body, getOrCreateCurrentLevel(newElem, body))
        fcn
      case _@ FunctionCallNode(VariableReferenceNode(NameIdentifier("++", _), _), _, _, _) if newElem.isInstanceOf[FunctionCallNode] => newElem
      case _: ArrayNode if newElem.isInstanceOf[ArrayNode] => newElem
      case _ =>
        /**
          * If target is not an object then we perform the update by concatenating a new one to
          * the already existing node
          */
        val newNode = ObjectNode(Seq(newElem))
        FunctionCallNode(
          VariableReferenceNode(NameIdentifier("++"), Seq()),
          FunctionCallParametersNode(Seq(node, newNode))).annotate(InfixNotationFunctionCallAnnotation())
    }
  }

  private def doGetScriptMappingExpression(
    astNode: AstNode,
    path: String,
    inputVars: List[String],
    boundedInputsTranslation: Map[String, String]): Expression = {
    val location = astNode.location()
    astNode match {
      case ObjectNode(elements, _) =>
        val children = elements.map(child => doGetScriptMappingExpression(child, path, inputVars, boundedInputsTranslation))
        DWObjectExpression(children.toArray, path, location)

      case KeyValuePairNode(key @ KeyNode(keyName, _, _, _), value, _) =>
        //TODO: Should this be only for KeyNode(StringNode) case?
        val keyNameStr = CodeGenerator.generate(keyName)
        val newPath = s"$path.$keyNameStr"

        val keyExpr = doGetScriptMappingExpression(key, newPath, inputVars, boundedInputsTranslation)
        val valueExpr = doGetScriptMappingExpression(value, newPath, inputVars, boundedInputsTranslation)

        DWAssignmentExpression(keyExpr, valueExpr, newPath, location)

      case _: KeyNode => DWLiteralExpression(path)

      case bon @ BinaryOpNode(opNode, lhs, rhs, _) if isSelectorOpId(opNode) =>
        val sourceExpr = doGetScriptMappingExpression(lhs, path, inputVars, boundedInputsTranslation)
        val attrExpr = doGetScriptMappingExpression(rhs, path, inputVars, boundedInputsTranslation)
        val inputPath: Optional[String] = collectExpressionInputPath(bon, inputVars, boundedInputsTranslation) match {
          case ArrayBuffer(head, _*) => Optional.of(head)
          case _                     => Optional.empty()
        }
        DWValueSelectionExpression(sourceExpr, attrExpr, path, inputPath, location)

      case FunctionCallNode(
        VariableReferenceNode(NameIdentifier("map", _), _),
        FunctionCallParametersNode(Seq(collection, FunctionNode(FunctionParameters(Seq(item, index)), body, _, _))),
        _, _
        ) =>
        val collectionExpr = doGetScriptMappingExpression(collection, path, inputVars, boundedInputsTranslation)
        val collectionInputPath = collectExpressionInputPath(collection, inputVars, boundedInputsTranslation)

        val inputTranslation = if (collectionInputPath.nonEmpty) {
          boundedInputsTranslation + (item.nameIdentifier.name -> collectionInputPath.head)
        } else {
          boundedInputsTranslation
        }

        val bodyExpr = doGetScriptMappingExpression(body, path, inputVars, inputTranslation)

        DWMappingExpression(
          collectionExpr,
          item.nameIdentifier.name,
          index.nameIdentifier.name,
          bodyExpr,
          path,
          location)
      case FunctionCallNode(func, args, _, _) =>
        val functionName = func match {
          case VariableReferenceNode(NameIdentifier(name, _), _) => name
          case _ => "Anonymous"
        }
        val argExpressions = args.args.map(arg => doGetScriptMappingExpression(arg, path, inputVars, boundedInputsTranslation))
        DWFunctionCallExpression(argExpressions.toArray, functionName, path, location)
      case NullSafeNode(node, _) => doGetScriptMappingExpression(node, path, inputVars, boundedInputsTranslation)
      case VariableReferenceNode(name, _) =>
        val inputPath: Optional[String] = if (inputVars.contains(name.name)) {
          Optional.of(name.name)
        } else if (boundedInputsTranslation.contains(name.name)) {
          Optional.of(boundedInputsTranslation(name.name))
        } else {
          Optional.empty()
        }
        DWLiteralExpression(path, location, inputPath)
      case en: ErrorAstNode                                        => DWUndefinedExpression(en.message().getMessage, path, location)
      case n: LiteralValueAstNode if !n.isInstanceOf[ErrorAstNode] => DWLiteralExpression(path)
      case ArrayNode(elems, _) =>
        val elemsExpr = elems.map(doGetScriptMappingExpression(_, path, inputVars, boundedInputsTranslation))
        DWArrayExpression(elemsExpr.toArray, path, location)
      case n =>
        val inputPaths = collectExpressionInputPath(n, inputVars, boundedInputsTranslation)
        DWComplexExpression(path, location, inputPaths.toArray)
    }
  }

  private def isSelectorOpId(opId: BinaryOpIdentifier): Boolean = {
    val ids: Seq[BinaryOpIdentifier] = Seq(
      ValueSelectorOpId,
      AttributeValueSelectorOpId,
      MultiValueSelectorOpId,
      MultiAttributeValueSelectorOpId,
      SchemaValueSelectorOpId)
    ids.contains(opId)
  }

  private def collectExpressionInputPath(inputAstNode: AstNode, inputVars: List[String], boundedVarsTranslation: Map[String, String]): ArrayBuffer[String] = {
    def maybeGetInputVar(s: String): Option[String] = {
      if (inputVars.contains(s)) {
        Some(s)
      } else if (boundedVarsTranslation.contains(s)) {
        Some(boundedVarsTranslation(s))
      } else {
        None
      }
    }

    val inputs: ArrayBuffer[String] = ArrayBuffer.empty
    AstNodeHelper.traverse(inputAstNode, {
      case BinaryOpNode(op, VariableReferenceNode(name, _), NameNode(StringNode(s, _), _, _), _) if isSelectorOpId(op) =>
        val maybeInputVar = maybeGetInputVar(name.name)
        if (maybeInputVar.isDefined) {
          if (!inputs.contains(s"${maybeInputVar.get}.$s")) inputs += s"${maybeInputVar.get}.$s"
        }
        false
      case BinaryOpNode(op, lhs, NameNode(StringNode(s, _), _, _), _) if isSelectorOpId(op) =>
        val maybePath = collectExpressionInputPath(lhs, inputVars, boundedVarsTranslation)
        maybePath.map(p => if (!inputs.contains(s"$p.$s")) inputs += s"$p.$s")
        false
      case VariableReferenceNode(variable, _) =>
        maybeGetInputVar(variable.name).map(p => inputs += p)
        false
      case _ => true
    })

    inputs
  }
}