/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.word2vec.Word2VecModel.Word2VecParameters
import ai.h2o.sparkling.H2OFrame
import hex.word2vec.Word2Vec.NormModel
import hex.word2vec.Word2Vec.WordModel

trait Word2VecParamsV3
  extends H2OAlgoParamsBase
  with HasPreTrained {

  protected def paramTag = reflect.classTag[Word2VecParameters]

  //
  // Parameter definitions
  //
  protected val vecSize = intParam(
    name = "vecSize",
    doc = """Set size of word vectors.""")

  protected val windowSize = intParam(
    name = "windowSize",
    doc = """Set max skip length between words.""")

  protected val sentSampleRate = floatParam(
    name = "sentSampleRate",
    doc = """Set threshold for occurrence of words. Those that appear with higher frequency in the training data
		will be randomly down-sampled; useful range is (0, 1e-5).""")

  protected val normModel = stringParam(
    name = "normModel",
    doc = """Use Hierarchical Softmax. Possible values are ``"HSM"``.""")

  protected val epochs = intParam(
    name = "epochs",
    doc = """Number of training iterations to run.""")

  protected val minWordFreq = intParam(
    name = "minWordFreq",
    doc = """This will discard words that appear less than <int> times.""")

  protected val initLearningRate = floatParam(
    name = "initLearningRate",
    doc = """Set the starting learning rate.""")

  protected val wordModel = stringParam(
    name = "wordModel",
    doc = """The word model to use (SkipGram or CBOW). Possible values are ``"SkipGram"``, ``"CBOW"``.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  //
  // Default values
  //
  setDefault(
    vecSize -> 100,
    windowSize -> 5,
    sentSampleRate -> 0.001f,
    normModel -> NormModel.HSM.name(),
    epochs -> 5,
    minWordFreq -> 5,
    initLearningRate -> 0.025f,
    wordModel -> WordModel.SkipGram.name(),
    modelId -> null,
    maxRuntimeSecs -> 0.0,
    exportCheckpointsDir -> null)

  //
  // Getters
  //
  def getVecSize(): Int = $(vecSize)

  def getWindowSize(): Int = $(windowSize)

  def getSentSampleRate(): Float = $(sentSampleRate)

  def getNormModel(): String = $(normModel)

  def getEpochs(): Int = $(epochs)

  def getMinWordFreq(): Int = $(minWordFreq)

  def getInitLearningRate(): Float = $(initLearningRate)

  def getWordModel(): String = $(wordModel)

  def getModelId(): String = $(modelId)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  //
  // Setters
  //
  def setVecSize(value: Int): this.type = {
    set(vecSize, value)
  }
           
  def setWindowSize(value: Int): this.type = {
    set(windowSize, value)
  }
           
  def setSentSampleRate(value: Float): this.type = {
    set(sentSampleRate, value)
  }
           
  def setNormModel(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[NormModel](value)
    set(normModel, validated)
  }
           
  def setEpochs(value: Int): this.type = {
    set(epochs, value)
  }
           
  def setMinWordFreq(value: Int): this.type = {
    set(minWordFreq, value)
  }
           
  def setInitLearningRate(value: Float): this.type = {
    set(initLearningRate, value)
  }
           
  def setWordModel(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[WordModel](value)
    set(wordModel, validated)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getWord2VecParamsV3(trainingFrame)
  }

  private[sparkling] def getWord2VecParamsV3(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "vec_size" -> getVecSize(),
        "window_size" -> getWindowSize(),
        "sent_sample_rate" -> getSentSampleRate(),
        "norm_model" -> getNormModel(),
        "epochs" -> getEpochs(),
        "min_word_freq" -> getMinWordFreq(),
        "init_learning_rate" -> getInitLearningRate(),
        "word_model" -> getWordModel(),
        "model_id" -> getModelId(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "export_checkpoints_dir" -> getExportCheckpointsDir()) +++
      getPreTrainedParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "vecSize" -> "vec_size",
        "windowSize" -> "window_size",
        "sentSampleRate" -> "sent_sample_rate",
        "normModel" -> "norm_model",
        "epochs" -> "epochs",
        "minWordFreq" -> "min_word_freq",
        "initLearningRate" -> "init_learning_rate",
        "wordModel" -> "word_model",
        "modelId" -> "model_id",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "exportCheckpointsDir" -> "export_checkpoints_dir")
  }
      
}
