package ai.minxiao.ds4s.core.nlp.embeds

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import scala.collection.immutable.{HashMap, HashSet}
import scala.io.Source

import breeze.linalg.{DenseMatrix, DenseVector, Transpose}

/**
  * IO Impls
  * 
  * @author mx
  */
trait IOImpl {
  def load(path: String, format: String): (HashMap[String, Int], DenseMatrix[Float]) = {
    require(Array("txt", "bin").contains(format))

    def loadText: (HashMap[String, Int], DenseMatrix[Float]) = {
      val buf = Source.fromFile(path)
      val lines = buf.getLines

      val pair = {
        for {
          (line, id) <- lines.zipWithIndex
          we = line.split(" ")
          w = we.head
          e = we.tail.map(_.toFloat)
        } yield (w -> id, e)
      }.toArray

      buf.close

      val vSize = pair.size
      val dim = pair.head._2.length

      val words = pair.map(_._1).toMap.asInstanceOf[HashMap[String, Int]]

      val reps = pair.flatMap(_._2)
      val embeds = (new DenseMatrix(dim, vSize, reps)).t

      (words, embeds)
    }

    def loadBinary: (HashMap[String, Int], DenseMatrix[Float]) = {
      val objW = new ObjectInputStream(new FileInputStream(path + ".words"))
      val words = objW.readObject.asInstanceOf[HashMap[String, Int]]
      objW.close()

      val objE = new ObjectInputStream(new FileInputStream(path + ".embeds"))
      val embArr = objE.readObject.asInstanceOf[Array[Float]]
      val vSize = words.size
      val dim = embArr.size / vSize
      val embeds = new DenseMatrix(vSize, dim, embArr)

      (words, embeds)
    }

    (format: @unchecked) match {
      case "txt" => loadText
      case "bin" => loadBinary
    }
  }

  def save(words: HashMap[String, Int], embeds: DenseMatrix[Float],
    path: String, format: String): Unit = {

    def saveText: Unit = {
      val file = new File(path)
      val bw = new BufferedWriter(new FileWriter(file))
      val text = {
        for {
          (w, i) <- words.iterator
          we <- embeds(i, ::).t.toArray.mkString(w + " ", " ", "\n")
        } yield we
      }.mkString("")
      bw.write(text)
      bw.close()
    }

    def saveBinary: Unit = {
      val objW = new ObjectOutputStream(new FileOutputStream(path + ".words"))
      objW.writeObject(words)
      objW.close()

      val objE = new ObjectOutputStream(new FileOutputStream(path + ".embeds"))
      objE.writeObject(embeds.toArray)
      objE.close()
    }

    (format: @unchecked) match {
      case "txt" => saveText
      case "bin" => saveBinary
    }
  }
}
