package org.mule.weave.v2.scope

import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast._
import org.mule.weave.v2.parser.ast.header.directives._
import org.mule.weave.v2.parser.ast.structure.NamespaceNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.ts.WeaveTypeReferenceResolver
import org.mule.weave.v2.utils.IdentityHashMap
import org.mule.weave.v2.utils.LazyValRef
import org.mule.weave.v2.utils.ThreadSafe

import scala.collection.mutable

/**
  * This class provides an API for variable resolution.
  *
  * This class needs to support concurrent access as multiple threads may need to be able to resolve variables from different modules
  */
@ThreadSafe
class ScopesNavigator(val rootScope: VariableScope) extends VariableResolver {

  private lazy val invalidRefs: Seq[AstNode] = resolveInvalidReference(rootScope)
  private lazy val resolveUnusedRefs: Seq[NameIdentifier] = resolveUnused(rootScope)
  private lazy val importResult: ImportInformationResult = analyzeImports(rootScope)
  private val astNodeToScope: LazyValRef[IdentityHashMap[AstNode, VariableScope]] = LazyValRef(() => {
    resolveAstNodeToScope(rootScope)
  })

  def referenceResolver: WeaveTypeReferenceResolver = rootScope.referenceResolver()

  private def resolveInvalidReference(scope: VariableScope): Seq[AstNode] = {
    val invalidReferences = mutable.ListBuffer[AstNode]()
    //We filter ImportedElements as they are remote variable references
    val localVariableReferences = scope
      .references
      .filterNot(ni => {
        rootScope.astNavigator().parentOf(ni).exists(_.isInstanceOf[ImportedElement])
      })
    for (ref <- localVariableReferences) {
      scope.resolveVariable(ref) match {
        case Some(_) =>
        case _       => invalidReferences += ref
      }
    }

    for (
      importedModule <- scope.importedModules;
      element <- importedModule._1.subElements.elements
    ) {
      if (!element.elementName.name.equals("*")) {
        importedModule._2.resolveVariable(element.elementName) match {
          case Some(_) =>
          case None    => invalidReferences += element.elementName
        }
      }
    }
    invalidReferences ++= scope
      .children()
      .flatMap(scope => resolveInvalidReference(scope))
    invalidReferences
  }

  private def resolveAstNodeToScope(scope: VariableScope): IdentityHashMap[AstNode, VariableScope] = {

    def toRootAstNodeScope(scope: VariableScope, collector: IdentityHashMap[AstNode, VariableScope] = IdentityHashMap[AstNode, VariableScope]()): IdentityHashMap[AstNode, VariableScope] = {
      collector.put(scope.astNode, scope)
      scope
        .children()
        .foreach(childScope => toRootAstNodeScope(childScope, collector))
      collector
    }

    def mapAstNodeToScopes(astNode: AstNode, currentScope: VariableScope, rootAstNodeScopes: IdentityHashMap[AstNode, VariableScope], collector: IdentityHashMap[AstNode, VariableScope] = IdentityHashMap[AstNode, VariableScope]()): IdentityHashMap[AstNode, VariableScope] = {
      val scope = rootAstNodeScopes.getOrElse(astNode, currentScope)
      collector.put(astNode, scope)
      astNode
        .children()
        .foreach(child => mapAstNodeToScopes(child, scope, rootAstNodeScopes, collector))
      collector
    }

    mapAstNodeToScopes(scope.astNode, scope, toRootAstNodeScope(scope))
  }

  def resolveLocalReferencedBy(nameIdentifier: NameIdentifier): Seq[Reference] = {
    rootScope.resolveReferenceTo(nameIdentifier)
  }

  private def analyzeImports(scope: VariableScope): ImportInformationResult = {

    def collectReferences(scope: VariableScope): Seq[NameIdentifier] = {
      //Exclude ImportDirective References
      scope.references.filterNot(s => scope.astNavigator().isChildOfAny(s, classOf[ImportedElement])) ++
        scope.children().flatMap(collectReferences)
    }

    val result = new ImportInformationResult()
    val importDirectives: Seq[ImportDirective] = scope.astNavigator().importDirectives()

    for (id <- importDirectives) {
      if (!id.isAnnotatedWith(classOf[InjectedNodeAnnotation]) && !id.importedModule.elementName.equals(NameIdentifier.INSERTED_FAKE_VARIABLE)) {
        result.addImportDirective(id)
      }
    }

    val identifiers: Seq[NameIdentifier] = collectReferences(scope)
    for (identifier <- identifiers) {
      val maybeReference = scopeOf(identifier).flatMap(_.resolveVariable(identifier))
      maybeReference match {
        case Some(value) if value.isCrossModule =>
          result.addUsage(value.referencedNode, value.moduleSource.get)
        case _ =>
      }
    }
    result
  }

  private def resolveUnused(scope: VariableScope): Seq[NameIdentifier] = {
    def doResolveUnused(scope: VariableScope, unusedReferences: mutable.HashSet[NameIdentifier]): Unit = {
      unusedReferences ++= scope
        .declarations
        .filter(x =>
          !x.name.matches("""\$+""")
            && x.annotation(classOf[InjectedNodeAnnotation]).isEmpty)
      for (ref <- scope.references) {
        scope.resolveVariable(ref) match {
          case Some(node) => unusedReferences -= node.referencedNode
          case _          =>
        }
      }
      scope
        .children()
        .foreach(scope => doResolveUnused(scope, unusedReferences))
    }

    val unusedReferences = new mutable.HashSet[NameIdentifier]()
    doResolveUnused(scope, unusedReferences)
    unusedReferences.toSeq
  }

  /**
    * Returns the declaration to which the given reference points.
    */
  def resolveReference(ref: AstNode): Option[AstNode] = {
    val maybeScope: Option[VariableScope] = scopeOf(ref)
    maybeScope match {
      case None => None
      case Some(scope) =>
        ref match {
          case vrn: VariableReferenceNode => scope.resolveVariable(vrn.variable).map(_.referencedNode)
          case nsn: NamespaceNode         => scope.resolveVariable(nsn.prefix).map(_.referencedNode)
          case trn: TypeReferenceNode     => scope.resolveVariable(trn.variable).map(_.referencedNode)
          case ni: NameIdentifier         => scope.resolveVariable(ni).map(_.referencedNode)
          case _                          => None
        }
    }

  }

  def astNavigator(): AstNavigator = {
    rootScope.astNavigator()
  }

  /**
    * Returns the declaration to which the given reference points.
    */
  def resolveVariable(ref: NameIdentifier): Option[Reference] = {
    val maybeScope: Option[VariableScope] = scopeOf(ref)
    maybeScope.flatMap(_.resolveVariable(ref))
  }

  /**
    * Returns all invalid references in this scope
    *
    * @return The list of all un resolved references.
    */
  def invalidReferences(): Seq[AstNode] = {
    invalidRefs
  }

  /**
    * Returns all the unused declarations
    */
  def unusedDeclarations(): Seq[NameIdentifier] = {
    resolveUnusedRefs
  }

  def importInformation(): ImportInformationResult = {
    importResult
  }

  /**
    * Returns the scope that this node belongs to.
    *
    * @param nodeToSearch The node to search
    * @return The scope of that node
    */
  def scopeOf(nodeToSearch: AstNode): Option[VariableScope] = {
    //There should always be a scope
    getAstNodeToScope.get(nodeToSearch)
  }

  def invalidate(): Unit = {
    astNodeToScope.clean()
  }

  private def getAstNodeToScope: IdentityHashMap[AstNode, VariableScope] = {
    astNodeToScope.get()
  }

}

object ScopesNavigator {
  def apply(astNode: AstNode, parsingContext: ParsingContext): ScopesNavigator = {
    val rootScope = new VariableScopeFactory(parsingContext).create(astNode)
    new ScopesNavigator(rootScope)
  }
}
