package ai.minxiao.ds4s.core.h2o

import hex.deeplearning.DeepLearningModel.DeepLearningParameters.{Activation, ClassSamplingMethod,
  InitialWeightDistribution, Loss}
import hex.genmodel.algos.glrm.{GlrmInitialization, GlrmLoss, GlrmRegularizer}
import hex.glm.GLMModel.GLMParameters.{Family, Solver}
import hex.kmeans.KMeans.{Initialization => KMeansInitialization}
import hex.ScoreKeeper.StoppingMetric
import hex.svd.SVDModel.SVDParameters.{Method => SVDMethod}
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.{Backend, Booster, DartNormalizeType,
  GrowPolicy, TreeMethod}
import hex.word2vec.Word2Vec.{WordModel, NormModel}
import org.apache.spark.ml.spark.models.MissingValuesHandling
import org.apache.spark.ml.spark.models.svm.SVMParameters // from h2o's
import org.apache.spark.ml.spark.models.svm.{Gradient, Updater}

/**
  * Package object learning, containing model types, model parameters, default parameter values
  *
  * @author mx
  */
package object learning {

  /** first identify the model type: naive bayes, distributed random forest */
  object ModelType extends Enumeration {
    type ModelType = Value
    // supervised
    val GLM, NB = Value
    val DRF, GBM = Value
    /** <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-extensions/xgboost/src/main/java/hex/tree/xgboost/XGBoostModel.java">XGBoost</a>
      */
    val XGB = Value
    /** <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-algos/src/main/java/hex/deeplearning/DeepLearningModel.java#L1309">Neural Networks</a>
      */
    val NN = Value
    val SVM = Value // only for binary classification or regression from spark SVM
    // unsupervised
    val KMM = Value
    val GLRM = Value
    val W2V = Value
  }

  /** then identify hyperparameters: laplace for NB, number of trees for distributed random forest */
  object ParamType extends Enumeration {
    type ParamType = Value
    /** All Models: early stopping
      * <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-core/src/main/java/hex/Model.java">Shared Model Parameters</a>
      */
    val stoppingRounds, stoppingMetric, stoppingTolerance= Value
    // GLM
    val family, solver = Value
    val intercept = Value
    // GLM
    /** R(w) = alpha * L1 + (1 - alpha) * L2: 0 -> ridge; 1 -> lasso */
    val alpha = Value
    // GLM
    val betaEpsilon = Value
    val gradientEpsilon = Value
    /** lambda * R(w): 0 -> no regularizer */
    val lambda = Value
    val lambdaSearch = Value
    val objectiveEpsilon = Value
    // NB
    val laplace = Value
    // DRF, GBM, XGB
    val ntrees, maxDepth = Value
    val nbins = Value
    val balanceClasses = Value
    // XGB
    val treeMethod, growPolicy, booster, dartNormalizeType, backend = Value
    val regAlpha = Value // L1
    val regLambda = Value // L2
    // NN
    // --architecture
    val hidden = Value // hidden units
    val activation = Value // activation function
    // -- loss
    val loss = Value
    // --regularization
    // ----dropout
    val inputDropoutRatio, hiddenDropoutRatios = Value
    // ----regularizer
    val l1, l2 = Value
    // ----others
    val maxW2 = Value
    // --optimization
    // ----Initialization
    val initialWeightDistribution, initialWeightScale = Value
    // NN, W2V----Update
    val epochs = Value // number of passes over the training data
    val miniBatchSize = Value // mini batch size
    // number of training data rows to be processed per iteration, this number will be devided into multiple nodes (if applicable)
    val trainSamplesPerIteration = Value
    // ------adaptive rate
    val adaptiveRate, rho, epsilon = Value // adadelta
    // learning rate, ...,  active if only adaptiveRate = false
    val rate, rateAnnealing, rateDecay = Value
    // ------momentum
    val momentumStart, momentumRamp, momentumStable = Value
    val nesterovAcceleratedGradient = Value
    // ------elastic average
    val elasticAveraging = Value
    // ----Stop
    val classificationStop, regressionStop = Value
    // --Scoring
    val scoreTrainingSamples, scoreValidationSamples, scoreValidationSampling = Value
    // NN, KMM --data
    val standardize = Value // true for automatically standarizing data
    // KMM, GLRM, GLM
    val k = Value
    val maxIterations = Value
    // KMM
    val initKMM = Value
    val estimateK = Value
    // GLRM
    val initGLRM = Value
    val svdMethod = Value
    val lossGLRMNum, lossGLRMCat = Value
    val regularizationX, gammaX = Value
    val regularizationY, gammaY = Value
    val initStepSize, minStepSize = Value
    val recoverSVD, imputeOriginal = Value
    // W2V
    val wordModel, normModel = Value
    val vecSize, windowSize = Value
    val minWordFreq, sentSampleRate = Value
    val initLearningRate = Value
    // Grid Search
    val maxModels, maxRuntimeSecs = Value
    // SVM
    val addIntercept = Value
    val convergenceTol = Value
    val miniBatchFraction = Value
    val regParam = Value
    val stepSize = Value
    val threshold = Value
    val gradient, updater = Value
    val missingValuesHandling = Value
    // NN: AutoEncoer
    val autoencoder = Value
  }

  /** finally, assign values to those hyperparameters */
  object ParamValue {
    /** NN
      * {{{Options: Activation.activation
      * --Tanh
      * --TanhWithDropout
      * --Rectifier
      * --RectifierWithDropout
      * --Maxout
      * --MaxoutWithDropout
      * --ExpRectifier
      * --ExpRectifierWithDropout
      * }}}
      */
    val activation = Activation.Tanh
    /** NN:
      * <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-algos/src/main/java/hex/deeplearning/DeepLearningModel.java#L1391">Adaptive Rate</a>
      */
    val adaptiveRate: Boolean = true
    // SVM
    val addIntercept: Boolean = true
    // GLM, XGB
    val alpha: Double = 0.0 /* only L2 */
    // NN, AutoEncoder
    val autoencoder: Boolean = false
    // DRF, GBM, XGB, NN: to balance classes
    val balanceClasses: Boolean = false
    /** XGB
      * {{{Options: Backend.backend
      * --auto
      * --gpu
      * --cpu
      * }}}
      */
    val backend = Backend.auto
    // GLM: only applies to IRLSM
    val betaEpsilon: Double = 1e-4
    /** XGB
      * {{{Options: Booster.booster
      * --gbtree
      * --gblinear
      * --dart
      * }}}
      */
    val booster = Booster.gbtree
    // NN: early stop based on the classification error threshold
    val classificationStop: Double = 0.0
    // SVM:
    val convergenceTol: Double = 1E-6
    /** XGB
      * {{{Options: DartNormalizeType.dartNormalizeType
      * --tree
      * --forest
      * }}}
      */
    val dartNormalizeType = DartNormalizeType.tree
    // NN: elastic averageing
    val elasticAveraging: Boolean = true
    // NN: number of passes over the training dataset
    val epochs: Int = 5
    // NN: Typical values are between 1e-10 and 1e-4
    val epsilon: Double = 1e-8
    // KMM: estimate K
    val estimateK: Boolean = false
    /** GLM
      * {{{Options: Family.family
      * --Binary Classification: binomial
      * --Multi-class Classification: multinomial
      * --Regression: gaussian, quasibinomial, ordinal, poisson, gamma, tweedie,
      * }}}
      */
    val family = Family.binomial
    // GLRM: gamma_x R(X) + gamma_y R(Y)
    val (gammaX: Double, gammaY: Double) = (0.0, 0.0)
    /** SVM: Gradient
      * Options: Hinge, LeastSquares, Logistic
      */
    val gradient: Gradient = Gradient.Logistic
    // GLM
    val gradientEpsilon: Double = 1e-6
    /** XGB
      * {{{Options:
      * --depthwise
      * --lossguide
      * }}}
      */
    val growPolicy = GrowPolicy.depthwise
    // NN
    val hidden = Array[Int](2)
    // NN: hidden dropout ratios
    val hiddenDropoutRatios: Array[Double] = Array(0.0)
    // GLRM: impute the original
    val imputeOriginal: Boolean = false
    // NN: input dropout ratio
    val inputDropoutRatio: Double = 0.0
    /** GLRM Initialization
      * {{{Options: GlrmInitialization.initialization
      * --PlusPlus
      * --Random
      * --SVD
      * --Power
      * }}}
      */
    val initGLRM = GlrmInitialization.PlusPlus
    /** NN: Initialization
      * {{{Options:
      * --UniformAdaptive: optimized initialization based on the size of the network;
      * --Uniform: zero mean with a parameterized interval (-initialWeightScale, initialWeightScale)
      * --Normal: zero mean with a parameterized standard deviation N(0, initialWeightScale^2)
      * }}}
      */
    val (initialWeightDistribution, initialWeightScale: Double) = (InitialWeightDistribution.UniformAdaptive, 1.0)
    /** KMM Initialization
      * {{{Options: KMeansInitialization.initialization
      * --Random
      * --PlusPlus
      * --Furthest
      * --User
      * }}}
      */
    val initKMM = KMeansInitialization.Furthest
    // W2V: Init learning rate
    val initLearningRate: Float = 0.025F
    // GLRM: step sizes
    val (initStepSize: Double, minStepSize: Double) = (1.0, 1e-4)
    // GLM
    val intercept: Boolean = true
    // KMM, GLRM: dimensionality (cluster)
    val k: Int = 2
    // NN: l1, l2 regularizer
    val (l1: Double, l2: Double) = (0.0, 0.0)
    // GLM, XGB
    val lambda: Double = 1.0
    // GLM
    val lambdaSearch: Boolean = false // use for GD with lambda as lambda min
    // NB: Laplace smoothing
    val laplace: Double = 0.0
    /** NN: loss
      * {{{Options: Loss.loss
      * --Automatic
      * --Classification: Quadratic, ModifiedHuber, CrossEntropy
      * --Regression: Absolute, Quadratic, Huber, Quantile
      * }}}
      */
    val loss = Loss.Automatic
    /** GLRM: <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-genmodel/src/main/java/hex/genmodel/algos/glrm/GlrmLoss.java">Loss</a>
      * {{{Options: GlrmLoss.glrmloss
      * --Numeric features
      * ----Quadratic
      * ----Absolute
      * ----Huber
      * ----Poisson
      * ----Periodic
      * --Binary features
      * ----Logistic
      * ----Hinge
      * --Multinomial features
      * ----Categorical
      * ----Ordinal
      * }}}
      */
    val (lossGLRMNum, lossGLRMCat) = (GlrmLoss.Quadratic, GlrmLoss.Categorical)
    // DRF, GBM, XGB
    val maxDepth: Int = 5
    // KMM, GLRM : max iterations
    val maxIterations: Int = 10     // Max iterations for Lloyds
    val maxModels: Int = 0 // 0 to disable
    val maxRuntimeSecs: Double = 0D // 0 to disable
    // NN: maximum on the sum of the squared incoming weights into any one neuron
    val maxW2: Float = Float.MaxValue
    // SVM: mini batch fraction
    val miniBatchFraction: Double = 0.1
    // NN: mini batch size;
    val miniBatchSize: Int = 1
    // W2V: Min Word Freq
    val minWordFreq: Int = 5
    // SVM: MissingValuesHandling
    val missingValuesHandling: MissingValuesHandling = MissingValuesHandling.Skip
    // NN: momentum, only active if disable adaptive learning rate
    // Assuming momentumStable > momentumStart, momentumRamp controls the number of instances to reach to the stable
    val (momentumStart: Double, momentumRamp: Double, momentumStable: Double) = (0.0, 1e6, .0)
    // DRF, GBM, XGB (only for Dart): minimal number of bins for numerical features
    val nbins: Int = 20
    // NN: Nesterov accelerated gradient descent
    val nesterovAcceleratedGradient: Boolean = true
    /** W2V: NormModel
      * {{{
      * Options: NormModel.normModel
      * --HSM
      * }}}
      */
    val normModel = NormModel.HSM
    // DRF, GBM, XGB
    val ntrees: Int = 5
    // GLM
    val objectiveEpsilon: Double = 1e-6
    // NN: learning rate, only active is adaptiveRate is disableed.
    val rate: Double = .005
    /** NN: learning rate annealing, only active is disable adaptiveRate
      * <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-algos/src/main/java/hex/deeplearning/DeepLearningModel.java#L1454">learning rate annealing</a>
      * The annealing rate is the inverse of the number of training samples for halving the learning rate
      */
    val rateAnnealing: Double = 1e-6 // it takes 1e6 training samples to halve the learning rate
    // NN: rate decay, only active if disable adaptiveRate
    val rateDecay: Double = 1.0
    // GLRM: Recover SVD
    val recoverSVD: Boolean = false
    // XGB
    val regAlpha: Float = 0.0F  // L1
    val regLambda: Float = 0.0F // L2
    /** GLRM: regularizer on X and Y
      * {{{Options: GlrmRegularizer.glrmRegularizer
      * --None
      * --Quadratic
      * --L2
      * --L1
      * --NonNegative
      * --OneSparse
      * --UnitOneSparse
      * --Simplex
      * }}}
      */
    // SVM: regParam
    val regParam: Double = 0.1
    // NN: early stop based on the regression error (MSE) threshold
    val regressionStop: Double = 1e-6
    val (regularizationX, regularizationY) = (GlrmRegularizer.None, GlrmRegularizer.None)
    // NN: Typical values are between 0.9 and 0.999
    val rho: Double = 0.99
    // NN: number of training/validation samples for scoring, 0 for all;
    val scoreTrainingSamples: Long = 0L
    val scoreValidationSamples: Long = 0L
    /** NN: score validation sampling method
      * {{{Options: ClassSamplingMethod.classSamplingMethod
      * --Uniform
      * --Stratified
      * }}}
      */
    val scoreValidationSampling = ClassSamplingMethod.Uniform
    // W2V: sent sample rate
    val sentSampleRate: Float = 1e-3F
    /** GLM
      * <a href = "http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/algo-params/solver.html">GLM Solver</a>
      * {{{Options: Solver.solver
      * --IRLSM: Iteratively Reweighted Least Squares Method
      * --L_BFGS: Limited-memory Broyden-Fletcher-Goldfarb-Shanno algorithm
      * --COORDINATE_DESCENT: Coordinate Decent
      * --COORDINATE_DESCENT_NAIVE: Coordinate Decent Naive
      * --AUTO: Sets the solver based on given data and parameters (default)
      * --GRADIENT_DESCENT_LH: Gradient Descent Likelihood (available for Ordinal family only; default for Ordinal family)
      * --GRADIENT_DESCENT_SQERR: Gradient Descent Squared Error (available for Ordinal family only)
      * Guidelines:
      * --L_BFGS works much better for L2-only multininomial and if you have too many active predictors.
      * --You must use IRLSM if you have p-values.
      * --IRLSM and COORDINATE_DESCENT share the same path (i.e., they both compute the same gram matrix), they just solve it differently.
      * --Use COORDINATE_DESCENT if you have less than 5000 predictors and L1 penalty.
      * --COORDINATE_DESCENT performs better when lambda_search is enabled. Also with bounds, it tends to get a higher accuracy.
      * --Use GRADIENT_DESCENT_LH or GRADIENT_DESCENT_SQERR when family=ordinal. With GRADIENT_DESCENT_LH, the model parameters are adjusted by minimizing the loss function; with GRADIENT_DESCENT_SQERR, the model parameters are adjusted using the loss function.
      * }}}
      */
    val solver = Solver.AUTO
    // GLM, NN, KMM (standardize columns)
    val standardize: Boolean = true
    // SVM
    val stepSize: Double = 1E-3
    // All models
    // stopping rounds: 0 to disable
    val stoppingRounds: Int = 0  // 0 to disable
    /** stopping metric: {AUTO, deviance, logloss, MSE, RMSE,MAE,RMSLE, AUC, lift_top_group, misclassification, mean_per_class_error, custom, r2}
      * <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-core/src/main/java/hex/ScoreKeeper.java">Stopping Metric</a>
      * mean_per_class_error --> average recall
      */
    val stoppingMetric = StoppingMetric.AUTO
    // stopping tolerance
    val stoppingTolerance: Double = 1e-4
    // GLRM: SVD method
    val svdMethod = SVDMethod.Randomized
    // SVM: threshold to separate positive and negative class
    val threshold: Double = 0.0
    /** NN: number of training samples per iteration (if using N nodes, each node will get 1/N samples)
      * {{{Options:
      * --  0: one epoch per iteration,
      * -- -1: the maximum amount of data per iteration (if **replicate training data** is enabled, N epochs will be trained per iteration on N nodes, otherwise one epoch).
      * -- -2: automatic mode (auto-tuning)
      * }}}
      */
    val trainSamplesPerIteration: Long = -2L
    /** XGB
      * {{{Options:
      * --auto
      * --exact
      * --approx
      * --hist
      * }}}
      */
    val treeMethod = TreeMethod.auto
    /** SVM: updater
      * Options: L1, L2, Simple
      */
    val updater: Updater = Updater.L2
    // W2V: Vec Size
    val vecSize: Int = 100
    val windowSize: Int = 5
    /** W2V: WordModel
      * {{{
      * Options: WordModel.wordModel
      * --SkipGram
      *}}}
      */
    val wordModel = WordModel.SkipGram
  }


  /** Map ParamType to String */
  def mapParamType2Str(paramType: ParamType.Value): String = paramType match {
    case ParamType.activation
      => "_activation"
    case ParamType.adaptiveRate
      => "_adaptive_rate"
    case ParamType.addIntercept
      => "_add_intercept"
    case ParamType.alpha
      => "_alpha"
    case ParamType.balanceClasses
      => "_balance_classes"
    case ParamType.backend
      => "_backend"
    case ParamType.betaEpsilon
      => "_beta_epsilon"
    case ParamType.booster
      => "_booster"
    case ParamType.classificationStop
      => "_classification_stop"
    case ParamType.convergenceTol
      => "_convergence_tol"
    case ParamType.dartNormalizeType
      => "_dart_normalize_type"
    case ParamType.elasticAveraging
      => "_elastic_averaging"
    case ParamType.epochs
      => "_epochs"
    case ParamType.epsilon
      => "_epsilon"
    case ParamType.estimateK
      => "_estimate_k"
    case ParamType.family
      => "_family"
    case ParamType.gammaX
      => "_gamma_x"
    case ParamType.gammaY
      => "_gamma_y"
    case ParamType.gradient
      => "_gradient"
    case ParamType.gradientEpsilon
      => "_gradient_epsilon"
    case ParamType.growPolicy
      => "_grow_policy"
    case ParamType.hidden
      => "_hidden"
    case ParamType.hiddenDropoutRatios
      => "_hidden_dropout_ratios"
    case ParamType.imputeOriginal
      => "_impute_original"
    case ParamType.inputDropoutRatio
      => "_input_dropout_ratio"
    case ParamType.initGLRM
      => "_init_glrm"
    case ParamType.initialWeightDistribution
      => "_initial_weight_distribution"
    case ParamType.initialWeightScale
      => "_initial_weight_scale"
    case ParamType.initKMM
      => "_init_kmm"
    case ParamType.initLearningRate
      => "_init_learning_rate"
    case ParamType.initStepSize
      => "_init_step_size"
    case ParamType.minStepSize
      => "_min_step_size"
    case ParamType.intercept
      => "_intercept"
    case ParamType.k
      => "_k"
    case ParamType.l1
      => "_l1"
    case ParamType.l2
      => "_l2"
    case ParamType.lambda
      => "_lambda"
    case ParamType.lambdaSearch
      => "_lambda_search"
    case ParamType.laplace
      => "_laplace"
    case ParamType.loss | ParamType.lossGLRMNum
      => "_loss"
    case ParamType.lossGLRMCat
      => "_multi_loss"
    case ParamType.maxDepth
      => "_max_depth"
    case ParamType.maxIterations
      => "_max_iterations"
    case ParamType.maxW2
      => "_max_w2"
    case ParamType.miniBatchFraction
      => "_mini_batch_fraction"
    case ParamType.miniBatchSize
      => "_mini_batch_size"
    case ParamType.minWordFreq
      => "_min_word_freq"
    case ParamType.missingValuesHandling
      => "_missing_values_handling"
    case ParamType.momentumStart
      => "_momentum_start"
    case ParamType.momentumRamp
      => "_momentum_ramp"
    case ParamType.momentumStable
      => "_momentum_stable"
    case ParamType.nbins
      => "_nbins"
    case ParamType.nesterovAcceleratedGradient
      => "_nesterov_accelerated_gradient"
    case ParamType.normModel
      => "_norm_model"
    case ParamType.ntrees
      => "_ntrees"
    case ParamType.objectiveEpsilon
      => "_objective_epsilon"
    case ParamType.rate
      => "_rate"
    case ParamType.rateAnnealing
      => "_rate_annealing"
    case ParamType.rateDecay
      => "_rate_decay"
    case ParamType.recoverSVD
      => "_recover_svd"
    case ParamType.regAlpha
      => "_reg_alpha"
    case ParamType.regLambda
      => "_reg_lambda"
    case ParamType.regParam
      => "_reg_param"
    case ParamType.regularizationX
      => "_regularization_x"
    case ParamType.regularizationY
      => "_regularization_y"
    case ParamType.regressionStop
      => "_regression_stop"
    case ParamType.rho
      => "_rho"
    case ParamType.scoreTrainingSamples
      => "_score_training_samples"
    case ParamType.scoreValidationSamples
      => "_score_validation_samples"
    case ParamType.scoreValidationSampling
      => "_score_validation_sampling"
    case ParamType.sentSampleRate
      => "_sent_sample_rate"
    case ParamType.solver
      => "_solver"
    case ParamType.standardize
      => "_standardize"
    case ParamType.stepSize
      => "_step_size"
    case ParamType.stoppingRounds
      => "_stopping_rounds"
    case ParamType.stoppingMetric
      => "_stopping_metric"
    case ParamType.stoppingTolerance
      => "_stopping_tolerance"
    case ParamType.svdMethod
      => "_svd_method"
    case ParamType.threshold
      => "_threshold"
    case ParamType.trainSamplesPerIteration
      => "_train_samples_per_iteration"
    case ParamType.treeMethod
      => "_tree_method"
    case ParamType.updater
      => "_updater"
    case ParamType.vecSize
      => "_vec_size"
    case ParamType.windowSize
      => "_window_size"
    case ParamType.wordModel
      => "_word_model"
   }
}
