package org.mule.weave.v2.interpreted

import org.mule.weave.v2.interpreted.node.ModuleNode
import org.mule.weave.v2.interpreted.transform.EngineGrammarTransformation
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.phase.ModuleParsingPhasesManager
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.utils.LockFactory

import java.util.concurrent.ConcurrentHashMap

/**
  * Handles the transformation to a Executable ModuleNode.
  * It goes from an fully qualified name, look it up and executes the corresponding compilation phases.
  */
trait RuntimeModuleNodeCompiler {

  /**
    * Compiles the module with the specified NameIdentifier
    *
    * @param moduleName     The name of the module to be compiled
    * @param parsingContext The parsing context
    * @return The runtime module node if it was able to successfully compile it.
    */
  def compile(moduleName: NameIdentifier, parsingContext: ParsingContext): Option[ModuleNode]

  /**
    * Compiles the module with the specified NameIdentifier
    *
    * @param moduleName     The name of the module to be compiled
    * @param parsingContext The parsing context
    * @param rootCompiler   The root compiler in the chain of compilers.
    * @return The runtime module node if it was able to successfully compile it.
    */
  def compile(moduleName: NameIdentifier, parsingContext: ParsingContext, rootCompiler: RuntimeModuleNodeCompiler): Option[ModuleNode]

  /**
    * Invalidate a cached entry with the given NameIdentifier
    *
    * @param nameIdentifier The name identifier of the module to invalidate.
    */
  def invalidate(nameIdentifier: NameIdentifier): Unit
}

class DefaultRuntimeModuleNodeCompiler extends RuntimeModuleNodeCompiler {

  val modules = new ConcurrentHashMap[NameIdentifier, ModuleNode]()
  private val modulesLock = LockFactory.createLock()

  override def compile(moduleName: NameIdentifier, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): Option[ModuleNode] = {
    val maybeModule = modules.get(moduleName)
    if (maybeModule != null) {
      Some(maybeModule)
    } else {
      modulesLock.lock({
        val maybeModule = modules.get(moduleName)
        if (maybeModule != null) {
          Some(maybeModule)
        } else {
          val maybeResult = parsingContext.moduleParserManager.scopeCheckModule(moduleName, parsingContext)
          maybeResult.map(result => {
            val moduleNode = EngineGrammarTransformation(parsingContext.child(moduleName), result.getResult().scope, moduleNodeLoader).transformModule(result.getResult().astNode)
            modules.put(moduleName, moduleNode)
            moduleNode
          })
        }
      })
    }
  }

  def compile(moduleName: NameIdentifier, parsingContext: ParsingContext): Option[ModuleNode] = {
    compile(moduleName, parsingContext, this)
  }

  override def invalidate(nameIdentifier: NameIdentifier): Unit = {
    modules.remove(nameIdentifier)
  }
}

/**
  * This Compiler will only compile the modules that are present on the specified ModuleParsingPhasesManager. If not present there it will delegate to the parent
  */
class CustomRuntimeModuleNodeCompiler(parserManager: ModuleParsingPhasesManager, parent: Option[RuntimeModuleNodeCompiler], parentLast: Boolean = false) extends RuntimeModuleNodeCompiler {
  val modules = new ConcurrentHashMap[NameIdentifier, ModuleNode]()
  private val modulesLock = LockFactory.createLock()

  def compile(moduleName: NameIdentifier, parsingContext: ParsingContext): Option[ModuleNode] = {
    compile(moduleName, parsingContext, this)
  }

  override def compile(moduleName: NameIdentifier, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): Option[ModuleNode] = {
    if (parentLast) {
      localCompile(moduleName, parsingContext, moduleNodeLoader).orElse({
        parent.flatMap(parentLoader => {
          parentLoader.compile(moduleName, parsingContext, moduleNodeLoader)
        })
      })
    } else {
      val compileInParent = parent
        .flatMap(parentLoader => {
          parentLoader.compile(moduleName, parsingContext, moduleNodeLoader)
        })
      compileInParent
        .orElse({
          localCompile(moduleName, parsingContext, moduleNodeLoader)
        })
    }
  }

  private def localCompile(moduleName: NameIdentifier, parsingContext: ParsingContext, moduleNodeLoader: RuntimeModuleNodeCompiler): Option[ModuleNode] = {
    val maybeModule = modules.get(moduleName)
    if (maybeModule != null) {
      Some(maybeModule)
    } else {
      modulesLock.lock({
        val maybeModule = modules.get(moduleName)
        if (maybeModule != null) {
          Some(maybeModule)
        } else {
          val maybeResult = parserManager.scopeCheckModule(moduleName, parsingContext)
          maybeResult.map(result => {
            val moduleNode = EngineGrammarTransformation(parsingContext, result.getResult().scope, moduleNodeLoader).transformModule(result.getResult().astNode)
            modules.put(moduleName, moduleNode)
            moduleNode
          })
        }
      })
    }
  }

  override def invalidate(nameIdentifier: NameIdentifier): Unit = {
    modules.remove(nameIdentifier)
    parent.foreach(_.invalidate(nameIdentifier))
  }
}

object RuntimeModuleNodeCompiler {

  def apply(): RuntimeModuleNodeCompiler = new DefaultRuntimeModuleNodeCompiler()

  def apply(parserManager: ModuleParsingPhasesManager, parent: Option[RuntimeModuleNodeCompiler] = None): RuntimeModuleNodeCompiler = new CustomRuntimeModuleNodeCompiler(parserManager, parent)

  def parentFirst(parserManager: ModuleParsingPhasesManager, parent: RuntimeModuleNodeCompiler) = new CustomRuntimeModuleNodeCompiler(parserManager, Some(parent), false)

  def parentLast(parserManager: ModuleParsingPhasesManager, parent: RuntimeModuleNodeCompiler) = new CustomRuntimeModuleNodeCompiler(parserManager, Some(parent), true)

  def chain(parserManager: ModuleParsingPhasesManager, parent: RuntimeModuleNodeCompiler, parentLast: Boolean) = new CustomRuntimeModuleNodeCompiler(parserManager, Some(parent), parentLast)
}
