package ai.minxiao.ds4s.core.nlp.embeds

import scala.collection.immutable.{HashMap, HashSet}
import scala.io.Source

import breeze.linalg.{DenseMatrix, DenseVector, Transpose}

/**
  * Embeds Trait
  *
  * @author mx
  */
trait Embeds {

  /* */
  def wEmb(w: String): Option[DenseVector[Float]]
  def sEmb(s: Array[String]): Option[DenseVector[Float]]
  def sEmb(s: String): Option[DenseVector[Float]]

  def save(path: String, format: String): Unit
}

@SerialVersionUID(6998100L)
object Embeds extends Serializable {
  def apply(path: String, format: String): Embeds =
    new EmbedsImpl(path, format)

  @SerialVersionUID(6973L)
  private class EmbedsImpl(embPath: String, embFormat: String)
      extends Embeds with IOImpl with Serializable {

    private val (words, embeds) = load(embPath, embFormat)
    private val vSize = words.size
    private val dim = embeds.cols

    override def wEmb(w: String): Option[DenseVector[Float]] =
      if (words.get(w) == None) None
      else Some(embeds(words(w), ::).t)

    override def sEmb(s: Array[String]): Option[DenseVector[Float]] = {
      val len = s.map(words.get(_)).filter(_ != None).size
      if (len == 0) None
      else Some(
        s.map(wEmb).filter(_ != None).map(_.get).reduce(_ + _) / len.toFloat
      )
    }

    override def sEmb(s: String): Option[DenseVector[Float]] =
      sEmb(s.split(" "))

    def save(path: String, format: String): Unit =
      save(words, embeds, path, format)

    override def toString =
      s"Embeds: vSize = $vSize, dim = $dim"
  }
}
