package org.mule.weave.v2.module.dwb.writer

import java.io.DataOutput
import java.io.FilterOutputStream
import java.io.IOException
import java.io.OutputStream
import java.io.UTFDataFormatException

import org.mule.weave.v2.module.dwb.WeaveBinaryUtils

/**
  * A data output stream lets an application write primitive data
  * types to an output stream in a portable way. An application can
  * then use a data input stream to read the data back in.
  */
class LongCountDataOutputStream(out: OutputStream) extends FilterOutputStream(out) with DataOutput {
  /**
    * The number of bytes written to the data output stream so far.
    */
  var written = 0L

  /**
    * bytearr is initialized on demand by writeUTF
    */
  private var bytearr: Array[Byte] = _
  private var count = 0

  private def incCount(value: Int): Unit = {
    written += value
  }

  /**
    * Writes the specified byte (the low eight bits of the argument
    * <code>b</code>) to the underlying output stream. If no exception
    * is thrown, the counter <code>written</code> is incremented by
    * <code>1</code>.
    * <p>
    * Implements the <code>write</code> method of <code>OutputStream</code>.
    *
    * @param      b the <code>byte</code> to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def write(b: Int): Unit = {
    out.write(b)
    incCount(1)
  }

  /**
    * Writes <code>len</code> bytes from the specified byte array
    * starting at offset <code>off</code> to the underlying output stream.
    * If no exception is thrown, the counter <code>written</code> is
    * incremented by <code>len</code>.
    *
    * @param      b   the data.
    * @param      off the start offset in the data.
    * @param      len the number of bytes to write.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def write(b: Array[Byte], off: Int, len: Int): Unit = {
    out.write(b, off, len)
    incCount(len)
  }

  /**
    * Flushes this data output stream. This forces any buffered output
    * bytes to be written out to the stream.
    * <p>
    * The <code>flush</code> method of <code>DataOutputStream</code>
    * calls the <code>flush</code> method of its underlying output stream.
    *
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    * @see java.io.OutputStream#flush()
    */
  @throws[IOException]
  override def flush(): Unit = {
    out.flush()
  }

  /**
    * Writes a <code>boolean</code> to the underlying output stream as
    * a 1-byte value. The value <code>true</code> is written out as the
    * value <code>(byte)1</code>; the value <code>false</code> is
    * written out as the value <code>(byte)0</code>. If no exception is
    * thrown, the counter <code>written</code> is incremented by
    * <code>1</code>.
    *
    * @param      v a <code>boolean</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeBoolean(v: Boolean): Unit = {
    out.write(if (v) 1
    else 0)
    incCount(1)
  }

  /**
    * Writes out a <code>byte</code> to the underlying output stream as
    * a 1-byte value. If no exception is thrown, the counter
    * <code>written</code> is incremented by <code>1</code>.
    *
    * @param      v a <code>byte</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeByte(v: Int): Unit = {
    out.write(v)
    incCount(1)
  }

  /**
    * Writes a <code>short</code> to the underlying output stream as two
    * bytes, high byte first. If no exception is thrown, the counter
    * <code>written</code> is incremented by <code>2</code>.
    *
    * @param      v a <code>short</code> to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeShort(v: Int): Unit = {
    out.write((v >>> 8) & 0xFF)
    out.write((v >>> 0) & 0xFF)
    incCount(2)
  }

  /**
    * Writes a <code>char</code> to the underlying output stream as a
    * 2-byte value, high byte first. If no exception is thrown, the
    * counter <code>written</code> is incremented by <code>2</code>.
    *
    * @param      v a <code>char</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeChar(v: Int): Unit = {
    out.write((v >>> 8) & 0xFF)
    out.write((v >>> 0) & 0xFF)
    incCount(2)
  }

  /**
    * Writes an <code>int</code> to the underlying output stream as four
    * bytes, high byte first. If no exception is thrown, the counter
    * <code>written</code> is incremented by <code>4</code>.
    *
    * @param      v an <code>int</code> to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeInt(v: Int): Unit = {
    out.write((v >>> 24) & 0xFF)
    out.write((v >>> 16) & 0xFF)
    out.write((v >>> 8) & 0xFF)
    out.write((v >>> 0) & 0xFF)
    incCount(4)
  }

  private val writeBuffer = new Array[Byte](8)

  /**
    * Writes a <code>long</code> to the underlying output stream as eight
    * bytes, high byte first. In no exception is thrown, the counter
    * <code>written</code> is incremented by <code>8</code>.
    *
    * @param      v a <code>long</code> to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeLong(v: Long): Unit = {
    writeBuffer(0) = (v >>> 56).toByte
    writeBuffer(1) = (v >>> 48).toByte
    writeBuffer(2) = (v >>> 40).toByte
    writeBuffer(3) = (v >>> 32).toByte
    writeBuffer(4) = (v >>> 24).toByte
    writeBuffer(5) = (v >>> 16).toByte
    writeBuffer(6) = (v >>> 8).toByte
    writeBuffer(7) = (v >>> 0).toByte
    out.write(writeBuffer, 0, 8)
    incCount(8)
  }

  /**
    * Converts the float argument to an <code>int</code> using the
    * <code>floatToIntBits</code> method in class <code>Float</code>,
    * and then writes that <code>int</code> value to the underlying
    * output stream as a 4-byte quantity, high byte first. If no
    * exception is thrown, the counter <code>written</code> is
    * incremented by <code>4</code>.
    *
    * @param      v a <code>float</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    * @see java.lang.Float#floatToIntBits(float)
    */
  @throws[IOException]
  override def writeFloat(v: Float): Unit = {
    writeInt(java.lang.Float.floatToIntBits(v))
  }

  /**
    * Converts the double argument to a <code>long</code> using the
    * <code>doubleToLongBits</code> method in class <code>Double</code>,
    * and then writes that <code>long</code> value to the underlying
    * output stream as an 8-byte quantity, high byte first. If no
    * exception is thrown, the counter <code>written</code> is
    * incremented by <code>8</code>.
    *
    * @param      v a <code>double</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    * @see java.lang.Double#doubleToLongBits(double)
    */
  @throws[IOException]
  override def writeDouble(v: Double): Unit = {
    writeLong(java.lang.Double.doubleToLongBits(v))
  }

  /**
    * Writes out the string to the underlying output stream as a
    * sequence of bytes. Each character in the string is written out, in
    * sequence, by discarding its high eight bits. If no exception is
    * thrown, the counter <code>written</code> is incremented by the
    * length of <code>s</code>.
    *
    * @param      s a string of bytes to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeBytes(s: String): Unit = {
    val len = s.length
    var i = 0
    while (i < len) {
      out.write(s.charAt(i).toByte)
      i += 1
    }
    incCount(len)
  }

  /**
    * Writes a string to the underlying output stream as a sequence of
    * characters. Each character is written to the data output stream as
    * if by the <code>writeChar</code> method. If no exception is
    * thrown, the counter <code>written</code> is incremented by twice
    * the length of <code>s</code>.
    *
    * @param      s a <code>String</code> value to be written.
    * @exception IOException  if an I/O error occurs.
    * @see java.io.DataOutputStream#writeChar(int)
    * @see java.io.FilterOutputStream#out
    */
  @throws[IOException]
  override def writeChars(s: String): Unit = {
    val len = s.length
    var i = 0
    while (i < len) {
      val v = s.charAt(i)
      out.write((v >>> 8) & 0xFF)
      out.write((v >>> 0) & 0xFF)
      i += 1
    }
    incCount(len * 2)
  }

  /**
    * Writes a string to the specified DataOutput using
    * <a href="DataInput.html#modified-utf-8">modified UTF-8</a>
    * encoding in a machine-independent manner.
    * <p>
    * First, two bytes are written to out as if by the <code>writeShort</code>
    * method giving the number of bytes to follow. This value is the number of
    * bytes actually written out, not the length of the string. Following the
    * length, each character of the string is output, in sequence, using the
    * modified UTF-8 encoding for the character. If no exception is thrown, the
    * counter <code>written</code> is incremented by the total number of
    * bytes written to the output stream. This will be at least two
    * plus the length of <code>str</code>, and at most two plus
    * thrice the length of <code>str</code>.
    *
    * @param      str a string to be written.
    * @param      out destination to write to
    * @exception IOException  if an I/O error occurs.
    */
  @throws[IOException]
  override def writeUTF(str: String): Unit = {
    val strlen = str.length
    val utflen = getUTFByteLength(str)
    var c = 0

    if (utflen > 65535) {
      throw new UTFDataFormatException("encoded string too long: " + utflen + " bytes")
    }
    initByteArray(utflen)
    writeShort(utflen)

    var i = 0
    while (i < strlen) {
      c = str.charAt(i)
      if ((c >= 0x0001) && (c <= 0x007F)) {
        addToByteArray(c)
      } else if (c > 0x07FF) {
        addToByteArray(0xE0 | ((c >> 12) & 0x0F))
        addToByteArray(0x80 | ((c >> 6) & 0x3F))
        addToByteArray(0x80 | ((c >> 0) & 0x3F))
      } else {
        addToByteArray(0xC0 | ((c >> 6) & 0x1F))
        addToByteArray(0x80 | ((c >> 0) & 0x3F))
      }
      i += 1
    }
    write(bytearr, 0, utflen)
  }

  private def initByteArray(utflen: Int): Unit = {
    count = 0
    if (bytearr == null || (bytearr.length < (utflen + 2))) {
      bytearr = new Array[Byte]((utflen * 2) + 2)
    }
  }

  private def addToByteArray(c: Int): Unit = {
    bytearr(count) = c.toByte
    count += 1
  }

  private def getUTFByteLength(str: String): Int = {
    WeaveBinaryUtils.getUTFByteLength(str)
  }

  def size() = written
}
