/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.glrm.GLRMModel.GLRMParameters
import ai.h2o.sparkling.H2OFrame
import hex.DataInfo.TransformType
import hex.genmodel.algos.glrm.GlrmLoss
import hex.genmodel.algos.glrm.GlrmLoss
import hex.genmodel.algos.glrm.GlrmRegularizer
import hex.genmodel.algos.glrm.GlrmRegularizer
import hex.genmodel.algos.glrm.GlrmInitialization
import hex.svd.SVDModel.SVDParameters.Method

trait H2OGLRMParams
  extends H2OAlgoParamsBase
  with HasUserX
  with HasUserY
  with HasLossByColNames {

  protected def paramTag = reflect.classTag[GLRMParameters]

  //
  // Parameter definitions
  //
  protected val transform = stringParam(
    name = "transform",
    doc = """Transformation of training data. Possible values are ``"NONE"``, ``"STANDARDIZE"``, ``"NORMALIZE"``, ``"DEMEAN"``, ``"DESCALE"``.""")

  protected val k = intParam(
    name = "k",
    doc = """Rank of matrix approximation.""")

  protected val loss = stringParam(
    name = "loss",
    doc = """Numeric loss function. Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")

  protected val multiLoss = stringParam(
    name = "multiLoss",
    doc = """Categorical loss function. Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")

  protected val lossByCol = nullableStringArrayParam(
    name = "lossByCol",
    doc = """Loss function by column (override). Possible values are ``"Quadratic"``, ``"Absolute"``, ``"Huber"``, ``"Poisson"``, ``"Periodic(0)"``, ``"Logistic"``, ``"Hinge"``, ``"Categorical"``, ``"Ordinal"``.""")

  protected val period = intParam(
    name = "period",
    doc = """Length of period (only used with periodic loss function).""")

  protected val regularizationX = stringParam(
    name = "regularizationX",
    doc = """Regularization function for X matrix. Possible values are ``"None"``, ``"Quadratic"``, ``"L2"``, ``"L1"``, ``"NonNegative"``, ``"OneSparse"``, ``"UnitOneSparse"``, ``"Simplex"``.""")

  protected val regularizationY = stringParam(
    name = "regularizationY",
    doc = """Regularization function for Y matrix. Possible values are ``"None"``, ``"Quadratic"``, ``"L2"``, ``"L1"``, ``"NonNegative"``, ``"OneSparse"``, ``"UnitOneSparse"``, ``"Simplex"``.""")

  protected val gammaX = doubleParam(
    name = "gammaX",
    doc = """Regularization weight on X matrix.""")

  protected val gammaY = doubleParam(
    name = "gammaY",
    doc = """Regularization weight on Y matrix.""")

  protected val maxIterations = intParam(
    name = "maxIterations",
    doc = """Maximum number of iterations.""")

  protected val maxUpdates = intParam(
    name = "maxUpdates",
    doc = """Maximum number of updates, defaults to 2*max_iterations.""")

  protected val initStepSize = doubleParam(
    name = "initStepSize",
    doc = """Initial step size.""")

  protected val minStepSize = doubleParam(
    name = "minStepSize",
    doc = """Minimum step size.""")

  protected val seed = longParam(
    name = "seed",
    doc = """RNG seed for initialization.""")

  protected val init = stringParam(
    name = "init",
    doc = """Initialization mode. Possible values are ``"Random"``, ``"SVD"``, ``"PlusPlus"``, ``"User"``, ``"Power"``.""")

  protected val svdMethod = stringParam(
    name = "svdMethod",
    doc = """Method for computing SVD during initialization (Caution: Randomized is currently experimental and unstable). Possible values are ``"GramSVD"``, ``"Power"``, ``"Randomized"``.""")

  protected val loadingName = nullableStringParam(
    name = "loadingName",
    doc = """[Deprecated] Use representation_name instead.  Frame key to save resulting X.""")

  protected val representationName = nullableStringParam(
    name = "representationName",
    doc = """Frame key to save resulting X.""")

  protected val expandUserY = booleanParam(
    name = "expandUserY",
    doc = """Expand categorical columns in user-specified initial Y.""")

  protected val imputeOriginal = booleanParam(
    name = "imputeOriginal",
    doc = """Reconstruct original training data by reversing transform.""")

  protected val recoverSvd = booleanParam(
    name = "recoverSvd",
    doc = """Recover singular values and eigenvectors of XY.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val ignoredCols = nullableStringArrayParam(
    name = "ignoredCols",
    doc = """Names of columns to ignore for training.""")

  protected val ignoreConstCols = booleanParam(
    name = "ignoreConstCols",
    doc = """Ignore constant columns.""")

  protected val scoreEachIteration = booleanParam(
    name = "scoreEachIteration",
    doc = """Whether to score during each iteration of model training.""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  //
  // Default values
  //
  setDefault(
    transform -> TransformType.NONE.name(),
    k -> 1,
    loss -> GlrmLoss.Quadratic.name(),
    multiLoss -> GlrmLoss.Categorical.name(),
    lossByCol -> null,
    period -> 1,
    regularizationX -> GlrmRegularizer.None.name(),
    regularizationY -> GlrmRegularizer.None.name(),
    gammaX -> 0.0,
    gammaY -> 0.0,
    maxIterations -> 1000,
    maxUpdates -> 2000,
    initStepSize -> 1.0,
    minStepSize -> 1.0e-4,
    seed -> -1L,
    init -> GlrmInitialization.PlusPlus.name(),
    svdMethod -> Method.Randomized.name(),
    loadingName -> null,
    representationName -> null,
    expandUserY -> true,
    imputeOriginal -> false,
    recoverSvd -> false,
    modelId -> null,
    ignoredCols -> null,
    ignoreConstCols -> true,
    scoreEachIteration -> false,
    maxRuntimeSecs -> 0.0,
    exportCheckpointsDir -> null)

  //
  // Getters
  //
  def getTransform(): String = $(transform)

  def getK(): Int = $(k)

  def getLoss(): String = $(loss)

  def getMultiLoss(): String = $(multiLoss)

  def getLossByCol(): Array[String] = $(lossByCol)

  def getPeriod(): Int = $(period)

  def getRegularizationX(): String = $(regularizationX)

  def getRegularizationY(): String = $(regularizationY)

  def getGammaX(): Double = $(gammaX)

  def getGammaY(): Double = $(gammaY)

  def getMaxIterations(): Int = $(maxIterations)

  def getMaxUpdates(): Int = $(maxUpdates)

  def getInitStepSize(): Double = $(initStepSize)

  def getMinStepSize(): Double = $(minStepSize)

  def getSeed(): Long = $(seed)

  def getInit(): String = $(init)

  def getSvdMethod(): String = $(svdMethod)

  def getLoadingName(): String = $(loadingName)

  def getRepresentationName(): String = $(representationName)

  def getExpandUserY(): Boolean = $(expandUserY)

  def getImputeOriginal(): Boolean = $(imputeOriginal)

  def getRecoverSvd(): Boolean = $(recoverSvd)

  def getModelId(): String = $(modelId)

  def getIgnoredCols(): Array[String] = $(ignoredCols)

  def getIgnoreConstCols(): Boolean = $(ignoreConstCols)

  def getScoreEachIteration(): Boolean = $(scoreEachIteration)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  //
  // Setters
  //
  def setTransform(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[TransformType](value)
    set(transform, validated)
  }
           
  def setK(value: Int): this.type = {
    set(k, value)
  }
           
  def setLoss(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GlrmLoss](value)
    set(loss, validated)
  }
           
  def setMultiLoss(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GlrmLoss](value)
    set(multiLoss, validated)
  }
           
  def setLossByCol(value: Array[String]): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValues[hex.genmodel.algos.glrm.GlrmLoss](value, nullEnabled = true)
    set(lossByCol, validated)
  }
           
  def setPeriod(value: Int): this.type = {
    set(period, value)
  }
           
  def setRegularizationX(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GlrmRegularizer](value)
    set(regularizationX, validated)
  }
           
  def setRegularizationY(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GlrmRegularizer](value)
    set(regularizationY, validated)
  }
           
  def setGammaX(value: Double): this.type = {
    set(gammaX, value)
  }
           
  def setGammaY(value: Double): this.type = {
    set(gammaY, value)
  }
           
  def setMaxIterations(value: Int): this.type = {
    set(maxIterations, value)
  }
           
  def setMaxUpdates(value: Int): this.type = {
    set(maxUpdates, value)
  }
           
  def setInitStepSize(value: Double): this.type = {
    set(initStepSize, value)
  }
           
  def setMinStepSize(value: Double): this.type = {
    set(minStepSize, value)
  }
           
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setInit(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[GlrmInitialization](value)
    set(init, validated)
  }
           
  def setSvdMethod(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Method](value)
    set(svdMethod, validated)
  }
           
  def setLoadingName(value: String): this.type = {
    set(loadingName, value)
  }
           
  def setRepresentationName(value: String): this.type = {
    set(representationName, value)
  }
           
  def setExpandUserY(value: Boolean): this.type = {
    set(expandUserY, value)
  }
           
  def setImputeOriginal(value: Boolean): this.type = {
    set(imputeOriginal, value)
  }
           
  def setRecoverSvd(value: Boolean): this.type = {
    set(recoverSvd, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setIgnoredCols(value: Array[String]): this.type = {
    set(ignoredCols, value)
  }
           
  def setIgnoreConstCols(value: Boolean): this.type = {
    set(ignoreConstCols, value)
  }
           
  def setScoreEachIteration(value: Boolean): this.type = {
    set(scoreEachIteration, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2OGLRMParams(trainingFrame)
  }

  private[sparkling] def getH2OGLRMParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "transform" -> getTransform(),
        "k" -> getK(),
        "loss" -> getLoss(),
        "multi_loss" -> getMultiLoss(),
        "loss_by_col" -> getLossByCol(),
        "period" -> getPeriod(),
        "regularization_x" -> getRegularizationX(),
        "regularization_y" -> getRegularizationY(),
        "gamma_x" -> getGammaX(),
        "gamma_y" -> getGammaY(),
        "max_iterations" -> getMaxIterations(),
        "max_updates" -> getMaxUpdates(),
        "init_step_size" -> getInitStepSize(),
        "min_step_size" -> getMinStepSize(),
        "seed" -> getSeed(),
        "init" -> getInit(),
        "svd_method" -> getSvdMethod(),
        "loading_name" -> getLoadingName(),
        "representation_name" -> getRepresentationName(),
        "expand_user_y" -> getExpandUserY(),
        "impute_original" -> getImputeOriginal(),
        "recover_svd" -> getRecoverSvd(),
        "model_id" -> getModelId(),
        "ignored_columns" -> getIgnoredCols(),
        "ignore_const_cols" -> getIgnoreConstCols(),
        "score_each_iteration" -> getScoreEachIteration(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "export_checkpoints_dir" -> getExportCheckpointsDir()) +++
      getUserXParam(trainingFrame) +++
      getUserYParam(trainingFrame) +++
      getLossByColNamesParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "transform" -> "transform",
        "k" -> "k",
        "loss" -> "loss",
        "multiLoss" -> "multi_loss",
        "lossByCol" -> "loss_by_col",
        "period" -> "period",
        "regularizationX" -> "regularization_x",
        "regularizationY" -> "regularization_y",
        "gammaX" -> "gamma_x",
        "gammaY" -> "gamma_y",
        "maxIterations" -> "max_iterations",
        "maxUpdates" -> "max_updates",
        "initStepSize" -> "init_step_size",
        "minStepSize" -> "min_step_size",
        "seed" -> "seed",
        "init" -> "init",
        "svdMethod" -> "svd_method",
        "loadingName" -> "loading_name",
        "representationName" -> "representation_name",
        "expandUserY" -> "expand_user_y",
        "imputeOriginal" -> "impute_original",
        "recoverSvd" -> "recover_svd",
        "modelId" -> "model_id",
        "ignoredCols" -> "ignored_columns",
        "ignoreConstCols" -> "ignore_const_cols",
        "scoreEachIteration" -> "score_each_iteration",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "exportCheckpointsDir" -> "export_checkpoints_dir")
  }
      
}
