package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.parser.InvalidTypeParameterCall
import org.mule.weave.v2.parser.InvalidTypeRef
import org.mule.weave.v2.parser.annotation.InjectedNodeAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.header.directives.TypeDirective
import org.mule.weave.v2.parser.ast.types.{ TypeParameterNode, TypeParametersListNode, TypeReferenceNode }
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.scope.{ AstNavigator, Reference, ScopesNavigator, VariableScope }

/**
  * Validates that a type reference refers to a Type value and not something else.
  *
  * @tparam T The type parameter extending AstNode.
  */
class TypeParameterCheckerPhase[T <: AstNode] extends FullCompileOnlyPhase[T, ScopeGraphResult[T]] {

  override def run(source: ScopeGraphResult[T], context: ParsingContext): Unit = {
    val typeReferenceNodes: Seq[TypeReferenceNode] = source.scope.astNavigator().allWithType(classOf[TypeReferenceNode])
    var astModified = false

    // Iterate through each type reference node
    typeReferenceNodes.foreach(typeRefNode => {
      val containerName: NameIdentifier = context.nameIdentifier
      val scope: ScopesNavigator = source.scope

      // Check if type parameters need to be injected
      val wasModified: Boolean = injectTypeParametersOnTypeReferenceNode(containerName, scope.scopeOf(typeRefNode.variable).get, typeRefNode, context)
      if (wasModified) {
        astModified = true
      }
    })

    // Invalidate the source scope if modifications were made
    if (astModified) {
      source.scope.invalidate()
    }
  }

  /**
    * Injects type parameters into the type reference node if necessary.
    *
    * @param containerName The name of the container.
    * @param scope The variable scope.
    * @param typeRefNode The type reference node to modify.
    * @param context The parsing context.
    * @return True if the AST was modified.
    */
  private def injectTypeParametersOnTypeReferenceNode(containerName: NameIdentifier, scope: VariableScope, typeRefNode: TypeReferenceNode, context: ParsingContext): Boolean = {
    var astModified = false
    val nameIdentifier: NameIdentifier = typeRefNode.variable
    val maybeReference: Option[Reference] = scope.resolveVariable(nameIdentifier)

    maybeReference match {
      case Some(typeReference) =>
        val parentNameIdentifier: NameIdentifier = typeReference.moduleSource.getOrElse(containerName)
        val referenceAstNavigator: AstNavigator = typeReference.scope.astNavigator()
        val parentNode = referenceAstNavigator.parentOf(typeReference.referencedNode).get

        parentNode match {
          case typeDirective: TypeDirective =>
            astModified = validateTypeParameters(parentNameIdentifier, typeReference, typeRefNode, context, nameIdentifier, typeDirective)

          case _: TypeParameterNode =>
            if (typeRefNode.typeArguments.nonEmpty) {
              context.messageCollector.error(InvalidTypeParameterCall(nameIdentifier.name, typeRefNode.typeArguments.get.size, 0), typeRefNode.location())
            }

          case otherNode =>
            if (typeRefNode.variable.name != NameIdentifier.INSERTED_FAKE_VARIABLE_NAME) {
              context.messageCollector.error(InvalidTypeRef(typeRefNode, otherNode), typeRefNode.location())
            }
        }

      case None => // No action needed if reference is not found
    }

    astModified
  }

  /**
    * Validates the type parameters of a type directive.
    *
    * @param parentNameIdentifier The name identifier for the module source from the type reference or the scope parsing context.
    * @param typeReference The reference.
    * @param typeRefNode The type reference node to validate.
    * @param context The parsing context.
    * @param nameIdentifier The name identifier.
    * @param typeDirective The type directive to validate against.
    */
  private def validateTypeParameters(parentNameIdentifier: NameIdentifier, typeReference: Reference, typeRefNode: TypeReferenceNode, context: ParsingContext, nameIdentifier: NameIdentifier, typeDirective: TypeDirective): Boolean = {
    var astModified = false
    if (typeDirective.typeParametersListNode.isEmpty && typeRefNode.typeArguments.nonEmpty) {
      context.messageCollector.error(InvalidTypeParameterCall(nameIdentifier.name, typeRefNode.typeArguments.get.size, 0), typeRefNode.location())
    } else {
      typeDirective.typeParametersListNode.foreach { typeParams =>
        if (typeParams.children().nonEmpty) {
          val typeArguments = typeRefNode.typeArguments
          if (typeArguments.isDefined) {
            // Ensure parameters and arguments have the same arity
            if (typeArguments.get.size != typeParams.children().size) {
              context.messageCollector.error(InvalidTypeParameterCall(nameIdentifier.name, typeArguments.get.size, typeParams.children().size), typeRefNode.location())
            }
          } else {
            // Inject default type parameters
            injectDefaultTypeParameters(parentNameIdentifier, typeReference, typeRefNode, context, typeParams)
            astModified = true
          }
        }
      }
    }
    astModified
  }

  /**
    * Injects default type parameters into the type reference node.
    *
    * @param parentNameIdentifier The name identifier for the module source from the type reference or the scope parsing context.
    * @param typeReference The reference.
    * @param typeRefNode The type reference node to modify.
    * @param context The parsing context.
    * @param typeParams The type parameters to inject.
    */
  private def injectDefaultTypeParameters(parentNameIdentifier: NameIdentifier, typeReference: Reference, typeRefNode: TypeReferenceNode, context: ParsingContext, typeParams: TypeParametersListNode): Unit = {
    val typeNodes = typeParams.typeParameters.map { typeParametersListNode =>
      typeParametersListNode.base.map {
        case typeReferenceNode: TypeReferenceNode =>
          handleTypeReferenceNode(
            parentNameIdentifier,
            typeReference,
            typeReferenceNode,
            context)

        case astNode => astNode
      }.getOrElse(TypeReferenceNode(NameIdentifier("Any")).annotate(InjectedNodeAnnotation()))
    }
    typeRefNode.typeArguments = Some(typeNodes)
  }

  /**
    * Handles the type reference node for injecting parameters.
    *
    * @param parentNameIdentifier The name identifier for the module source from the type reference or the scope parsing context.
    * @param typeReference The reference.
    * @param typeReferenceNode The type reference node to handle.
    * @param context The parsing context.
    * @return The modified type reference node.
    */
  private def handleTypeReferenceNode(parentNameIdentifier: NameIdentifier, typeReference: Reference, typeReferenceNode: TypeReferenceNode, context: ParsingContext): TypeReferenceNode = {
    val nameIdentifier: NameIdentifier = typeReferenceNode.variable
    val maybeReference: Option[Reference] = typeReference.scope.resolveVariable(nameIdentifier)

    maybeReference match {
      case Some(reference) =>
        val localName = reference.referencedNode.localName()
        val typeParamFQN = reference.moduleSource.getOrElse(parentNameIdentifier).::(localName.name)

        val modifiedNode = cloneTypeReferenceNode(typeReferenceNode, typeParamFQN)

        // Look for injecting type parameters on arguments based on type directive definition
        injectTypeParametersOnTypeReferenceNode(parentNameIdentifier, typeReference.scope, modifiedNode, context)

        modifiedNode
      case None =>
        // If no reference is found, return a copy from the original node
        typeReferenceNode.copy(typeReferenceNode.variable.doClone())
    }
  }

  private def cloneTypeReferenceNode(typeReferenceNode: TypeReferenceNode, nameIdentifier: NameIdentifier): TypeReferenceNode = {
    if (typeReferenceNode.typeArguments.isDefined && typeReferenceNode.typeArguments.get.nonEmpty) {
      val maybeTypeArgumentNodes = typeReferenceNode.typeArguments.map(typeArgs => {
        typeArgs.map({
          case typeArgumentTypeReferenceNode: TypeReferenceNode =>
            cloneTypeReferenceNode(typeArgumentTypeReferenceNode, typeArgumentTypeReferenceNode.variable.doClone())
          case tn => tn
        })
      })
      TypeReferenceNode(nameIdentifier, maybeTypeArgumentNodes)
    } else {
      TypeReferenceNode(nameIdentifier)
    }
  }
}