package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.AstNodeHelper
import org.mule.weave.v2.parser.ast.functions.DoBlockNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.OverloadedFunctionNode
import org.mule.weave.v2.parser.ast.header.directives.DirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.DocumentNode
import org.mule.weave.v2.parser.location.Position
import org.mule.weave.v2.parser.location.WeaveLocation

import scala.collection.mutable

/**
  * In this phase the overloaded functions will be aggregated into a composite one
  *
  * @tparam T the ast node type
  */
class FunctionAggregationPhase[T <: AstNode] extends FullCompileOnlyPhase[T, ParsingResult[T]] {

  override def run(source: ParsingResult[T], context: ParsingContext): Unit = {
    val rootNode = source.astNode

    rootNode match {
      case mn: DocumentNode => {
        mn.header.directives = mapDirectives(mn.header.directives)
      }
      case mn: ModuleNode => {
        mn.elements = mapDirectives(mn.directives)
      }
    }

    val doBlocks: Seq[DoBlockNode] = AstNodeHelper.collectChildrenWith(source.astNode, classOf[DoBlockNode])
    doBlocks.foreach(node => {
      node.header.directives = mapDirectives(node.header.directives)
    })
  }

  private def mapDirectives(directives: Seq[DirectiveNode]): Seq[DirectiveNode] = {
    val functionDirectiveNodes = directives.collect({
      case fdn: FunctionDirectiveNode => fdn
    })
    val resolvedFunctionByName: Map[String, Seq[FunctionDirectiveNode]] =
      functionDirectiveNodes
        .filter(_.literal.isInstanceOf[FunctionNode])
        .groupBy(_.variable.name)

    val processedFunctions: mutable.ArrayBuffer[String] = mutable.ArrayBuffer[String]()

    directives
      .flatMap({
        case fdn: FunctionDirectiveNode if fdn.literal.isInstanceOf[FunctionNode] => {
          val functionName = fdn.variable.name
          if (processedFunctions.contains(functionName)) {
            //We filter already processed functions
            None
          } else {
            processedFunctions.+=(functionName)
            val overloads: Seq[FunctionDirectiveNode] = resolvedFunctionByName(functionName)
            if (overloads.size == 1) {
              Some(fdn)
            } else {
              val directiveNode: FunctionDirectiveNode = FunctionDirectiveNode(variable = fdn.variable.cloneAst(), literal = OverloadedFunctionNode(overloads))
              fdn.copyCommentsTo(directiveNode)
              val startPositions: Seq[Position] = overloads.map(_.location().startPosition)
              val endPositions: Seq[Position] = overloads.map(_.location().endPosition)
              directiveNode._location = Some(WeaveLocation(startPositions.min, endPositions.max, directiveNode.children().head.location().resourceName))
              Some(directiveNode)
            }
          }
        }
        case directive => Some(directive)
      })
      .sortBy(_.location().endPosition.index)
    //This sort is needed in order to be able to find nodes as overloaded can mess all the indexes
  }
}
