package org.mule.weave.v2.compilation.serializer

import org.mule.weave.v2.compilation.ArraySerializableAstNode
import org.mule.weave.v2.compilation.BooleanSerializableValueAstNode
import org.mule.weave.v2.compilation.IntSerializableValueAstNode
import org.mule.weave.v2.compilation.NodeSerializableAstNode
import org.mule.weave.v2.compilation.NoneSerializableValueAstNode
import org.mule.weave.v2.compilation.SerializableAstNode
import org.mule.weave.v2.compilation.SerializableAstNodeKind
import org.mule.weave.v2.compilation.SerializableAstNodeLocation
import org.mule.weave.v2.compilation.StringSerializableValueAstNode
import org.mule.weave.v2.compilation.exception.DeserializationException
import org.mule.weave.v2.compilation.exception.IncompatibleBinarySerializationVersion
import org.mule.weave.v2.compilation.exception.MissingBinarySerializationDeclaration
import org.mule.weave.v2.version.DataWeaveRuntimeVersion

import java.io.{ DataInputStream, DataOutputStream, EOFException }
import java.util
import java.util.EmptyStackException
import scala.collection.mutable

object SerializableAstNodeSerializer {
  private val binaryVersion = DataWeaveRuntimeVersion.binaryWeaveVersion.split("\\.").map(_.toShort)
  val DW_MAGIC_NUM = 0xDA7ADA7A

  def serialize(astNode: SerializableAstNode, outputStream: DataOutputStream): Unit = {
    outputStream.writeInt(DW_MAGIC_NUM)
    outputStream.writeShort(binaryVersion.head)
    outputStream.writeShort(binaryVersion(1))

    doSerialize(astNode, outputStream)
  }

  private def doSerialize(astNode: SerializableAstNode, outputStream: DataOutputStream): Unit = {
    astNode match {
      case nodeAstNode: NodeSerializableAstNode => // | KIND(INT) CHILDREN_COUNT(INT) INT INT |
        nodeAstNode.children.reverse.foreach(child => doSerialize(child, outputStream))
        outputStream.writeShort(nodeAstNode.kind)
        outputStream.writeInt(nodeAstNode.children.length)
        // values don't have location, only here!
        outputStream.writeInt(astNode.location().startIndex)
        outputStream.writeInt(astNode.location().endIndex)
      case arrayAstNode: ArraySerializableAstNode => // | KIND(INT) CHILDREN_COUNT |
        arrayAstNode.children.reverse.foreach(child => doSerialize(child, outputStream))
        outputStream.writeShort(arrayAstNode.kind())
        outputStream.writeInt(arrayAstNode.children.length)
      case stringValueAstNode: StringSerializableValueAstNode => // | KIND(INT) STRING(UTF_8) |
        outputStream.writeShort(stringValueAstNode.kind())
        outputStream.writeUTF(stringValueAstNode.value)
      case booleanValueAstNode: BooleanSerializableValueAstNode => // | KIND(BOOLEAN) BOOLEAN(BYTE) |
        outputStream.writeShort(booleanValueAstNode.kind())
        outputStream.writeBoolean(booleanValueAstNode.value)
      case intValueAstNode: IntSerializableValueAstNode => // | KIND(INT_VAL) INT_VAL(INT)
        outputStream.writeShort(intValueAstNode.kind())
        outputStream.writeInt(intValueAstNode.value)
      case _: NoneSerializableValueAstNode => // | KIND(INT) |
        outputStream.writeShort(astNode.kind())
    }
  }

  def deserialize(dataInputStream: DataInputStream): SerializableAstNode = {
    val stack = new util.Stack[SerializableAstNode]()

    def extract(count: Int): Array[SerializableAstNode] = {
      val child: mutable.ArrayBuffer[SerializableAstNode] = mutable.ArrayBuffer()
      for (w <- 1 to count) {
        try {
          child.+=(stack.pop())
        } catch {
          case e: EmptyStackException => throw new IllegalStateException(s"Parser requested #$count elements but the stack has less nodes stacked.", e)
        }
      }
      child.toArray
    }

    validateBinaryHeader(dataInputStream)

    var eof = false
    var kind = dataInputStream.readShort()
    while (!eof) {
      val node = kind match {
        case SerializableAstNodeKind.STRING_VALUE_NODE =>
          val value = dataInputStream.readUTF()
          StringSerializableValueAstNode(value)
        case SerializableAstNodeKind.BOOLEAN_VALUE_NODE =>
          val value = dataInputStream.readBoolean()
          BooleanSerializableValueAstNode(value)
        case SerializableAstNodeKind.INT_VALUE_NODE =>
          val value = dataInputStream.readInt()
          IntSerializableValueAstNode(value)
        case SerializableAstNodeKind.NONE_VALUE_NODE =>
          NoneSerializableValueAstNode()
        case SerializableAstNodeKind.ARRAY_NODE => {
          val childrenCount = dataInputStream.readInt()
          ArraySerializableAstNode(extract(childrenCount))
        }
        case _ => // NodeSerializableAstNode (any type node)
          val childrenCount = dataInputStream.readInt()
          val location = SerializableAstNodeLocation.apply(dataInputStream.readInt(), dataInputStream.readInt())
          NodeSerializableAstNode(kind, location, extract(childrenCount))
      }

      stack.push(node)
      try {
        kind = dataInputStream.readShort()
      } catch {
        case _: EOFException => eof = true
      }
    }

    val astNode = stack.pop()
    if (!stack.empty()) {
      throw new IllegalStateException("Parser stack not empty, orphan nodes remain")
    }
    astNode
  }

  private def validateBinaryHeader(dataInputStream: DataInputStream): Unit = {
    var dwMagic: Int = -1
    var binaryMajor: Short = -1
    var binaryMinor: Short = -1

    try {
      dwMagic = dataInputStream.readInt()
      binaryMajor = dataInputStream.readShort()
      binaryMinor = dataInputStream.readShort()
    } catch {
      case e: Throwable => throw MissingBinarySerializationDeclaration(e)
    }

    if (dwMagic != DW_MAGIC_NUM) {
      throw DeserializationException("Corrupt bdwl header, couldn't find DW magic number")
    }

    if (binaryMajor != binaryVersion.head || binaryMinor > binaryVersion(1)) {
      throw IncompatibleBinarySerializationVersion(binaryMajor, binaryMinor)
    }
  }
}

