package ai.minxiao.ds4s.core.dl4j.embed

import org.datavec.api.util.ClassPathResource
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
import org.deeplearning4j.models.glove.Glove
import org.deeplearning4j.models.word2vec.VocabWord
import org.deeplearning4j.models.word2vec.wordstore.VocabCache
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache
import org.deeplearning4j.text.sentenceiterator.{BasicLineIterator, SentenceIterator}
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor
import org.deeplearning4j.text.tokenization.tokenizerfactory.{DefaultTokenizerFactory, TokenizerFactory}
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess
import org.slf4j.{Logger, LoggerFactory}

/**
  * GVec: Global Vectors for Word Representation
  * <a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/Glove.java">>
  * Glove</a>
  * @note <a href="https://www.aclweb.org/anthology/D14-1162">Jeffrey Pennington and Richard Socher and Christopher D. Manning. Glove: Global Vectors for Word Representation. In EMNLP, 2014.</a>
  * @constructor
  * @param minWordFrequency min word frequency, default=5
  * @param layerSize word vector dimensionality, default=100
  * @param windowSize window size, default=5
  * @param epochs training epochs, default=1
  * @param batchSize batch size, default=12
  * @param sampling sampling rate, default=0(no sampling)
  * @param learningRate learning rate, default=1E-1
  * @param minLearningRate mini learning rate, default=1E-6
  * @param useAdaGrad whether to use AdaGrad, default=false
  * @param tokenizerFactory tokenizer factory, default=new DefaultTokenizerFactory
  * @param tokenPreProcessor token preprocessor, default=new CommonPreprocessor
  * @param seed seed for random generator, default=2018
  * @param alpha
  * @param xMax
  * @param shuffle
  * @param symmetric
  * @param iterate
  *
  * @author mx
  */
@SerialVersionUID(875086L)
class GVec(
  minWordFrequency: Int = 5,
  layerSize: Int = 100,
  windowSize: Int = 5,
  epochs: Int = 1,
  batchSize: Int = 16,
  final val iterations: Int = 1,
  learningRate: Double = 1E-1,
  minLearningRate: Double = 1E-6,
  useAdaGrad: Boolean = false,
  tokenizerFactory: TokenizerFactory = new DefaultTokenizerFactory,
  tokenPreProcessor: TokenPreProcess = new CommonPreprocessor,
  lookupTable: Option[InMemoryLookupTable[VocabWord]] = None,
  vocabCache: Option[AbstractCache[VocabWord]] = None,
  seed: Long = 2018L,
  alpha: Double = 0.75,
  xMax: Double = 100,
  shuffle: Boolean = true,
  symmetric: Boolean = true,
  iterate: SentenceIterator
) extends Serializable {

  def build(): Glove = {
    val gloveBuilder = new Glove.Builder().
      minWordFrequency(minWordFrequency).
      layerSize(layerSize).
      windowSize(windowSize).
      epochs(epochs).
      batchSize(batchSize).
      iterations(iterations).
      learningRate(learningRate).
      minLearningRate(minLearningRate).
      useAdaGrad(useAdaGrad).
      tokenizerFactory({
        tokenizerFactory.setTokenPreProcessor(tokenPreProcessor)
        tokenizerFactory
      }).
      seed(seed).
      alpha(alpha).
      xMax(xMax).
      shuffle(shuffle).
      symmetric(symmetric).
      iterate(iterate)

    if (lookupTable != None) gloveBuilder.lookupTable(lookupTable.get)
    if (vocabCache != None) gloveBuilder.vocabCache(vocabCache.get)

    gloveBuilder.build
  }
}
