package ai.minxiao.ds4s.core.dl4j.learning

import java.nio.charset.Charset
import java.util.{ArrayList, Collections}
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._

import org.deeplearning4j.earlystopping.{EarlyStoppingConfiguration, EarlyStoppingModelSaver, EarlyStoppingResult}
import org.deeplearning4j.earlystopping.saver.{InMemoryModelSaver, LocalFileGraphSaver, LocalFileModelSaver}
import org.deeplearning4j.earlystopping.scorecalc.{DataSetLossCalculator, ROCScoreCalculator, ScoreCalculator}
import org.deeplearning4j.earlystopping.scorecalc.ROCScoreCalculator.{ROCType, Metric => ROCMetric}
import org.deeplearning4j.earlystopping.termination.{
  EpochTerminationCondition, BestScoreEpochTerminationCondition, MaxEpochsTerminationCondition, ScoreImprovementEpochTerminationCondition,
  IterationTerminationCondition, InvalidScoreIterationTerminationCondition, MaxScoreIterationTerminationCondition, MaxTimeIterationTerminationCondition
}
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer
import org.deeplearning4j.nn.api.Model
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator

/**
  * ESLearner
  *
  * @constructor
  * @param modelSaver model saver type, options: in-memory-model, local-file-graph, local-file-model, default="in-memory-model"
  * @param directory model save directory for local-file modes, default="./"
  * @param encoding model encoding for local-file modes, default=Charset.defaultCharset
  * options: {{{
  * US-ASCII	Seven-bit ASCII, a.k.a. ISO646-US, a.k.a. the Basic Latin block of the Unicode character set
  * ISO-8859-1  	ISO Latin Alphabet No. 1, a.k.a. ISO-LATIN-1
  * UTF-8	Eight-bit UCS Transformation Format
  * UTF-16BE	Sixteen-bit UCS Transformation Format, big-endian byte order
  * UTF-16LE	Sixteen-bit UCS Transformation Format, little-endian byte order
  * UTF-16	Sixteen-bit UCS Transformation Format, byte order identified by an optional byte-order mark
  * }}}
  * saveLastModel whether also save the last model in additional to the best model, default=false
  * @param epochTerminationConditions epoch termination conditions, default=Array("max-epochs")
  * all possible options: "max-epochs", "best-score-epoch", "score-improvement-epoch"
  * @param maxEpochs max epochs for max-epochs, default=10
  * @param bestExpectedScore, best epected score for "best-score-epoch"， default=1E-3
  * @param lesserBetter whether lesser is better for the score of "best-score-epoch", default=true
  * @param maxEpochsWithNoImprovement max epochs (consecutive epochs) without improvements, default=5
  * @param minImprovement min improvement, default=1E-3
  * @param evaluateEveryNEpochs evaluate every N epochs, default=1
  * @param iterationTerminationConditions iteration termination conditions, default=Array("invalid-score-iteration")
  * all possible options: "invalid-score-iteration", "max-score-iteration", "max-time-iteration"
  * @param maxScore max score for "max-score-iteration", default=1E-3
  * @param maxTimeAmount max time amount, default=30L,
  * @param maxTimeUnit, max time unit, default=TimeUnit.MINUTES
  * all options, TimeUnit.X: DAYS, HOURS, MICROSECONDS, MILLISECONDS, MINUTES, NANOSECONDS, SECONDS
  * @param scoreCalculator score calculator, default="dataset-loss"
  * all options: "dataset-loss", "roc-score"
  * @param dataSetIterator dataset for score calculation, required, must be provided
  * @param average whether to average the score, default=true
  * @param rocType ROCType, default=ROCType.ROC, all options {ROC, BINARY, MULTICLASS}
  * @param rocMetric ROC Metric, default=Metric.AUC, all options {AUC, AUPRC}
  *
  * @author mx
  */
@SerialVersionUID(698376L)
class ESLearner[T <: Model] (
  // modelSaver
  modelSaver: String = "in-memory-model",
  directory: String = "./",
  encoding: Charset = Charset.defaultCharset,
  saveLastModel: Boolean = false,
  // epochTerminationConditions
  epochTerminationConditions: Array[String] = Array("max-epochs"),
  maxEpochs: Int = 10,
  bestExpectedScore: Double = 1E-3,
  lesserBetter: Boolean = true,
  maxEpochsWithNoImprovement: Int = 5,
  minImprovement: Double = 1E-3,
  // evaluate frequency
  evaluateEveryNEpochs: Int = 1,
  // iterationTerminationConditions
  iterationTerminationConditions: Array[String] = Array("invalid-score-iteration"),
  maxScore: Double = 1E-3,
  maxTimeAmount: Long = 30L,
  maxTimeUnit: TimeUnit = TimeUnit.MINUTES,
  // scoreCalculator
  scoreCalculator: String = "dataset-loss",
  dataSetIterator: DataSetIterator,
  average: Boolean = true,
  rocType: ROCType = ROCType.ROC,
  rocMetric: ROCMetric = ROCMetric.AUC
) extends Serializable {

  require(Array("in-memory-model", "local-file-graph", "local-file-model").contains(modelSaver))
  require(epochTerminationConditions.forall(Array("max-epochs", "best-score-epoch", "score-improvement-epoch").contains(_)))
  require(iterationTerminationConditions.forall(Array("invalid-score-iteration", "max-score-iteration", "max-time-iteration").contains(_)))
  require(Array("dataset-loss", "roc-score").contains(scoreCalculator))

  private val esConf = new EarlyStoppingConfiguration.Builder[T]().
    modelSaver(
      {
        modelSaver match {
          case "local-file-graph"     => new LocalFileGraphSaver(directory, encoding)
          case "local-file-model"     => new LocalFileModelSaver(directory, encoding)
          case _/*"in-memory-model"*/ => new InMemoryModelSaver
        }
      }.asInstanceOf[EarlyStoppingModelSaver[T]]
    ).
    saveLastModel(saveLastModel).
    epochTerminationConditions(
      {
        epochTerminationConditions.map{
          case "best-score-epoch"        => new BestScoreEpochTerminationCondition(bestExpectedScore)
          case "score-improvement-epoch" => new ScoreImprovementEpochTerminationCondition(maxEpochsWithNoImprovement, minImprovement)
          case _/*"max-epochs"*/         => new MaxEpochsTerminationCondition(maxEpochs)
        }
      }.asInstanceOf[Array[EpochTerminationCondition]]: _*
    ).
    evaluateEveryNEpochs(evaluateEveryNEpochs).
    iterationTerminationConditions(
      {
        iterationTerminationConditions.map {
          case "max-score-iteration"          => new MaxScoreIterationTerminationCondition(maxScore)
          case "max-time-iteration"           => new MaxTimeIterationTerminationCondition(maxTimeAmount, maxTimeUnit)
          case _/*"invalid-score-iteration"*/ => new InvalidScoreIterationTerminationCondition
        }
      }.asInstanceOf[Array[IterationTerminationCondition]]: _*
    ).
    scoreCalculator(
      {
        scoreCalculator match {
          case "roc-score"         => new ROCScoreCalculator(rocType, rocMetric, dataSetIterator)
          case _/*"dataset-loss"*/ => new DataSetLossCalculator(dataSetIterator, average)
        }
      }.asInstanceOf[ScoreCalculator[T]]
    ).
    build()

  def run(net: MultiLayerNetwork, trainData: DataSetIterator, printResult: Boolean = false): EarlyStoppingResult[MultiLayerNetwork] = {
    val esLearner = new EarlyStoppingTrainer(
      esConf.asInstanceOf[EarlyStoppingConfiguration[MultiLayerNetwork]],
      net,
      trainData
    )
    val esResult = esLearner.fit()
    if (printResult) printESResult(esResult)
    esResult
  }

  def printESResult(esResult: EarlyStoppingResult[MultiLayerNetwork]) {
    println("Termination reason: " + esResult.getTerminationReason)
    println("Termination details: " + esResult.getTerminationDetails)
    println("Total epochs: " + esResult.getTotalEpochs)
    println("Best epoch number: " + esResult.getBestModelEpoch)
    println("Score at best epoch: " + esResult.getBestModelScore)

    //Print score vs. epoch
    val scoreVsEpoch = esResult.getScoreVsEpoch
    val list = new ArrayList(scoreVsEpoch.keySet)
    Collections.sort(list)
    println("Score vs. Epoch:")
    for (i <- list.asScala) println(i + "\t" + scoreVsEpoch.get(i))
  }
}
