package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.structure.KeyNode
import org.mule.weave.v2.parser.ast.structure.KeyValuePairNode
import org.mule.weave.v2.parser.ast.structure.ObjectNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.phase.ParsingContext.CORE_MODULE
import org.mule.weave.v2.scope.AstNavigator
import org.mule.weave.v2.scope.ScopesNavigator

import scala.collection.mutable.ListBuffer

class LoggingContextInjectionPhase[T <: AstNode]() extends CompilationPhase[ScopeGraphResult[T], ScopeGraphResult[T]] {

  private val logFunctionNames = List("log", "logWith", "logDebug", "logInfo", "logWarn", "logError")
  private var mutated = false

  override def doCall(source: ScopeGraphResult[T], ctx: ParsingContext): PhaseResult[ScopeGraphResult[T]] = {
    val navigator: ScopesNavigator = source.scope
    val astNavigator = navigator.rootScope.astNavigator()
    val moduleName = source.astNode.location().resourceName.name

    AstNodeHelper.traverseChildren(source.astNode, {
      case node: FunctionCallNode if (node.function.isInstanceOf[VariableReferenceNode]) =>
        val functionVrn = node.function.asInstanceOf[VariableReferenceNode]
        val maybeReference = navigator.resolveVariable(functionVrn.variable)
        maybeReference match {
          case Some(ref) =>
            ref.moduleSource match {
              case Some(ms) =>
                if (ms.name == CORE_MODULE && logFunctionNames.contains(ref.referencedNode.name)) {
                  val elems = ListBuffer(KeyValuePairNode(KeyNode(StringNode("module")), StringNode(moduleName)))
                  val maybeFunName = parentFunctionName(node, astNavigator)
                  maybeFunName match {
                    case Some(name) => elems += KeyValuePairNode(KeyNode(StringNode("function")), StringNode(name))
                    case _          =>
                  }
                  val logContext = ObjectNode(elems)

                  val logInternalNi = NameIdentifier.CORE_MODULE.::("logInternal")
                  logInternalNi._location = functionVrn.variable._location
                  val logInternalVrn = VariableReferenceNode(logInternalNi, functionVrn.codeAnnotations)
                  logInternalVrn._location = functionVrn._location
                  node.update(node.function, logInternalVrn)
                  val args = node.args
                  // add default prefix if needed
                  if (args.args.length == 1) {
                    args.args = StringNode("") +: args.args
                  }
                  // add logLevel if needed
                  if (args.args.length == 2) {
                    ref.referencedNode.name match {
                      case "log" | "logInfo" => args.args = StringNode("Info") +: args.args
                      case "logDebug"        => args.args = StringNode("Debug") +: args.args
                      case "logWarn"         => args.args = StringNode("Warn") +: args.args
                      case "logError"        => args.args = StringNode("Error") +: args.args
                    }
                  }
                  args.args = logContext +: args.args
                  mutated = true
                }
              case None =>
            }
          case None =>
        }
        true
      case _ => true
    })

    if (mutated) {
      // We rebuild the scope graph if it was astGraphMutated
      val scopeGraphPhase = new ScopeGraphPhase[T]()
      scopeGraphPhase.call(source, ctx)
    } else {
      SuccessResult(source, ctx)
    }
  }

  private def parentFunctionName(node: AstNode, astNavigator: AstNavigator): Option[String] = {
    val maybeNode = astNavigator.parentOf(node)
    maybeNode match {
      case Some(n) if (n.isInstanceOf[FunctionDirectiveNode]) =>
        Some(n.asInstanceOf[FunctionDirectiveNode].variable.name)
      case Some(n) => parentFunctionName(n, astNavigator)
      case None    => None
    }
  }

}
