package org.mule.weave.v2.parser.phase

import org.mule.weave.v2.annotations.WeaveApi
import org.mule.weave.v2.parser.Message
import org.mule.weave.v2.parser.MessageCollector
import org.mule.weave.v2.parser.annotation.PreCompiledTypeAnnotation
import org.mule.weave.v2.parser.ast.AstNode
import org.mule.weave.v2.parser.location.WeaveLocation

trait ConditionalPhase[S] extends CompilationPhase[S, S] {
  def shouldCall(source: S, context: ParsingContext): Boolean

  def run(source: S, context: ParsingContext): Unit

  override def call(source: S, context: ParsingContext): PhaseResult[_ <: S] = {
    if (shouldCall(source, context)) {
      context.notificationManager.startPhase(context.nameIdentifier, this)
      val result = doCall(source, context)
      context.notificationManager.endPhase(context.nameIdentifier, this, result)
      result
    } else {
      SuccessResult(source, context)
    }

  }

  final def doCall(source: S, context: ParsingContext): PhaseResult[_ <: S] = {
    run(source, context)
    SuccessResult(source, context)
  }
}

trait FullCompileOnlyPhase[T <: AstNode, S <: AstNodeResultAware[T]] extends ConditionalPhase[S] {
  override def shouldCall(source: S, context: ParsingContext): Boolean = {
    !source.astNode.isAnnotatedWith(classOf[PreCompiledTypeAnnotation])
  }
  override def run(source: S, context: ParsingContext): Unit
}

/**
  * Applies a transformation to the tree and returns the transformed tree
  *
  * @tparam S Source tree type
  * @tparam R Return type
  */
trait CompilationPhase[S, +R] {
  def call(source: S, context: ParsingContext): PhaseResult[_ <: R] = {
    context.notificationManager.startPhase(context.nameIdentifier, this)
    val result = doCall(source, context)
    context.notificationManager.endPhase(context.nameIdentifier, this, result)
    result
  }

  def doCall(source: S, context: ParsingContext): PhaseResult[_ <: R]

  def chainWith[B >: R, Q](next: CompilationPhase[B, Q]): CompilationPhase[S, Q] = {
    new CompositeCompilationPhase(this, next)
  }

  def enrichWith[B >: R](next: CompilationPhase[B, B]): CompilationPhase[S, B] = {
    new EnrichedCompilationPhase(this, next)
  }
}

/**
  * This phases doesn't modify the AST nor generate new data structures in the compilation Context.
  * It gets as input the same that it outputs
  */
trait VerificationPhase[T <: AstNode, S <: AstNodeResultAware[T]] extends ConditionalPhase[S] {

  override def shouldCall(source: S, context: ParsingContext): Boolean =
    !context.skipVerification && !source.astNode.isAnnotatedWith(classOf[PreCompiledTypeAnnotation])

  final override def run(source: S, context: ParsingContext): Unit = verify(source, context)

  def verify(source: S, context: ParsingContext): Unit
}

class CompositeCompilationPhase[S, R, Q](first: CompilationPhase[S, R], second: CompilationPhase[_ >: R, Q]) extends CompilationPhase[S, Q] {
  override def doCall(source: S, context: ParsingContext): PhaseResult[_ <: Q] = {
    val result: PhaseResult[R] = first.call(source, context)
    if ((context.strictMode && result.hasErrors()) || (!context.strictMode && result.noResult())) {
      FailureResult(context)
    } else {
      second.call(result.getResult(), context)
    }
  }
}

class EnrichedCompilationPhase[S, R](first: CompilationPhase[S, R], second: CompilationPhase[_ >: R, R]) extends CompilationPhase[S, R] {
  override def doCall(source: S, context: ParsingContext): PhaseResult[_ <: R] = {
    val result: PhaseResult[R] = first.call(source, context)
    if ((context.strictMode && result.hasErrors()) || (!context.strictMode && result.noResult())) {
      FailureResult(context)
    } else {
      second.call(result.getResult(), context)
    }
  }
}

/**
  * The result of a compilation phase
  *
  * @param result           The result value
  * @param messageCollector The parsing context
  * @tparam T The phase return type
  */
@WeaveApi(Seq("data-weave-agent"))
class PhaseResult[+T](result: Option[T], messageCollector: MessageCollector) {

  def noResult(): Boolean = !hasResult()

  /**
    * The phase result value. If empty an exception will be thrown
    *
    * @return
    */
  def getResult(): T = {
    result.getOrElse({
      throw new CompilationException(messageCollector)
    })
  }

  /**
    * Returns the result
    *
    * @return
    */
  def mayBeResult: Option[T] = result

  /**
    * Returns the message from this result
    *
    * @return
    */
  def messages(): MessageCollector = {
    messageCollector
  }

  /**
    * If result is present of not
    *
    * @return
    */
  def hasResult(): Boolean = result.isDefined

  /**
    * True if no result is present
    *
    * @return
    */
  def isEmpty(): Boolean = result.isEmpty

  /**
    * True if errors are present
    *
    * @return
    */
  def hasErrors(): Boolean = messageCollector.errorMessages.nonEmpty

  /**
    * True if no errors are present
    *
    * @return
    */
  def noErrors(): Boolean = messageCollector.errorMessages.isEmpty

  def errorMessages(): Seq[(WeaveLocation, Message)] = {
    messageCollector.errorMessages
  }

  def warningMessages(): Seq[(WeaveLocation, Message)] = {
    messageCollector.warningMessages
  }

  def onSuccess[B](f: T => PhaseResult[B]): PhaseResult[B] = {
    if (hasResult())
      f(result.get)
    else
      FailureResult(messages())
  }

  def map[B](f: T => B): PhaseResult[B] = {
    if (hasResult())
      SuccessResult(f(result.get), messages())
    else
      FailureResult(messages())
  }

}

object FailureResult {

  def apply[T](context: ParsingContext) = new PhaseResult[T](None, context.messageCollector)

  def apply[T](context: MessageCollector) = new PhaseResult[T](None, context)
}

object SuccessResult {
  def apply[T](result: T, context: ParsingContext) = new PhaseResult[T](Some(result), context.messageCollector)

  def apply[T](result: T, messageCollector: MessageCollector) = new PhaseResult[T](Some(result), messageCollector)
}
