/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.coxph.CoxPHModel.CoxPHParameters
import ai.h2o.sparkling.H2OFrame
import hex.coxph.CoxPHModel.CoxPHParameters.CoxPHTies

trait H2OCoxPHParams
  extends H2OAlgoParamsBase
  with HasIgnoredCols
  with HasInteractionPairs {

  protected def paramTag = reflect.classTag[CoxPHParameters]

  //
  // Parameter definitions
  //
  protected val startCol = nullableStringParam(
    name = "startCol",
    doc = """Start Time Column.""")

  protected val stopCol = nullableStringParam(
    name = "stopCol",
    doc = """Stop Time Column.""")

  protected val stratifyBy = nullableStringArrayParam(
    name = "stratifyBy",
    doc = """List of columns to use for stratification.""")

  protected val ties = stringParam(
    name = "ties",
    doc = """Method for Handling Ties. Possible values are ``"efron"``, ``"breslow"``.""")

  protected val init = doubleParam(
    name = "init",
    doc = """Coefficient starting value.""")

  protected val lreMin = doubleParam(
    name = "lreMin",
    doc = """Minimum log-relative error.""")

  protected val maxIterations = intParam(
    name = "maxIterations",
    doc = """Maximum number of iterations.""")

  protected val interactionsOnly = nullableStringArrayParam(
    name = "interactionsOnly",
    doc = """A list of columns that should only be used to create interactions but should not itself participate in model training.""")

  protected val interactions = nullableStringArrayParam(
    name = "interactions",
    doc = """A list of predictor column indices to interact. All pairwise combinations will be computed for the list.""")

  protected val useAllFactorLevels = booleanParam(
    name = "useAllFactorLevels",
    doc = """(Internal. For development only!) Indicates whether to use all factor levels.""")

  protected val singleNodeMode = booleanParam(
    name = "singleNodeMode",
    doc = """Run on a single node to reduce the effect of network overhead (for smaller datasets).""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val labelCol = stringParam(
    name = "labelCol",
    doc = """Response variable column.""")

  protected val weightCol = nullableStringParam(
    name = "weightCol",
    doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor.""")

  protected val offsetCol = nullableStringParam(
    name = "offsetCol",
    doc = """Offset column. This will be added to the combination of columns before applying the link function.""")

  protected val exportCheckpointsDir = nullableStringParam(
    name = "exportCheckpointsDir",
    doc = """Automatically export generated models to this directory.""")

  //
  // Default values
  //
  setDefault(
    startCol -> null,
    stopCol -> null,
    stratifyBy -> null,
    ties -> CoxPHTies.efron.name(),
    init -> 0.0,
    lreMin -> 9.0,
    maxIterations -> 20,
    interactionsOnly -> null,
    interactions -> null,
    useAllFactorLevels -> false,
    singleNodeMode -> false,
    modelId -> null,
    labelCol -> "label",
    weightCol -> null,
    offsetCol -> null,
    exportCheckpointsDir -> null)

  //
  // Getters
  //
  def getStartCol(): String = $(startCol)

  def getStopCol(): String = $(stopCol)

  def getStratifyBy(): Array[String] = $(stratifyBy)

  def getTies(): String = $(ties)

  def getInit(): Double = $(init)

  def getLreMin(): Double = $(lreMin)

  def getMaxIterations(): Int = $(maxIterations)

  def getInteractionsOnly(): Array[String] = $(interactionsOnly)

  def getInteractions(): Array[String] = $(interactions)

  def getUseAllFactorLevels(): Boolean = $(useAllFactorLevels)

  def getSingleNodeMode(): Boolean = $(singleNodeMode)

  def getModelId(): String = $(modelId)

  def getLabelCol(): String = $(labelCol)

  def getWeightCol(): String = $(weightCol)

  def getOffsetCol(): String = $(offsetCol)

  def getExportCheckpointsDir(): String = $(exportCheckpointsDir)

  //
  // Setters
  //
  def setStartCol(value: String): this.type = {
    set(startCol, value)
  }
           
  def setStopCol(value: String): this.type = {
    set(stopCol, value)
  }
           
  def setStratifyBy(value: Array[String]): this.type = {
    set(stratifyBy, value)
  }
           
  def setTies(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[CoxPHTies](value)
    set(ties, validated)
  }
           
  def setInit(value: Double): this.type = {
    set(init, value)
  }
           
  def setLreMin(value: Double): this.type = {
    set(lreMin, value)
  }
           
  def setMaxIterations(value: Int): this.type = {
    set(maxIterations, value)
  }
           
  def setInteractionsOnly(value: Array[String]): this.type = {
    set(interactionsOnly, value)
  }
           
  def setInteractions(value: Array[String]): this.type = {
    set(interactions, value)
  }
           
  def setUseAllFactorLevels(value: Boolean): this.type = {
    set(useAllFactorLevels, value)
  }
           
  def setSingleNodeMode(value: Boolean): this.type = {
    set(singleNodeMode, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setLabelCol(value: String): this.type = {
    set(labelCol, value)
  }
           
  def setWeightCol(value: String): this.type = {
    set(weightCol, value)
  }
           
  def setOffsetCol(value: String): this.type = {
    set(offsetCol, value)
  }
           
  def setExportCheckpointsDir(value: String): this.type = {
    set(exportCheckpointsDir, value)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2OCoxPHParams(trainingFrame)
  }

  private[sparkling] def getH2OCoxPHParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "start_column" -> getStartCol(),
        "stop_column" -> getStopCol(),
        "stratify_by" -> getStratifyBy(),
        "ties" -> getTies(),
        "init" -> getInit(),
        "lre_min" -> getLreMin(),
        "max_iterations" -> getMaxIterations(),
        "interactions_only" -> getInteractionsOnly(),
        "interactions" -> getInteractions(),
        "use_all_factor_levels" -> getUseAllFactorLevels(),
        "single_node_mode" -> getSingleNodeMode(),
        "model_id" -> getModelId(),
        "response_column" -> getLabelCol(),
        "weights_column" -> getWeightCol(),
        "offset_column" -> getOffsetCol(),
        "export_checkpoints_dir" -> getExportCheckpointsDir()) +++
      getIgnoredColsParam(trainingFrame) +++
      getInteractionPairsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "startCol" -> "start_column",
        "stopCol" -> "stop_column",
        "stratifyBy" -> "stratify_by",
        "ties" -> "ties",
        "init" -> "init",
        "lreMin" -> "lre_min",
        "maxIterations" -> "max_iterations",
        "interactionsOnly" -> "interactions_only",
        "interactions" -> "interactions",
        "useAllFactorLevels" -> "use_all_factor_levels",
        "singleNodeMode" -> "single_node_mode",
        "modelId" -> "model_id",
        "labelCol" -> "response_column",
        "weightCol" -> "weights_column",
        "offsetCol" -> "offset_column",
        "exportCheckpointsDir" -> "export_checkpoints_dir")
  }
      
}
    