package ai.minxiao.ds4s.core.dl4j.learning

import scala.collection.JavaConverters._
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
import org.deeplearning4j.models.glove.Glove
import org.deeplearning4j.models.sequencevectors.SequenceVectors
import org.deeplearning4j.models.word2vec.Word2Vec
import org.deeplearning4j.models.word2vec.wordstore.VocabCache
import org.deeplearning4j.nn.graph.ComputationGraph
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.text.sentenceiterator.SentenceIterator
import org.deeplearning4j.ui.api.UIServer
import org.deeplearning4j.ui.stats.StatsListener
import org.deeplearning4j.ui.storage.InMemoryStatsStorage
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.dataset.api.DataSet
import org.nd4j.linalg.dataset.api.iterator.{DataSetIterator, MultiDataSetIterator}

import ai.minxiao.ds4s.core.dl4j.evaluation.Evaluator
import ai.minxiao.ds4s.core.dl4j.ui.UIStarter

/**
  * DL4J Learner
  *
  * @author mx
  */
object Learner {


  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    * @param epochs epochs, default=1
    */
  def run(net: MultiLayerNetwork, trainData: DataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1, epochs: Int = 1): MultiLayerNetwork  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()
    }

    net
  }

  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    * @param epochs epochs, default=1
    */
  def run(net: ComputationGraph, trainData: DataSetIterator, storagePath: Option[String],
    useUI: Boolean, listenFreq: Int, epochs: Int): ComputationGraph  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq, false)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()
    }

    net
  }

  def run(net: ComputationGraph, trainData: DataSetIterator,
    useUI: Boolean, listenFreq: Int): ComputationGraph  = {
    run(net, trainData, None, true, listenFreq, 1)
  }

  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)

  def run(net: MultiLayerNetwork, trainData: DataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1): MultiLayerNetwork  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    net.fit(trainData)

    net
  }
  */


  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param features training data features
    * @param labels training data labels
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    */
  def run(net: MultiLayerNetwork,
    features: INDArray, labels: INDArray,
    useUI: Boolean, listenFreq: Int, epochs: Int): MultiLayerNetwork  = {

    if (useUI) UIStarter.initNewInMemory(net, listenFreq)

    for (epoch <- 0 until epochs) net.fit(features, labels)

    net
  }

  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param features training data features
    * @param labels training data labels
    * @param featureMasks feature masks,
    * @param labelMasks label masks,
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    */
  def run(net: MultiLayerNetwork,
    features: INDArray, labels: INDArray, featureMasks: Option[INDArray], labelMasks: Option[INDArray],
    useUI: Boolean, listenFreq: Int, epochs: Int): MultiLayerNetwork  = {

    if (useUI) UIStarter.initNewInMemory(net, listenFreq)

    for (epoch <- 0 until epochs)
      net.fit(features, labels,
        if (featureMasks != None) featureMasks.get else null,
        if (labelMasks != None) labelMasks.get else null
      )

    net
  }


  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param features training data features
    * @param labels training data labels
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    */
  def run(net: MultiLayerNetwork,
    features: INDArray, labels: INDArray,
    useUI: Boolean, listenFreq: Int): MultiLayerNetwork  = {
    run(net, features, labels, useUI, listenFreq, 1)
  }


  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param features training data features
    * @param labels training data labels
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    */
  def run(net: MultiLayerNetwork,
    features: INDArray, labels: INDArray,
    useUI: Boolean, listenFreq: Int, storagePath: Option[String]): MultiLayerNetwork  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    net.fit(features, labels)

    net
  }

  /**
    * Fitting with UI Monitoring
    * @param net multilayer network
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    * @param testData Some(test data), default=None
    * @param testFreq test frequency (related to epochs), default=1 (every epoch)
    * @param epochs training epochs, default=1
    */
  def runWithTest(net: MultiLayerNetwork, trainData: DataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1,
    testData: Option[DataSetIterator] = None, testFreq: Int = 1, epochs: Int = 1): MultiLayerNetwork  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()

      if ((testData != None) && ((epoch + 1) % testFreq == 0)) {
        val eval = Evaluator.run(net, testData.get)
        println(eval.stats)
        testData.get.reset()
      }
    }

    net
  }


  /**
    * Fitting with UI Monitoring
    * @param net computation graph
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    * @param testData Some(test data), default=None
    * @param testFreq test frequency (related to epochs), default=1 (every epoch)
    * @param epochs training epochs, default=1
    */
  def runWithTestCG(net: ComputationGraph, trainData: DataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1,
    testData: Option[DataSetIterator] = None, testFreq: Int = 1, epochs: Int = 1): ComputationGraph  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq, false)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()

      if ((testData != None) && ((epoch + 1) % testFreq == 0)) {
        val eval = Evaluator.run(net, testData.get)
        println(eval.stats)
        testData.get.reset()
      }
    }

    net
  }

  /**
    * Fitting with UI Monitoring
    * @param net computation graph
    * @param trainData training data
    * @param useUI whether to use UI, default=false
    * @param listenFreq listener frequency, default=1
    * @param storagePath storage path to save the stats, default=None (in memory)
    * @param testData Some(test data), default=None
    * @param testFreq test frequency (related to epochs), default=1 (every epoch)
    * @param epochs training epochs, default=1
    */
  def runWithTestCGMD(net: ComputationGraph, trainData: MultiDataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1,
    testData: Option[MultiDataSetIterator] = None, testFreq: Int = 1, epochs: Int = 1, evalType: String = "default"): ComputationGraph  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq, false)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()

      if ((testData != None) && ((epoch + 1) % testFreq == 0)) {
        val eval = Evaluator.run(net, testData.get, evalType)
        println(eval.stats)
        testData.get.reset()
      }
    }

    net
  }

  def runWithTestCGMDEval(net: ComputationGraph, trainData: MultiDataSetIterator, storagePath: Option[String] = None,
    useUI: Boolean = false, listenFreq: Int = 1,
    testData: Option[MultiDataSetIterator] = None, testFreq: Int = 1, epochs: Int = 1,
    evalTypes: Map[Int, Array[String]],
    nCMap: Map[Int, Int]): ComputationGraph  = {

    if (useUI) {
      if (storagePath == None) UIStarter.initNewInMemory(net, listenFreq, false)
      else UIStarter.write2File(net, storagePath.get, listenFreq)
    }

    for (epoch <- 0 until epochs) {
      net.fit(trainData)
      trainData.reset()

      if ((testData != None) && ((epoch + 1) % testFreq == 0)) {
        val evals = Evaluator.run(net, testData.get, evalTypes, nCMap)
        evals.asScala.values.flatMap(_.map(_.stats)).foreach(println)
        testData.get.reset()
      }
    }

    net
  }

  /**
    * @param net multilayer network
    * @param trainData training data
    * @param epochs training epochs
    */
  def run(net: MultiLayerNetwork, trainData: DataSet, epochs: Int): MultiLayerNetwork  = {
    for (epoch <- 0 until epochs)
      net.fit(trainData)

    net
  }

  /**
    * @param net multilayer network
    * @param trainData training data
    */
  def run(net: MultiLayerNetwork, trainData: DataSet): MultiLayerNetwork  =
    run(net, trainData, 1)

  /**
    * Train Word2Vec modesl
    * @param vec Word2Vec
    * @param trainData training data
    * @return fitted Word2Vec
    */
  def run(vec: Word2Vec, trainData: SentenceIterator): Word2Vec = {
    vec.setSentenceIterator(trainData)
    vec.fit()
    vec
  }

  /**
    * Train Word2Vec and save it
    * @param vec Word2Vec
    * @param trainData training data
    * @return fitted Word2Vec
    */
  def run(vec: Word2Vec, trainData: SentenceIterator, saveModelTo: String): Word2Vec = {
    vec.setSentenceIterator(trainData)
    vec.fit()
    WordVectorSerializer.writeWord2VecModel(vec, saveModelTo)
    vec
  }

  /**
    * Train Glove and save it
    * @param vec Glove
    * @return fitted Glove
    */
  def run(vec: Glove, saveModelTo: String): Glove = {
    vec.fit()
    WordVectorSerializer.writeWordVectors(vec, saveModelTo)
    vec
  }

}
