/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.pca.PCAModel.PCAParameters
import ai.h2o.sparkling.H2OFrame
import hex.DataInfo.TransformType
import hex.pca.PCAModel.PCAParameters.Method
import hex.pca.PCAImplementation

trait H2OPCAParams
  extends H2OAlgoParamsBase
  with HasIgnoredCols {

  protected def paramTag = reflect.classTag[PCAParameters]

  //
  // Parameter definitions
  //
  protected val transform = stringParam(
    name = "transform",
    doc = """Transformation of training data. Possible values are ``"NONE"``, ``"STANDARDIZE"``, ``"NORMALIZE"``, ``"DEMEAN"``, ``"DESCALE"``.""")

  protected val pcaMethod = stringParam(
    name = "pcaMethod",
    doc = """Specify the algorithm to use for computing the principal components: GramSVD - uses a distributed computation of the Gram matrix, followed by a local SVD; Power - computes the SVD using the power iteration method (experimental); Randomized - uses randomized subspace iteration method; GLRM - fits a generalized low-rank model with L2 loss function and no regularization and solves for the SVD using local matrix algebra (experimental). Possible values are ``"GramSVD"``, ``"Power"``, ``"Randomized"``, ``"GLRM"``.""")

  protected val pcaImpl = stringParam(
    name = "pcaImpl",
    doc = """Specify the implementation to use for computing PCA (via SVD or EVD): MTJ_EVD_DENSEMATRIX - eigenvalue decompositions for dense matrix using MTJ; MTJ_EVD_SYMMMATRIX - eigenvalue decompositions for symmetric matrix using MTJ; MTJ_SVD_DENSEMATRIX - singular-value decompositions for dense matrix using MTJ; JAMA - eigenvalue decompositions for dense matrix using JAMA. References: JAMA - http://math.nist.gov/javanumerics/jama/; MTJ - https://github.com/fommil/matrix-toolkits-java/. Possible values are ``"MTJ_EVD_DENSEMATRIX"``, ``"MTJ_EVD_SYMMMATRIX"``, ``"MTJ_SVD_DENSEMATRIX"``, ``"JAMA"``.""")

  protected val k = intParam(
    name = "k",
    doc = """Rank of matrix approximation.""")

  protected val maxIterations = intParam(
    name = "maxIterations",
    doc = """Maximum training iterations.""")

  protected val seed = longParam(
    name = "seed",
    doc = """RNG seed for initialization.""")

  protected val useAllFactorLevels = booleanParam(
    name = "useAllFactorLevels",
    doc = """Whether first factor level is included in each categorical expansion.""")

  protected val computeMetrics = booleanParam(
    name = "computeMetrics",
    doc = """Whether to compute metrics on the training data.""")

  protected val imputeMissing = booleanParam(
    name = "imputeMissing",
    doc = """Whether to impute missing entries with the column mean.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val ignoreConstCols = booleanParam(
    name = "ignoreConstCols",
    doc = """Ignore constant columns.""")

  protected val scoreEachIteration = booleanParam(
    name = "scoreEachIteration",
    doc = """Whether to score during each iteration of model training.""")

  protected val maxRuntimeSecs = doubleParam(
    name = "maxRuntimeSecs",
    doc = """Maximum allowed runtime in seconds for model training. Use 0 to disable.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  //
  // Default values
  //
  setDefault(
    transform -> TransformType.NONE.name(),
    pcaMethod -> Method.GramSVD.name(),
    pcaImpl -> PCAImplementation.MTJ_EVD_SYMMMATRIX.name(),
    k -> 1,
    maxIterations -> 1000,
    seed -> -1L,
    useAllFactorLevels -> false,
    computeMetrics -> true,
    imputeMissing -> false,
    modelId -> null,
    ignoreConstCols -> true,
    scoreEachIteration -> false,
    maxRuntimeSecs -> 0.0,
    exportCheckpointsDir -> null)

  //
  // Getters
  //
  def getTransform(): String = $(transform)

  def getPcaMethod(): String = $(pcaMethod)

  def getPcaImpl(): String = $(pcaImpl)

  def getK(): Int = $(k)

  def getMaxIterations(): Int = $(maxIterations)

  def getSeed(): Long = $(seed)

  def getUseAllFactorLevels(): Boolean = $(useAllFactorLevels)

  def getComputeMetrics(): Boolean = $(computeMetrics)

  def getImputeMissing(): Boolean = $(imputeMissing)

  def getModelId(): String = $(modelId)

  def getIgnoreConstCols(): Boolean = $(ignoreConstCols)

  def getScoreEachIteration(): Boolean = $(scoreEachIteration)

  def getMaxRuntimeSecs(): Double = $(maxRuntimeSecs)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  //
  // Setters
  //
  def setTransform(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[TransformType](value)
    set(transform, validated)
  }
           
  def setPcaMethod(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Method](value)
    set(pcaMethod, validated)
  }
           
  def setPcaImpl(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[PCAImplementation](value)
    set(pcaImpl, validated)
  }
           
  def setK(value: Int): this.type = {
    set(k, value)
  }
           
  def setMaxIterations(value: Int): this.type = {
    set(maxIterations, value)
  }
           
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setUseAllFactorLevels(value: Boolean): this.type = {
    set(useAllFactorLevels, value)
  }
           
  def setComputeMetrics(value: Boolean): this.type = {
    set(computeMetrics, value)
  }
           
  def setImputeMissing(value: Boolean): this.type = {
    set(imputeMissing, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setIgnoreConstCols(value: Boolean): this.type = {
    set(ignoreConstCols, value)
  }
           
  def setScoreEachIteration(value: Boolean): this.type = {
    set(scoreEachIteration, value)
  }
           
  def setMaxRuntimeSecs(value: Double): this.type = {
    set(maxRuntimeSecs, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2OPCAParams(trainingFrame)
  }

  private[sparkling] def getH2OPCAParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "transform" -> getTransform(),
        "pca_method" -> getPcaMethod(),
        "pca_impl" -> getPcaImpl(),
        "k" -> getK(),
        "max_iterations" -> getMaxIterations(),
        "seed" -> getSeed(),
        "use_all_factor_levels" -> getUseAllFactorLevels(),
        "compute_metrics" -> getComputeMetrics(),
        "impute_missing" -> getImputeMissing(),
        "model_id" -> getModelId(),
        "ignore_const_cols" -> getIgnoreConstCols(),
        "score_each_iteration" -> getScoreEachIteration(),
        "max_runtime_secs" -> getMaxRuntimeSecs(),
        "export_checkpoints_dir" -> getExportCheckpointsDir()) +++
      getIgnoredColsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "transform" -> "transform",
        "pcaMethod" -> "pca_method",
        "pcaImpl" -> "pca_impl",
        "k" -> "k",
        "maxIterations" -> "max_iterations",
        "seed" -> "seed",
        "useAllFactorLevels" -> "use_all_factor_levels",
        "computeMetrics" -> "compute_metrics",
        "imputeMissing" -> "impute_missing",
        "modelId" -> "model_id",
        "ignoreConstCols" -> "ignore_const_cols",
        "scoreEachIteration" -> "score_each_iteration",
        "maxRuntimeSecs" -> "max_runtime_secs",
        "exportCheckpointsDir" -> "export_checkpoints_dir")
  }
      
}
