package ai.minxiao.ds4s.core.dl4j.embed

import org.datavec.api.util.ClassPathResource
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors
import org.deeplearning4j.models.word2vec.{VocabWord, Word2Vec}
import org.deeplearning4j.models.word2vec.wordstore.VocabCache
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache
import org.deeplearning4j.text.documentiterator.LabelsSource
import org.deeplearning4j.text.sentenceiterator.{BasicLineIterator, SentenceIterator}
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor
import org.deeplearning4j.text.tokenization.tokenizerfactory.{DefaultTokenizerFactory, TokenizerFactory}
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess
import org.slf4j.{Logger, LoggerFactory}

/**
  * D2V
  * Doc2Vec (Paragraph Vector):
  * <a href="https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java">
  * Doc2Vec (ParagraphVectors)</a>
  * @note reference: <a href="https://cs.stanford.edu/~quocle/paragraph_vector.pdf">Quoc Le and Tomas Mikolov. Distributed Representations of Sentences and Documents. In ICML, 2014. </a>
  * @constructor
  * @param minWordFrequency min word frequency, default=5
  * @param layerSize word vector dimensionality, default=100
  * @param windowSize window size, default=5
  * @param epochs training epochs, default=1
  * @param batchSize batch size, default=12
  * @param sampling sampling rate, default=0(no sampling)
  * @param learningRate learning rate, default=1E-1
  * @param minLearningRate mini learning rate, default=1E-6
  * @param useAdaGrad whether to use AdaGrad, default=false
  * @param tokenizerFactory tokenizer factory, default=new DefaultTokenizerFactory
  * @param tokenPreProcessor token preprocessor, default=new CommonPreprocessor
  * @param seed seed for random generator, default=2018
  * @param trainWordVectors whether to train word vectors together with document vectors, default=true (PV-DM), false for PV-DBOW
  * @param labelsSource label source, default=None
  *
  * @note <a href="https://cs.stanford.edu/~quocle/paragraph_vector.pdf">Quoc Le, Tomas Mikolov. Distributed Representations of Sentences and Documents. In ICML, 2014.</a>
  *
  * @author mx
  */
@SerialVersionUID(875086L)
class D2V(
  minWordFrequency: Int = 5,
  layerSize: Int = 100,
  windowSize: Int = 5,
  epochs: Int = 1,
  batchSize: Int = 16,
  final val iterations: Int = 1,
  sampling: Double = 0,
  learningRate: Double = 1E-1,
  minLearningRate: Double = 1E-6,
  useAdaGrad: Boolean = false,
  tokenizerFactory: TokenizerFactory = new DefaultTokenizerFactory,
  tokenPreProcessor: TokenPreProcess = new CommonPreprocessor,
  lookupTable: Option[InMemoryLookupTable[VocabWord]] = None,
  vocabCache: Option[AbstractCache[VocabWord]] = None,
  seed: Long = 2018L,
  trainWordVectors: Boolean = true,
  labelsSource: Option[LabelsSource] = None
) extends Serializable {

  def build(): ParagraphVectors = {
    val doc2VecBuilder = new ParagraphVectors.Builder().
      minWordFrequency(minWordFrequency).
      layerSize(layerSize).
      windowSize(windowSize).
      epochs(epochs).
      batchSize(batchSize).
      iterations(iterations).
      sampling(sampling).
      learningRate(learningRate).
      minLearningRate(minLearningRate).
      useAdaGrad(useAdaGrad).
      tokenizerFactory({
        tokenizerFactory.setTokenPreProcessor(tokenPreProcessor)
        tokenizerFactory
      }).
      seed(seed).
      trainWordVectors(trainWordVectors)

    if (lookupTable != None) doc2VecBuilder.lookupTable(lookupTable.get)
    if (vocabCache != None) doc2VecBuilder.vocabCache(vocabCache.get)
    if (labelsSource != None) doc2VecBuilder.labelsSource(labelsSource.get)

    doc2VecBuilder.build
  }
}
