package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.parser.ModuleParser
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.utils.LockFactory

import java.util.concurrent.ConcurrentHashMap

/**
  * Handles the lifecycle of a module going through all the phases and caching for performance
  */
trait ModuleParsingPhasesManager {

  /**
    * Invalidate all cached entries
    */
  def invalidateAll(): Unit

  /**
    * Invalidate cached entries for the given NameIdentifier
    *
    * @param nameIdentifier The name identifier of the module
    */
  def invalidateModule(nameIdentifier: NameIdentifier): Unit

  /**
    * Runs parsing phase of the specified module with the given NameIdentifier
    *
    * @param nameIdentifier The name identifier of the module
    * @param parentContext  The parsing context
    * @return The Phase Result
    */
  def parseModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]]

  /**
    * Runs the scope phase of the specified module with the given NameIdentifier
    *
    * @param nameIdentifier The name identifier of the module
    * @param parentContext  The parsing context
    * @return The Phase Result
    */
  def scopeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ScopeGraphResult[ModuleNode]]]

  /**
    * Runs the type phase of the specified module with the given NameIdentifier
    *
    * @param nameIdentifier The name identifier of the module
    * @param parentContext  The parsing context
    * @return The Phase Result
    */
  def typeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[TypeCheckingResult[ModuleNode]]]

  /**
    * Checks if the module for the given NameIdentifier can be resolved
    *
    * @param nameIdentifier The name identifier
    * @return whether the module can be resolved by this ModuleParsingPhasesManager or not
    */
  def canResolveModule(nameIdentifier: NameIdentifier): Boolean

  /**
    * Picks the specific ModuleParsingPhasesManager that resolves a given module's name identifier
    *
    * @param nameIdentifier The module's name identifier
    * @return the specific ModuleParsingPhasesManager that resolves the name identifier
    */
  def moduleParsingPhasesManagerForNameIdentifier(nameIdentifier: NameIdentifier): Option[ModuleParsingPhasesManager] =
    if (canResolveModule(nameIdentifier)) Some(this) else None

}

object ModuleParsingPhasesManager {
  def apply(moduleLoader: ModuleLoaderManager): ModuleParsingPhasesManager = new DefaultModuleParsingPhasesManager(moduleLoader)
}

class WithDependencyGraphParsingPhasesManager(val delegate: ModuleParsingPhasesManager, dependencyGraph: DependencyGraph) extends ModuleParsingPhasesManager {

  override def invalidateModule(nameIdentifier: NameIdentifier): Unit = {
    delegate.invalidateModule(nameIdentifier)
  }

  override def parseModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] = {
    val result = delegate.parseModule(nameIdentifier, parentContext)
    if (result.isDefined) {
      dependencyGraph.loadDependenciesFrom(nameIdentifier, result.get)
    }
    result
  }

  override def scopeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ScopeGraphResult[ModuleNode]]] = {
    delegate.scopeCheckModule(nameIdentifier, parentContext)
  }

  override def typeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[TypeCheckingResult[ModuleNode]]] = {
    delegate.typeCheckModule(nameIdentifier, parentContext)
  }

  override def invalidateAll(): Unit = {
    delegate.invalidateAll()
    dependencyGraph.invalidateAll()
  }

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = delegate.canResolveModule(nameIdentifier)
}

object WithDependencyGraphParsingPhasesManager {
  def apply(delegate: ModuleParsingPhasesManager, dependencyGraph: DependencyGraph): WithDependencyGraphParsingPhasesManager = {
    new WithDependencyGraphParsingPhasesManager(delegate, dependencyGraph)
  }
}

class HierarchicalModuleParsingPhasesManager(val child: ModuleParsingPhasesManager, val parent: ModuleParsingPhasesManager) extends ModuleParsingPhasesManager {

  private val composite = CompositeModuleParsingPhasesManager(parent, child)

  override def invalidateAll(): Unit = composite.invalidateAll()

  override def invalidateModule(nameIdentifier: NameIdentifier): Unit = composite.invalidateModule(nameIdentifier)

  override def parseModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] =
    composite.parseModule(nameIdentifier, parentContext)

  override def scopeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ScopeGraphResult[ModuleNode]]] =
    composite.scopeCheckModule(nameIdentifier, parentContext)

  override def typeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[TypeCheckingResult[ModuleNode]]] =
    composite.typeCheckModule(nameIdentifier, parentContext)

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = composite.canResolveModule(nameIdentifier)

  override def moduleParsingPhasesManagerForNameIdentifier(nameIdentifier: NameIdentifier): Option[ModuleParsingPhasesManager] = {
    parent.moduleParsingPhasesManagerForNameIdentifier(nameIdentifier)
      .orElse(if (child.canResolveModule(nameIdentifier)) Some(this) else None)
  }
}

class DefaultModuleParsingPhasesManager(val moduleDefinitionLoader: ModuleLoaderManager) extends ModuleParsingPhasesManager {

  private val parsingCache = new ConcurrentHashMap[NameIdentifier, Option[PhaseResult[ParsingResult[ModuleNode]]]]()
  private val scopedGraphCache = new ConcurrentHashMap[NameIdentifier, Option[PhaseResult[ScopeGraphResult[ModuleNode]]]]()
  private val typeCheckingCache = new ConcurrentHashMap[NameIdentifier, Option[PhaseResult[TypeCheckingResult[ModuleNode]]]]()

  private def binaryModuleNameIdentifier(identifier: NameIdentifier) = {
    NameIdentifier(identifier.name, Some(ModuleLoader.BINARY_LOADER_NAME))
  }

  /**
    * As the lock is already being shared by the NameIdentifier I can use one lock and save space
    */
  private val lock = LockFactory.createLock()

  override def parseModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] = {
    val result = parsingCache.get(nameIdentifier)
    if (result != null) {
      result
    } else {
      lock.lock(
        nameIdentifier, {
        val result = parsingCache.get(nameIdentifier)
        if (result != null) {
          result
        } else {
          val moduleContext: ParsingContext = parentContext.child(nameIdentifier)
          val maybeModule: Option[PhaseResult[ParsingResult[ModuleNode]]] =
            moduleDefinitionLoader.loadModule(binaryModuleNameIdentifier(nameIdentifier), moduleContext)
              .orElse(moduleDefinitionLoader.loadModule(nameIdentifier, moduleContext))
          val newParsingModule: Option[PhaseResult[ParsingResult[ModuleNode]]] = maybeModule.map(moduleResult => {
            if (!moduleResult.hasErrors()) {
              //apply canonical phase
              ModuleParser.canonicalPhasePhases().call(moduleResult.getResult(), moduleContext)
            } else {
              moduleResult
            }
          })
          parsingCache.put(nameIdentifier, newParsingModule)
          newParsingModule
        }
      })
    }
  }

  override def scopeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ScopeGraphResult[ModuleNode]]] = {
    val result = scopedGraphCache.get(nameIdentifier)
    if (result != null) {
      result
    } else {
      val module: Option[PhaseResult[ParsingResult[ModuleNode]]] = parseModule(nameIdentifier, parentContext)
      lock.lock(
        nameIdentifier, {
        val result = scopedGraphCache.get(nameIdentifier)
        if (result != null) {
          result
        } else {
          val newScopeResult = module.map(result => {
            result.onSuccess(parsingResult => {
              val libraryContext: ParsingContext = parentContext.child(nameIdentifier)
              ModuleParser
                .scopePhasePhases()
                .call(parsingResult, libraryContext)
            })
          })
          scopedGraphCache.put(nameIdentifier, newScopeResult)
          newScopeResult
        }
      })
    }
  }

  override def typeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[TypeCheckingResult[ModuleNode]]] = {
    val result = typeCheckingCache.get(nameIdentifier)
    if (result != null) {
      result
    } else {
      val module: Option[PhaseResult[ScopeGraphResult[ModuleNode]]] = scopeCheckModule(nameIdentifier, parentContext)
      lock.lock(
        nameIdentifier, {
        val result = module.map(result => {
          result.onSuccess(parsingResult => {
            val libraryContext: ParsingContext = parentContext.child(nameIdentifier)
            ModuleParser.typeCheckPhasePhases().call(parsingResult, libraryContext)
          })
        })
        typeCheckingCache.put(nameIdentifier, result)
        result
      })
    }

  }

  override def invalidateModule(nameIdentifier: NameIdentifier): Unit = {
    lock.lock(nameIdentifier, {
      parsingCache.remove(nameIdentifier)
      scopedGraphCache.remove((nameIdentifier))
      typeCheckingCache.remove(nameIdentifier)
    })

  }

  override def invalidateAll(): Unit = {
    val keys = parsingCache.keySet().toArray(new Array[NameIdentifier](0))
    var i = 0
    while (i < keys.length) {
      val nameIdentifier = keys(i)
      invalidateModule(nameIdentifier)
      i = i + 1
    }
  }

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = moduleDefinitionLoader.canResolveModule(nameIdentifier)
}

class CompositeModuleParsingPhasesManager(loaders: ModuleParsingPhasesManager*) extends ModuleParsingPhasesManager {
  override def parseModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] = {
    loaders.toStream.flatMap(loader => loader.parseModule(nameIdentifier, parentContext)).headOption
  }

  override def scopeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[ScopeGraphResult[ModuleNode]]] = {
    loaders.toStream.flatMap(loader => loader.scopeCheckModule(nameIdentifier, parentContext)).headOption
  }

  override def typeCheckModule(nameIdentifier: NameIdentifier, parentContext: ParsingContext): Option[PhaseResult[TypeCheckingResult[ModuleNode]]] = {
    loaders.toStream.flatMap(loader => loader.typeCheckModule(nameIdentifier, parentContext)).headOption
  }

  override def invalidateModule(nameIdentifier: NameIdentifier): Unit = {
    loaders.foreach(_.invalidateModule(nameIdentifier))
  }

  override def invalidateAll(): Unit = {
    loaders.foreach(_.invalidateAll())
  }

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = loaders.exists(_.canResolveModule(nameIdentifier))
}

object CompositeModuleParsingPhasesManager {
  def apply(loaders: ModuleParsingPhasesManager*): CompositeModuleParsingPhasesManager = new CompositeModuleParsingPhasesManager(loaders: _*)
}
