/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ai.h2o.sparkling.ml.params

import hex.rulefit.RuleFitModel.RuleFitParameters
import ai.h2o.sparkling.H2OFrame
import hex.rulefit.RuleFitModel.Algorithm
import hex.rulefit.RuleFitModel.ModelType
import hex.genmodel.utils.DistributionFamily
import hex.MultinomialAucType

trait H2ORuleFitParams
  extends H2OAlgoParamsBase
  with HasUnsupportedOffsetCol
  with HasIgnoredCols {

  protected def paramTag = reflect.classTag[RuleFitParameters]

  //
  // Parameter definitions
  //
  protected val seed = longParam(
    name = "seed",
    doc = """Seed for pseudo random number generator (if applicable).""")

  protected val algorithm = stringParam(
    name = "algorithm",
    doc = """The algorithm to use to generate rules. Possible values are ``"DRF"``, ``"GBM"``, ``"AUTO"``.""")

  protected val minRuleLength = intParam(
    name = "minRuleLength",
    doc = """Minimum length of rules. Defaults to 3.""")

  protected val maxRuleLength = intParam(
    name = "maxRuleLength",
    doc = """Maximum length of rules. Defaults to 3.""")

  protected val maxNumRules = intParam(
    name = "maxNumRules",
    doc = """The maximum number of rules to return. defaults to -1 which means the number of rules is selected 
by diminishing returns in model deviance.""")

  protected val modelType = stringParam(
    name = "modelType",
    doc = """Specifies type of base learners in the ensemble. Possible values are ``"RULES"``, ``"RULES_AND_LINEAR"``, ``"LINEAR"``.""")

  protected val ruleGenerationNtrees = intParam(
    name = "ruleGenerationNtrees",
    doc = """Specifies the number of trees to build in the tree model. Defaults to 50.""")

  protected val removeDuplicates = booleanParam(
    name = "removeDuplicates",
    doc = """Whether to remove rules which are identical to an earlier rule. Defaults to true.""")

  protected val lambdaValue = nullableDoubleArrayParam(
    name = "lambdaValue",
    doc = """Lambda for LASSO regressor.""")

  protected val modelId = nullableStringParam(
    name = "modelId",
    doc = """Destination id for this model; auto-generated if not specified.""")

  protected val distribution = stringParam(
    name = "distribution",
    doc = """Distribution function. Possible values are ``"AUTO"``, ``"bernoulli"``, ``"quasibinomial"``, ``"modified_huber"``, ``"multinomial"``, ``"ordinal"``, ``"gaussian"``, ``"poisson"``, ``"gamma"``, ``"tweedie"``, ``"huber"``, ``"laplace"``, ``"quantile"``, ``"fractionalbinomial"``, ``"negativebinomial"``, ``"custom"``.""")

  protected val labelCol = stringParam(
    name = "labelCol",
    doc = """Response variable column.""")

  protected val weightCol = nullableStringParam(
    name = "weightCol",
    doc = """Column with observation weights. Giving some observation a weight of zero is equivalent to excluding it from the dataset; giving an observation a relative weight of 2 is equivalent to repeating that row twice. Negative weights are not allowed. Note: Weights are per-row observation weights and do not increase the size of the data frame. This is typically the number of times a row is repeated, but non-integer values are supported as well. During training, rows with higher weights matter more, due to the larger loss function pre-factor. If you set weight = 0 for a row, the returned prediction frame at that row is zero and this is incorrect. To get an accurate prediction, remove all rows with weight == 0.""")

  protected val aucType = stringParam(
    name = "aucType",
    doc = """Set default multinomial AUC type. Possible values are ``"AUTO"``, ``"NONE"``, ``"MACRO_OVR"``, ``"WEIGHTED_OVR"``, ``"MACRO_OVO"``, ``"WEIGHTED_OVO"``.""")

  //
  // Default values
  //
  setDefault(
    seed -> -1L,
    algorithm -> Algorithm.AUTO.name(),
    minRuleLength -> 3,
    maxRuleLength -> 3,
    maxNumRules -> -1,
    modelType -> ModelType.RULES_AND_LINEAR.name(),
    ruleGenerationNtrees -> 50,
    removeDuplicates -> true,
    lambdaValue -> null,
    modelId -> null,
    distribution -> DistributionFamily.AUTO.name(),
    labelCol -> "label",
    weightCol -> null,
    aucType -> MultinomialAucType.AUTO.name())

  //
  // Getters
  //
  def getSeed(): Long = $(seed)

  def getAlgorithm(): String = $(algorithm)

  def getMinRuleLength(): Int = $(minRuleLength)

  def getMaxRuleLength(): Int = $(maxRuleLength)

  def getMaxNumRules(): Int = $(maxNumRules)

  def getModelType(): String = $(modelType)

  def getRuleGenerationNtrees(): Int = $(ruleGenerationNtrees)

  def getRemoveDuplicates(): Boolean = $(removeDuplicates)

  def getLambdaValue(): Array[Double] = $(lambdaValue)

  def getModelId(): String = $(modelId)

  def getDistribution(): String = $(distribution)

  def getLabelCol(): String = $(labelCol)

  def getWeightCol(): String = $(weightCol)

  def getAucType(): String = $(aucType)

  //
  // Setters
  //
  def setSeed(value: Long): this.type = {
    set(seed, value)
  }
           
  def setAlgorithm(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[Algorithm](value)
    set(algorithm, validated)
  }
           
  def setMinRuleLength(value: Int): this.type = {
    set(minRuleLength, value)
  }
           
  def setMaxRuleLength(value: Int): this.type = {
    set(maxRuleLength, value)
  }
           
  def setMaxNumRules(value: Int): this.type = {
    set(maxNumRules, value)
  }
           
  def setModelType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[ModelType](value)
    set(modelType, validated)
  }
           
  def setRuleGenerationNtrees(value: Int): this.type = {
    set(ruleGenerationNtrees, value)
  }
           
  def setRemoveDuplicates(value: Boolean): this.type = {
    set(removeDuplicates, value)
  }
           
  def setLambdaValue(value: Array[Double]): this.type = {
    set(lambdaValue, value)
  }
           
  def setModelId(value: String): this.type = {
    set(modelId, value)
  }
           
  def setDistribution(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[DistributionFamily](value)
    set(distribution, validated)
  }
           
  def setLabelCol(value: String): this.type = {
    set(labelCol, value)
  }
           
  def setWeightCol(value: String): this.type = {
    set(weightCol, value)
  }
           
  def setAucType(value: String): this.type = {
    val validated = EnumParamValidator.getValidatedEnumValue[MultinomialAucType](value)
    set(aucType, validated)
  }
           

  override private[sparkling] def getH2OAlgorithmParams(trainingFrame: H2OFrame): Map[String, Any] = {
    super.getH2OAlgorithmParams(trainingFrame) ++ getH2ORuleFitParams(trainingFrame)
  }

  private[sparkling] def getH2ORuleFitParams(trainingFrame: H2OFrame): Map[String, Any] = {
      Map(
        "seed" -> getSeed(),
        "algorithm" -> getAlgorithm(),
        "min_rule_length" -> getMinRuleLength(),
        "max_rule_length" -> getMaxRuleLength(),
        "max_num_rules" -> getMaxNumRules(),
        "model_type" -> getModelType(),
        "rule_generation_ntrees" -> getRuleGenerationNtrees(),
        "remove_duplicates" -> getRemoveDuplicates(),
        "lambda" -> getLambdaValue(),
        "model_id" -> getModelId(),
        "distribution" -> getDistribution(),
        "response_column" -> getLabelCol(),
        "weights_column" -> getWeightCol(),
        "auc_type" -> getAucType()) +++
      getUnsupportedOffsetColParam(trainingFrame) +++
      getIgnoredColsParam(trainingFrame)
  }

  override private[sparkling] def getSWtoH2OParamNameMap(): Map[String, String] = {
    super.getSWtoH2OParamNameMap() ++
      Map(
        "seed" -> "seed",
        "algorithm" -> "algorithm",
        "minRuleLength" -> "min_rule_length",
        "maxRuleLength" -> "max_rule_length",
        "maxNumRules" -> "max_num_rules",
        "modelType" -> "model_type",
        "ruleGenerationNtrees" -> "rule_generation_ntrees",
        "removeDuplicates" -> "remove_duplicates",
        "lambdaValue" -> "lambda",
        "modelId" -> "model_id",
        "distribution" -> "distribution",
        "labelCol" -> "response_column",
        "weightCol" -> "weights_column",
        "aucType" -> "auc_type")
  }
      
}
