package org.mule.weave.v2.parser.phase
import org.mule.weave.v2.parser.{ DuplicatedSyntaxDirective, InvalidSyntaxDirectivePlacement }
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.ast.header.HeaderNode
import org.mule.weave.v2.parser.ast.header.directives.{ DirectiveNode, VersionDirective }
import org.mule.weave.v2.parser.ast.module.ModuleNode

class SyntaxDirectiveValidation extends AstNodeVerifier {
  private def shouldError(firstSyntaxDirective: VersionDirective): Boolean = {
    //Syntax directive validations introduced after DW 2.4
    firstSyntaxDirective.major.v.toInt == 2 && firstSyntaxDirective.minor.v.toInt > 4 ||
      firstSyntaxDirective.major.v.toInt > 2
  }

  /**
    *
    * @param directives
    * @param context
    *
    * After DW 2.4 syntax directive must be unique per script and must be at the start of the script
    * (barring comments)
    */
  def checkSyntaxDirective(directives: Seq[DirectiveNode], context: ParsingContext): Unit = {
    val syntaxDirectives = directives.collect({ case d: VersionDirective => d })
    if (syntaxDirectives.isEmpty) return

    if (syntaxDirectives.length > 1) {
      syntaxDirectives.tail.foreach(sd => {
        if (shouldError(syntaxDirectives.head)) {
          context.messageCollector.error(DuplicatedSyntaxDirective(), sd.location())
        } else {
          context.messageCollector.warning(DuplicatedSyntaxDirective(), sd.location())
        }
      })
    }

    if (syntaxDirectives.head != directives.head) {
      if (shouldError(syntaxDirectives.head)) {
        context.messageCollector.error(InvalidSyntaxDirectivePlacement(), syntaxDirectives.head.location())
      } else {
        context.messageCollector.warning(InvalidSyntaxDirectivePlacement(), syntaxDirectives.head.location())
      }
    }
  }

  override def verify(node: AstNode, context: ParsingContext): Unit = {
    node match {
      case HeaderNode(directives) => checkSyntaxDirective(directives, context)
      case ModuleNode(_, nodes)   => checkSyntaxDirective(nodes, context)
      case _                      =>
    }
  }
}
