package org.mule.weave.v2.parser.ast

import org.mule.weave.v2.api.tooling.ast.DWAstNodeKind
import org.mule.weave.v2.parser.annotation.AstNodeAnnotation
import org.mule.weave.v2.parser.annotation.EnclosedMarkAnnotation
import org.mule.weave.v2.parser.ast.annotation.AnnotationCapableNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.location.Position
import org.mule.weave.v2.parser.location.UnknownLocation
import org.mule.weave.v2.parser.location.WeaveLocation

import scala.collection.TraversableOnce
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * Represents all the elements in the DataWeave Syntax Tree
  */
trait AstNode extends WeaveLocationCapable {

  /**
    * If contains the annotations
    */
  private var _annotations: Array[AstNodeAnnotation] = Array.empty

  /**
    * It contains the comments
    */
  private var _comments: Array[CommentNode] = Array.empty

  def isLeaf() = children().isEmpty

  /**
    * Returns the children nodes
    *
    * @return The list of nodes
    */
  def children(): Seq[AstNode] = Children.Empty

  /**
    * Returns the list of annotations of this node
    *
    * @return The annotations
    */
  def annotations(): Seq[AstNodeAnnotation] = _annotations

  /**
    * Annotates this node with the specified annotation
    */
  def annotate(annotation: AstNodeAnnotation): this.type = {
    _annotations = _annotations :+ (annotation)
    this
  }

  def annotation[T <: AstNodeAnnotation](annoType: Class[T]): Option[T] = _annotations.find((annotation) => annotation.getClass.equals(annoType)).map(annoType.cast(_))

  def annotationsBy[T <: AstNodeAnnotation](annoType: Class[T]): Seq[T] = _annotations.filter(annotation => annotation.getClass.equals(annoType)).map(annoType.cast(_))

  def addComment(comment: CommentNode): Unit = {
    _comments = _comments :+ (comment)
  }

  def isAnnotatedWith[T <: AstNodeAnnotation](annoType: Class[T]): Boolean = _annotations.exists((annotation) => annotation.getClass.equals(annoType))

  def comments: Seq[CommentNode] = _comments

  def weaveDoc: Option[CommentNode] = _comments.find(_.commentType == CommentType.DocComment)

  /**
    * Returns the semantic start location. This location takes into account if this node is enclosed in parenthesis or not
    *
    * @return
    */
  def semanticStartPosition(): Position = {
    annotation(classOf[EnclosedMarkAnnotation]) match {
      case Some(ea) => ea.location.startPosition
      case _        => location().startPosition
    }
  }

  def hasWeaveDoc: Boolean = weaveDoc.isDefined

  def cloneAst(): AstNode = {
    val astNode = doClone()
    copyAnnotationsAndCommentsTo(astNode)
    astNode._location = _location
    astNode
  }

  def copyAnnotationsAndCommentsTo(astNode: AstNode): Unit = {
    astNode._annotations.++=(_annotations)
    astNode._comments.++=(_comments)
  }

  def copyCommentsTo(astNode: AstNode): Unit = {
    astNode._comments.++=(_comments)
  }

  def getKind(): String

  protected def doClone(): AstNode
}

trait VirtualAstNode extends AstNode

/**
  * Marker Trait that specifies that the given Node represents an Expression, it evaluates and returns a Value.
  */
trait ExpressionAstNode extends AnnotationCapableNode

trait NamedAstNode extends AstNode {
  def nameIdentifier: NameIdentifier
}

/**
  * Represents a node that it is just a container of nodes.
  * The position will be calculated based on children
  */
trait ContainerAstNode extends AstNode {

  private lazy val containerLocation: WeaveLocation = {
    val childrenLocations = children().map(_.location()).filterNot(_ eq UnknownLocation)
    if (childrenLocations.isEmpty) {
      UnknownLocation
    } else {
      val startPositions = childrenLocations.map(_.startPosition)
      val endPositions = childrenLocations.map(_.endPosition)
      WeaveLocation(startPositions.min, endPositions.max, children().head.location().resourceName)
    }
  }

  override def location(): WeaveLocation = {
    if (_location.isDefined) {
      _location.get
    } else {
      containerLocation
    }
  }
}

/**
  * Collection used to represent the children nodes of a given node
  */
class Children {

  private val builder = new ArrayBuffer[AstNode]()

  def +=(element: Option[AstNode]): Children = {
    element match {
      case Some(value) => builder.+=(value)
      case None        =>
    }
    this
  }

  def ++=(element: Option[TraversableOnce[AstNode]]): Children = {
    element match {
      case Some(value) => builder.++=(value)
      case None        =>
    }
    this
  }

  def +=(element: AstNode): Children = {
    builder.+=(element)
    this
  }

  def ++=(element: TraversableOnce[AstNode]): Children = {
    builder.++=(element)
    this
  }

  def result(): Seq[AstNode] = builder.result()

}

object Children {

  val Empty: Seq[AstNode] = Seq.empty

  def apply() = new Children()

  def apply(elements: AstNode*): Children = new Children().++=(elements)

  def apply(elements: AstNode): Children = new Children().+=(elements)

  def apply(element1: AstNode, element2: AstNode): Children = new Children().+=(element1).+=(element2)
}

object Child {

  def apply(element1: AstNode): Array[AstNode] = {
    val result = new Array[AstNode](1)
    result.update(0, element1)
    result
  }

  def apply(element1: AstNode, element2: AstNode): Array[AstNode] = {
    val result = new Array[AstNode](2)
    result.update(0, element1)
    result.update(1, element2)
    result
  }

  def apply(element1: AstNode, element2: AstNode, optional: Option[AstNode]): Array[AstNode] = {
    if (optional.isDefined) {
      Child(element1, element2, optional.get)
    } else {
      Child(element1, element2)
    }
  }

  def apply(optional: Option[AstNode]): Seq[AstNode] = {
    if (optional.isDefined) {
      Child(optional.get)
    } else {
      Seq.empty
    }
  }

  def apply(element1: AstNode, element2: AstNode, element3: AstNode): Array[AstNode] = {
    val result = new Array[AstNode](3)
    result.update(0, element1)
    result.update(1, element2)
    result.update(2, element3)
    result
  }

  def apply(element1: AstNode, element2: AstNode, element3: AstNode, element4: AstNode): Array[AstNode] = {
    val result = new Array[AstNode](4)
    result.update(0, element1)
    result.update(1, element2)
    result.update(2, element3)
    result.update(3, element4)
    result
  }
}

/**
  * Trait that marks that a given node can update its children
  */
trait MutableAstNode {

  /**
    * Replace the child node that is equals to the given node with the new one
    *
    * @param toBeReplaced The node to be replaced
    * @param withNode     The replacement
    */
  def update(toBeReplaced: AstNode, withNode: AstNode)
}

/**
  * Represents a node that is a literal value.
  */
trait LiteralValueAstNode extends AstNode {

  /**
    * The literal text written by the user
    */
  val literalValue: String
}

trait LiteralExpressionValueAstNode extends LiteralValueAstNode with ExpressionAstNode {}

case class CommentNode(literalValue: String, commentType: CommentType.Value) extends LiteralValueAstNode {
  override protected def doClone(): AstNode = {
    copy()
  }

  override def getKind(): String = DWAstNodeKind.COMMENT_NODE
}

object CommentType extends Enumeration {
  val LineComment, BlockComment, DocComment = Value
}
