package org.mule.weave.v2.scope

import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast._
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.functions._
import org.mule.weave.v2.parser.ast.header.directives._
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.patterns._
import org.mule.weave.v2.parser.ast.structure.NamespaceNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.updates.UpdateExpressionNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.parser.phase.PhaseResult
import org.mule.weave.v2.parser.phase.ScopeGraphResult
import org.mule.weave.v2.ts.WeaveTypeReferenceResolver
import org.mule.weave.v2.utils.IdentityHashMap

import scala.collection.mutable

class ScopesNavigator(val rootScope: VariableScope) {

  private lazy val invalidRefs: Seq[AstNode] = resolveInvalidReference(rootScope)
  private lazy val resolveUnusedRefs: Seq[NameIdentifier] = resolveUnused(rootScope)
  private lazy val importResult = analyzeImports(rootScope)
  private var astNodeToScope: IdentityHashMap[AstNode, VariableScope] = _

  def referenceResolver: WeaveTypeReferenceResolver = rootScope.referenceResolver()

  private def resolveInvalidReference(scope: VariableScope): Seq[AstNode] = {
    val invalidReferences = mutable.ListBuffer[AstNode]()
    //We filter ImportedElements as they are remote variable references
    val localVariableReferences = scope
      .references()
      .filterNot((ni) => {
        rootScope.astNavigator().parentOf(ni).exists(_.isInstanceOf[ImportedElement])
      })
    for (ref <- localVariableReferences) {
      scope.resolveVariable(ref) match {
        case Some(_) =>
        case _       => invalidReferences += ref
      }
    }

    for (
      importedModule <- scope.importedModules();
      element <- importedModule._1.subElements.elements
    ) {
      if (!element.elementName.name.equals("*")) {
        importedModule._2.resolveVariable(element.elementName) match {
          case Some(_) =>
          case None    => invalidReferences += element.elementName
        }
      }
    }
    invalidReferences ++= scope
      .children()
      .flatMap((scope) => {
        resolveInvalidReference(scope)
      })
    invalidReferences
  }

  private def resolveAstNodeToScope(scope: VariableScope): IdentityHashMap[AstNode, VariableScope] = {

    def toRootAstNodeScope(scope: VariableScope, collector: IdentityHashMap[AstNode, VariableScope] = IdentityHashMap[AstNode, VariableScope]()): IdentityHashMap[AstNode, VariableScope] = {
      collector.put(scope.astNode, scope)
      scope
        .children()
        .foreach((scope) => {
          toRootAstNodeScope(scope, collector)
        })
      collector
    }

    def mapAstNodeToScopes(astNode: AstNode, currentScope: VariableScope, rootAstNodeScopes: IdentityHashMap[AstNode, VariableScope], collector: IdentityHashMap[AstNode, VariableScope] = IdentityHashMap[AstNode, VariableScope]()): IdentityHashMap[AstNode, VariableScope] = {
      val scope = rootAstNodeScopes.getOrElse(astNode, currentScope)
      collector.put(astNode, scope)
      astNode
        .children()
        .foreach((child) => {
          mapAstNodeToScopes(child, scope, rootAstNodeScopes, collector)
        })
      collector
    }

    mapAstNodeToScopes(scope.astNode, scope, toRootAstNodeScope(scope))
  }

  def resolveLocalReferencedBy(nameIdentifier: NameIdentifier): Seq[Reference] = {
    rootScope.resolveReferenceTo(nameIdentifier)
  }

  private def analyzeImports(scope: VariableScope): ImportInformationResult = {

    def collectReferences(scope: VariableScope): Seq[NameIdentifier] = {
      //Exclude ImportDirective References
      scope.references()
        .filterNot((s) => scope.astNavigator().isChildOfAny(s, classOf[ImportedElement])) ++ (scope.children().flatMap((scope) => collectReferences(scope)))
    }

    val result = new ImportInformationResult()
    val importDirectives: Seq[ImportDirective] = scope.astNavigator().importDirectives()

    for (id <- importDirectives) {
      if (!id.isAnnotatedWith(classOf[InjectedNodeAnnotation]) && !id.importedModule.elementName.equals(NameIdentifier.INSERTED_FAKE_VARIABLE)) {
        result.addImportDirective(id)
      }
    }

    val identifiers: Seq[NameIdentifier] = collectReferences(scope)
    for (identifier <- identifiers) {
      val maybeReference = scopeOf(identifier).flatMap(_.resolveVariable(identifier))
      maybeReference match {
        case None =>
        case Some(value) => {
          if (value.isCrossModule) {
            result.addUsage(value.referencedNode, value.moduleSource.get)
          }
        }
      }
    }
    result
  }

  private def resolveUnused(scope: VariableScope): Seq[NameIdentifier] = {
    def doResolveUnused(scope: VariableScope, unusedReferences: mutable.HashSet[NameIdentifier]): Unit = {
      unusedReferences ++= scope
        .declarations()
        .filter(x =>
          !x.name.matches("""\$+""")
            && x.annotation(classOf[InjectedNodeAnnotation]).isEmpty)
      for (ref <- scope.references()) {
        scope.resolveVariable(ref) match {
          case Some(node) => unusedReferences -= node.referencedNode
          case _          =>
        }
      }
      scope
        .children()
        .foreach((scope) => {
          doResolveUnused(scope, unusedReferences)
        })
    }

    val unusedReferences = new mutable.HashSet[NameIdentifier]()
    doResolveUnused(scope, unusedReferences)
    unusedReferences.toSeq
  }

  /**
    * Returns the declaration to which the given reference points.
    */
  def resolveReference(ref: AstNode): Option[AstNode] = {
    val maybeScope: Option[VariableScope] = scopeOf(ref)
    maybeScope match {
      case None => None
      case Some(scope) => {
        ref match {
          case vrn: VariableReferenceNode => scope.resolveVariable(vrn.variable).map(_.referencedNode)
          case nsn: NamespaceNode         => scope.resolveVariable(nsn.prefix).map(_.referencedNode)
          case trn: TypeReferenceNode     => scope.resolveVariable(trn.variable).map(_.referencedNode)
          case ni: NameIdentifier         => scope.resolveVariable(ni).map(_.referencedNode)
          case _                          => None
        }
      }
    }

  }

  def astNavigator(): AstNavigator = {
    rootScope.astNavigator()
  }

  /**
    * Returns the declaration to which the given reference points.
    */
  def resolveVariable(ref: NameIdentifier): Option[Reference] = {
    val maybeScope: Option[VariableScope] = scopeOf(ref)
    maybeScope.flatMap(_.resolveVariable(ref))
  }

  /**
    * Returns all invalid references in this scope
    *
    * @return The list of all un resolved references.
    */
  def invalidReferences(): Seq[AstNode] = {
    invalidRefs
  }

  /**
    * Returns all the unused declarations
    */
  def unusedDeclarations(): Seq[NameIdentifier] = {
    resolveUnusedRefs
  }

  def importInformation(): ImportInformationResult = {
    importResult
  }

  /**
    * Returns the scope that this node belongs to.
    *
    * @param nodeToSearch The node to search
    * @return The scope of that node
    */
  def scopeOf(nodeToSearch: AstNode): Option[VariableScope] = {
    //There should always be a scope
    getAstNodeToScope().get(nodeToSearch)
  }

  def invalidate(): Unit = {
    astNodeToScope = null
  }

  private def getAstNodeToScope(): IdentityHashMap[AstNode, VariableScope] = {
    if (astNodeToScope == null) {
      astNodeToScope = resolveAstNodeToScope(rootScope)
    }
    astNodeToScope
  }

}

object ScopesNavigator {
  def apply(astNode: AstNode, parsingContext: ParsingContext, parentScope: Option[VariableScope] = None): ScopesNavigator = {
    val rootScope = new ScopeFactory(parsingContext).create(astNode)
    rootScope.parentScope = parentScope
    new ScopesNavigator(rootScope)
  }
}

class ScopeFactory(context: ParsingContext) {
  def create(astNode: AstNode): VariableScope = {
    astNode match {
      case mn: ModuleNode =>
        val scope = VariableScope(context, mn.name.name, astNode)
        mn.elements.foreach((x) => visitNode(x, scope))
        scope
      case _ =>
        val scope = VariableScope(context, astNode)
        visitNode(astNode, scope)
        scope
    }
  }

  private def visitChildren(astNode: AstNode, scope: VariableScope): Unit = {
    val children = astNode.children()
    var i = 0
    while (i < children.length) {
      visitNode(children(i), scope)
      i = i + 1
    }
  }

  private def visitNode(astNode: AstNode, scope: VariableScope): Unit = {
    astNode match {
      case _: VersionDirective =>
      case vd @ VarDirective(varName, _, _, _) => {
        visitChildren(vd, scope)
        scope.addDeclaration(varName)
      }
      case id @ ImportDirective(moduleName, _, codeAnnotations) => {
        val importedNameIdentifier = moduleName.elementName
        if (!importedNameIdentifier.equals(context.nameIdentifier)) {
          //If we are importing same we don't do anything
          val module: PhaseResult[ScopeGraphResult[ModuleNode]] = context.getScopeGraphForModule(importedNameIdentifier)
          if (module.hasResult()) {
            scope.addImportedModule(id, module.getResult().scope.rootScope)
          }
          id.subElements.elements.foreach((element) => {
            if (!element.elementName.equals(NameIdentifier.$star)) {
              scope.addReference(element.elementName)
            }
          })
        }
        visitNodes(codeAnnotations, scope)
      }
      case TypeDirective(varName, typeParametersListNode, typeExpression, codeAnnotations) => {
        scope.addDeclaration(varName)
        val typeScope = scope.createChild(typeExpression)
        if (typeParametersListNode.isDefined) {
          typeParametersListNode.get.typeParameters.foreach((node) => {
            typeScope.addDeclaration(node.name)
            visitNode(node, scope)
          })
        }
        visitNode(typeExpression, typeScope)
        visitNodes(codeAnnotations, scope)
      }
      case trn @ TypeReferenceNode(_, typeParameters, schema, typeSchema, codeAnnotations) => {
        scope.addReference(trn.variable)
        if (typeParameters.isDefined) {
          visitNodes(typeParameters.get, scope)
        }
        visitOptionNode(schema, scope)
        visitOptionNode(typeSchema, scope)
        visitNodes(codeAnnotations, scope)
      }
      case id: InputDirective => {
        scope.addDeclaration(id.variable)
        visitChildren(id, scope)
      }
      case ad: AnnotationDirectiveNode => {
        scope.addDeclaration(ad.nameIdentifier)
        visitChildren(ad, scope)
      }
      case AnnotationNode(name, args) => {
        scope.addReference(name)
        visitOptionNode(args, scope)
      }
      case NamespaceDirective(prefix, _, codeAnnotations) => {
        scope.addDeclaration(prefix)
        visitNodes(codeAnnotations, scope)
      }
      case nn: NamespaceNode => {
        scope.addReference(nn.prefix)
      }
      case uen: UpdateExpressionNode => {
        val variableScope = scope.createChild(uen)
        variableScope.addDeclaration(uen.name)
        variableScope.addDeclaration(uen.indexId)
        visitNode(uen.selector, variableScope)
        visitNode(uen.updateExpression, variableScope)
        visitOptionNode(uen.condition, variableScope)
      }
      case pn: PatternExpressionNode => {
        val expressionPatternScope: VariableScope = scope.createChild(pn)
        pn match {
          case RegexPatternNode(pattern, name, onMatch) => {
            expressionPatternScope.addDeclaration(name)
            visitNode(pattern, expressionPatternScope)
            visitNode(onMatch, expressionPatternScope)
          }
          case TypePatternNode(pattern, name, onMatch) => {
            visitNode(pattern, expressionPatternScope)
            expressionPatternScope.addDeclaration(name)
            visitNode(onMatch, expressionPatternScope)
          }
          case LiteralPatternNode(pattern, name, onMatch) => {
            visitNode(pattern, expressionPatternScope)
            expressionPatternScope.addDeclaration(name)
            visitNode(onMatch, expressionPatternScope)
          }
          case ExpressionPatternNode(pattern, name, onMatch) => {
            expressionPatternScope.addDeclaration(name)
            visitNode(pattern, expressionPatternScope)
            visitNode(onMatch, expressionPatternScope)
          }
          case EmptyArrayPatternNode(onMatch) => {
            visitNode(onMatch, expressionPatternScope)
          }
          case DeconstructArrayPatternNode(headNameIdentifier, tailNameIdentifier, onMatch) => {
            expressionPatternScope.addDeclaration(headNameIdentifier)
            expressionPatternScope.addDeclaration(tailNameIdentifier)
            visitNode(onMatch, expressionPatternScope)
          }
          case DeconstructObjectPatternNode(headKeyNameIdentifier, headValueNameIdentifier, tailNameIdentifier, onMatch) => {
            expressionPatternScope.addDeclaration(headKeyNameIdentifier)
            expressionPatternScope.addDeclaration(headValueNameIdentifier)
            expressionPatternScope.addDeclaration(tailNameIdentifier)
            visitNode(onMatch, expressionPatternScope)
          }
          case DefaultPatternNode(value, name) => {
            expressionPatternScope.addDeclaration(name)
            visitNode(value, expressionPatternScope)
          }
          case _ =>
        }
      }
      case FunctionDirectiveNode(varName, fn, codeAnnotation) => {
        scope.addDeclaration(varName)
        codeAnnotation.foreach((co) => {
          visitNode(co, scope)
        })
        visitNode(fn, scope)
      }
      case ModuleNode(name, elements) => {
        val child: VariableScope = scope.createChild(name.name, astNode)
        elements.foreach((x) => visitNode(x, child))
      }
      case overloadedFunctionNode: OverloadedFunctionNode => {
        val children = overloadedFunctionNode.functions
        var i = 0
        while (i < children.length) {
          visitNode(children(i), scope)
          i = i + 1
        }
      }
      case fn @ FunctionNode(arguments, body, returnType, typeParameterList) => {
        //Function has a scope with the function definition. All type parameters and function params and return type
        val functionScope = scope.createChild(fn)
        typeParameterList match {
          case None =>
          case Some(typeParametersList) => {
            functionScope.addDeclarations(typeParametersList.typeParameters.map(_.name))
          }
        }
        arguments.paramList.foreach((argument) => {
          argument.wtype match {
            case None =>
            case Some(argType) =>
              visitNode(argType, functionScope)
          }
        })

        arguments.paramList.foreach((argument) => {
          argument.codeAnnotations.foreach((ann) => visitNode(ann, functionScope))
        })

        returnType match {
          case None                 =>
          case Some(returnTypeNode) => visitNode(returnTypeNode, functionScope)
        }
        for (maybeDefaultValue <- arguments.paramList.map(_.defaultValue); defaultValue <- maybeDefaultValue) {
          visitNode(defaultValue, functionScope)
        }
        //End of function scope.
        //Function body is child of the function definition scope
        val funBodyScope: VariableScope = functionScope.createChild(body)
        val argNames: Seq[NameIdentifier] = arguments.paramList.map(_.variable)
        funBodyScope.addDeclarations(argNames)
        visitNode(body, funBodyScope)
      }
      case UsingNode(assignments, expr, codeAnnotations) => {
        var currentScope = scope
        assignments.assignmentSeq.foreach((assignment) => {
          currentScope = currentScope.createChild(assignment)
          visitNode(assignment.value, currentScope)
          currentScope.addDeclaration(assignment.name)
        })
        val usingScope = currentScope.createChild(expr)
        visitNode(expr, usingScope)
        visitNodes(codeAnnotations, scope)
      }
      case doBlock: DoBlockNode => {
        val doBlockScope = scope.createChild(doBlock)
        doBlock.header.children().foreach((x) => visitNode(x, doBlockScope))
        visitNode(doBlock.body, doBlockScope)
        visitNodes(doBlock.codeAnnotations, scope)
      }
      case ref: VariableReferenceNode => {
        scope.addReference(ref.variable)
        visitNodes(ref.codeAnnotations, scope)
      }
      case node => {
        visitChildren(node, scope)
      }
    }
  }

  private def visitNodes(nodes: Seq[AstNode], scope: VariableScope) = {
    nodes.foreach((co) => {
      visitNode(co, scope)
    })
  }

  private def visitOptionNode(nodes: Option[AstNode], scope: VariableScope) = {
    nodes.foreach((co) => {
      visitNode(co, scope)
    })
  }
}
