package org.mule.weave.v2.ts

import org.mule.weave.v2.parser.{ Message, MessageCollector }
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.header.directives.VersionDirective
import org.mule.weave.v2.parser.location.UnknownLocation
import org.mule.weave.v2.parser.location.WeaveLocation
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.scope.ScopesNavigator
import org.mule.weave.v2.utils.IdentityHashMap

import scala.collection.Seq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class WeaveTypeResolutionContext(val rootGraph: TypeGraph) {

  private val _executors = mutable.Stack[WeaveTypeResolutionContextValues]()

  def clearMessageFor(node: TypeNode): Unit = {
    currentNodeMessageCollector.clearMessagesFor(node)
  }

  def error(message: Message, node: TypeNode): Unit = {
    currentNodeMessageCollector.error(message, node)
  }

  def error(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    currentNodeMessageCollector.error(message, node, location)
  }

  def warning(message: Message, node: TypeNode): Unit = {
    currentNodeMessageCollector.warning(message, node)
  }

  def warning(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    currentNodeMessageCollector.warning(message, node, location)
  }

  /**
    * Report warning for a node using an arbitrary location. We use it for example for reporting
    * function warnings using the function name identifier location.
    */
  def unsafeLocationWarning(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    currentNodeMessageCollector.unsafeLocationWarning(message, node, location)
  }

  def newExecutorWithContext(scope: ScopesNavigator, dataGraph: TypeGraph, parsingContext: ParsingContext): WeaveTypePropagator = {
    val propagator: WeaveTypePropagator = new WeaveTypePropagator(this)
    _executors.push(WeaveTypeResolutionContextValues(dataGraph, propagator, scope, parsingContext, new NodeBaseMessageCollector()))
    propagator
  }

  def newReverseExecutorWithContext(scope: ScopesNavigator, dataGraph: TypeGraph, parsingContext: ParsingContext): ReverseWeaveTypePropagator = {
    val propagator = new ReverseWeaveTypePropagator(this)
    _executors.push(WeaveTypeResolutionContextValues(dataGraph, propagator, scope, parsingContext, new NodeBaseMessageCollector()))
    propagator
  }

  def endContext(): Unit = {
    reportErrorMessages(currentNodeMessageCollector.nodeErrors)
    reportWarningMessages(currentNodeMessageCollector.nodeWarnings)
    _executors.pop()
  }

  private def reportErrorMessages(messages: Seq[(WeaveLocation, Message)]): Unit = {
    messages.foreach((message) => {
      currentParsingContext.messageCollector.error(message._2, message._1)
    })
  }

  private def reportWarningMessages(messages: Seq[(WeaveLocation, Message)]): Unit = {
    messages.foreach((message) => {
      currentParsingContext.messageCollector.warning(message._2, message._1)
    })
  }

  /**
    * Returns the current executor that is being used for execution of the graph
    *
    * @return
    */
  def currentExecutor: BaseWeaveTypePropagator = _executors.top.bwtp

  def currentNodeMessageCollector: NodeBaseMessageCollector = _executors.top.nbmc

  def currentSyntaxVersion: VersionDirective = _executors.top.tg.syntaxVersion

  /**
    * Returns the graph that is being currently under execution
    *
    * @return
    */
  def currentGraph: TypeGraph = _executors.top.tg

  /**
    * Return the
    *
    * @return
    */
  def currentScopeNavigator: ScopesNavigator = _executors.top.sn

  def currentParsingContext: ParsingContext = _executors.top.pc

  /**
    * Returns the subgraph for the specified function with the given parameter types
    *
    * @return The Graph if defined
    */
  def getFunctionSubGraphs(functionNode: FunctionNode): Option[ArrayBuffer[(Seq[WeaveType], TypeGraph, MessageCollector)]] = {
    rootGraph.getFunctionSubGraphs(functionNode)
  }

  /**
    * Returns the subgraph for the specified function with the given parameter types
    *
    * @return The Graph if defined
    */
  def getFunctionSubGraph(functionNode: FunctionNode, parameterTypes: Seq[WeaveType]): Option[(TypeGraph, MessageCollector)] = {
    rootGraph.getFunctionSubGraph(functionNode, parameterTypes)
  }

  /**
    * Removes the sub graph for the given function and parameters type
    *
    */
  def removeFunctionSubGraph(functionNode: FunctionNode, parameterTypes: Seq[WeaveType]): Unit = {
    rootGraph.removeFunctionSubGraph(functionNode, parameterTypes)
  }

  /**
    * Adds a new subgraph for the specified function with the given parameter types
    */
  def addFunctionSubGraph(functionNode: FunctionNode, parameterTypes: Seq[WeaveType], graph: TypeGraph, messageCollector: MessageCollector): Unit = {
    rootGraph.addFunctionSubGraph(functionNode, parameterTypes, graph, messageCollector)
  }
}

case class WeaveTypeResolutionContextValues(tg: TypeGraph, bwtp: BaseWeaveTypePropagator, sn: ScopesNavigator, pc: ParsingContext, nbmc: NodeBaseMessageCollector)

class NodeBaseMessageCollector {
  private val _errors = IdentityHashMap[AstNode, ArrayBuffer[(WeaveLocation, Message)]]()
  private val _warnings = IdentityHashMap[AstNode, ArrayBuffer[(WeaveLocation, Message)]]()

  def nodeErrors: Seq[(WeaveLocation, Message)] = {
    _errors.toSeq.flatMap(_._2).sortBy(_._1.startPosition.index)
  }

  def nodeWarnings: Seq[(WeaveLocation, Message)] = {
    _warnings.toSeq.flatMap(_._2).sortBy(_._1.startPosition.index)
  }

  def error(message: Message, node: TypeNode): Unit = {
    error(message, node, node.location())
  }

  def error(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    val correctLocation: WeaveLocation = getCorrectLocation(node, location)
    _errors.getOrElseUpdate(node.astNode, ArrayBuffer()).+=((correctLocation, message))
  }

  private def getCorrectLocation(node: TypeNode, location: WeaveLocation) = {
    val correctLocation = if (location.eq(UnknownLocation) || !node.location().contains(location)) {
      node.location()
    } else {
      location
    }
    correctLocation
  }

  def warning(message: Message, node: TypeNode): Unit = {
    warning(message, node, node.location())
  }

  def warning(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    val correctLocation: WeaveLocation = getCorrectLocation(node, location)
    _warnings.getOrElseUpdate(node.astNode, ArrayBuffer()).+=((correctLocation, message))
  }

  /**
    * Report warning for a node using an arbitrary location. We use it for example for reporting
    * function warnings using the function name identifier location.
    */
  def unsafeLocationWarning(message: Message, node: TypeNode, location: WeaveLocation): Unit = {
    _warnings.getOrElseUpdate(node.astNode, ArrayBuffer()).+=((location, message))
  }

  def clearMessagesFor(node: TypeNode): Unit = {
    _errors.remove(node.astNode)
    _warnings.remove(node.astNode)
  }
}

trait BaseWeaveTypePropagator {

  def scheduleNode(node: TypeNode): Unit

  def run(): Unit
}

class WeaveTypePropagator(ctx: WeaveTypeResolutionContext) extends BaseWeaveTypePropagator {
  val executionStack: mutable.ListBuffer[TypeNode] = mutable.ListBuffer()

  override def scheduleNode(node: TypeNode): Unit = {
    val foundNode = executionStack.indexWhere((scheduledNode) => scheduledNode eq node)
    if (foundNode >= 0) {
      executionStack.remove(foundNode)
    }
    executionStack.+=(node)
  }

  override def run(): Unit = {
    scheduleNodes()
    start()
    ctx.endContext()
  }

  private def start(): Unit = {
    while (executionStack.nonEmpty) {
      val nodeToExecute: TypeNode = executionStack.remove(0)
      nodeToExecute.resolve(ctx)
    }
  }

  private def scheduleNodes(): Unit = {
    ctx.currentGraph.nodes.foreach((node) => {
      val incomingEdges: Seq[Edge] = node.incomingEdges()
      if (node.allDependenciesResolved()) {
        scheduleNode(node)
      } else if (incomingEdges.exists(_.crossGraphEdge())) {
        val crossGraph: Seq[Edge] = incomingEdges.filter((edge) => edge.crossGraphEdge())
        crossGraph.foreach(scheduleDependencies(_, mutable.Stack()))
      }
    })
  }

  private def scheduleDependencies(edge: Edge, stack: mutable.Stack[TypeNode]): Unit = {
    if (!edge.incomingTypeDefined()) {
      if (edge.source.resultType().isDefined) {
        edge.propagateType(edge.source.resultType().get, ctx)
      } else if (edge.source.incomingEdges().isEmpty) {
        scheduleNode(edge.source)
      } else {
        if (!stack.contains(edge.source)) {
          stack.push(edge.source)
          edge.source
            .incomingEdges()
            .foreach((incomingEdge) => {
              //Avoid self reference loop
              if (incomingEdge.source != edge.target) {
                scheduleDependencies(incomingEdge, stack)
              }
            })
          stack.pop()
        }
      }
    } else {
      ctx.currentExecutor.scheduleNode(edge.target)
    }
  }

}

class ReverseWeaveTypePropagator(ctx: WeaveTypeResolutionContext) extends BaseWeaveTypePropagator {
  val executionStack: mutable.ListBuffer[TypeNode] = mutable.ListBuffer()

  override def scheduleNode(node: TypeNode): Unit = {
    val foundNode = executionStack.indexWhere((scheduledNode) => scheduledNode eq node)
    if (foundNode >= 0) {
      executionStack.remove(foundNode)
    }
    executionStack.+=(node)
  }

  private def start(): Unit = {
    while (executionStack.nonEmpty) {
      val nodeToExecute = executionStack.remove(0)
      nodeToExecute.reverseResolve(ctx)
    }
  }

  private def scheduleNodes(): Unit = {
    ctx.currentGraph.nodes.foreach((node) => {
      scheduleNode(node)
    })
  }

  def run(): Unit = {
    scheduleNodes()
    start()
    ctx.endContext()
  }
}
