package ai.minxiao.ds4s.core.h2o.learning

import java.net.URI
import scala.collection.mutable.{Map => MMap}
import scala.annotation._

import hex.ConfusionMatrix
import hex.deeplearning.{DeepLearning, DeepLearningModel}
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import hex.deeplearning.DeepLearningModel.DeepLearningParameters.{Activation, ClassSamplingMethod, InitialWeightDistribution, Loss}
import hex.genmodel.algos.glrm.{GlrmInitialization, GlrmLoss, GlrmRegularizer}
import hex.glm.{GLM, GLMModel}
import hex.glm.GLMModel.GLMParameters
import hex.glm.GLMModel.GLMParameters.{Family, Solver}
import hex.glrm.{GLRM, GLRMModel}
import hex.glrm.GLRMModel.GLRMParameters
import hex.kmeans.{KMeans, KMeansModel}
import hex.kmeans.KMeans.{Initialization => KMeansInitialization}
import hex.kmeans.KMeansModel.KMeansParameters
import hex.Model
import hex.Model.{Parameters, Output}
import hex.naivebayes.{NaiveBayes, NaiveBayesModel}
import hex.naivebayes.NaiveBayesModel.NaiveBayesParameters
import hex.ScoreKeeper.StoppingMetric
import hex.svd.SVDModel.SVDParameters.{Method => SVDMethod}
import hex.tree.drf.{DRF, DRFModel}
import hex.tree.drf.DRFModel.DRFParameters
import hex.tree.gbm.{GBM, GBMModel}
import hex.tree.gbm.GBMModel.GBMParameters
import hex.tree.xgboost.{XGBoost, XGBoostModel}
import hex.tree.xgboost.XGBoostModel.XGBoostParameters
import hex.tree.xgboost.XGBoostModel.XGBoostParameters.{Backend, Booster, GrowPolicy, TreeMethod, DartNormalizeType}
import hex.word2vec.{Word2Vec, Word2VecModel}
import hex.word2vec.Word2Vec.{WordModel, NormModel}
import hex.word2vec.Word2VecModel.{Word2VecParameters}
import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.spark.models.MissingValuesHandling
import org.apache.spark.ml.spark.models.svm.{SVM, SVMModel, SVMParameters} // from h2o's
import org.apache.spark.ml.spark.models.svm.{Gradient, Updater}
import water.fvec.H2OFrame
import water.Key
import water.support.H2OFrameSupport._
import water.support.{ModelSerializationSupport}

/**
  * H2OLearner
  *
  * @author mx
  */
@SerialVersionUID(7276L)
class H2OLearner(modelType: ModelType.Value) extends Serializable {
  /** Hyperparameters */
  private lazy val _params: MMap[ParamType.Value, Any] = modelType match {
    case ModelType.GLM      => MMap(
      ParamType.alpha       -> ParamValue.alpha,
      ParamType.lambda      -> ParamValue.lambda,
      ParamType.lambdaSearch-> ParamValue.lambdaSearch,
      ParamType.intercept   -> ParamValue.intercept,
      ParamType.family      -> ParamValue.family,
      ParamType.solver      -> ParamValue.solver,
      ParamType.standardize -> ParamValue.standardize,
      ParamType.maxIterations
                            -> ParamValue.maxIterations,
      ParamType.betaEpsilon -> ParamValue.betaEpsilon,
      ParamType.gradientEpsilon
                            -> ParamValue.gradientEpsilon,
      ParamType.objectiveEpsilon
                            -> ParamValue.objectiveEpsilon
    )
    case ModelType.NB   => MMap(
      ParamType.laplace -> ParamValue.laplace
    )
    case ModelType.DRF | ModelType.GBM => MMap(
      ParamType.ntrees            -> ParamValue.ntrees,
      ParamType.maxDepth          -> ParamValue.maxDepth,
      ParamType.nbins             -> ParamValue.nbins,
      ParamType.balanceClasses    -> ParamValue.balanceClasses,
      ParamType.stoppingRounds    -> ParamValue.stoppingRounds,
      ParamType.stoppingMetric    -> ParamValue.stoppingMetric,
      ParamType.stoppingTolerance -> ParamValue.stoppingTolerance
    )
    case ModelType.XGB            => MMap(
      ParamType.ntrees            -> ParamValue.ntrees,
      ParamType.maxDepth          -> ParamValue.maxDepth,
      ParamType.nbins             -> ParamValue.nbins,
      ParamType.balanceClasses    -> ParamValue.balanceClasses,
      ParamType.regAlpha          -> ParamValue.regAlpha,
      ParamType.regLambda         -> ParamValue.regLambda,
      ParamType.backend           -> ParamValue.backend,
      ParamType.booster           -> ParamValue.booster,
      ParamType.dartNormalizeType -> ParamValue.dartNormalizeType,
      ParamType.growPolicy        -> ParamValue.growPolicy,
      ParamType.treeMethod        -> ParamValue.treeMethod,
      ParamType.stoppingRounds    -> ParamValue.stoppingRounds,
      ParamType.stoppingMetric    -> ParamValue.stoppingMetric,
      ParamType.stoppingTolerance -> ParamValue.stoppingTolerance
    )
    case ModelType.NN                    => MMap(
      ParamType.balanceClasses           -> ParamValue.balanceClasses,
      ParamType.hidden                   -> ParamValue.hidden,
      ParamType.activation               -> ParamValue.activation,
      ParamType.epochs                   -> ParamValue.epochs,
      ParamType.standardize              -> ParamValue.standardize,
      ParamType.miniBatchSize            -> ParamValue.miniBatchSize,
      ParamType.trainSamplesPerIteration -> ParamValue.trainSamplesPerIteration,
      ParamType.adaptiveRate             -> ParamValue.adaptiveRate,
      ParamType.rho                      -> ParamValue.rho,
      ParamType.epsilon                  -> ParamValue.epsilon,
      ParamType.rate                     -> ParamValue.rate,
      ParamType.rateAnnealing            -> ParamValue.rateAnnealing,
      ParamType.rateDecay                -> ParamValue.rateDecay,
      ParamType.momentumStart            -> ParamValue.momentumStart,
      ParamType.momentumRamp             -> ParamValue.momentumRamp,
      ParamType.momentumStable           -> ParamValue.momentumStable,
      ParamType.nesterovAcceleratedGradient
                                         -> ParamValue.nesterovAcceleratedGradient,
      ParamType.inputDropoutRatio        -> ParamValue.inputDropoutRatio,
      ParamType.hiddenDropoutRatios      -> ParamValue.hiddenDropoutRatios,
      ParamType.l1                       -> ParamValue.l1,
      ParamType.l2                       -> ParamValue.l2,
      ParamType.maxW2                    -> ParamValue.maxW2,
      ParamType.initialWeightDistribution-> ParamValue.initialWeightDistribution,
      ParamType.initialWeightScale       -> ParamValue.initialWeightScale,
      ParamType.loss                     -> ParamValue.loss,
      ParamType.scoreTrainingSamples     -> ParamValue.scoreTrainingSamples,
      ParamType.scoreValidationSamples   -> ParamValue.scoreValidationSamples,
      ParamType.scoreValidationSampling  -> ParamValue.scoreValidationSampling,
      ParamType.classificationStop       -> ParamValue.classificationStop,
      ParamType.regressionStop           -> ParamValue.regressionStop,
      ParamType.elasticAveraging         -> ParamValue.elasticAveraging,
      ParamType.stoppingRounds           -> ParamValue.stoppingRounds,
      ParamType.stoppingMetric           -> ParamValue.stoppingMetric,
      ParamType.stoppingTolerance        -> ParamValue.stoppingTolerance,
      ParamType.autoencoder              -> ParamValue.autoencoder
    )
    case ModelType.KMM        => MMap(
      ParamType.k             -> ParamValue.k,
      ParamType.maxIterations -> ParamValue.maxIterations,
      ParamType.initKMM       -> ParamValue.initKMM,
      ParamType.standardize   -> ParamValue.standardize,
      ParamType.estimateK     -> ParamValue.estimateK
    )
    case ModelType.GLRM       => MMap(
      ParamType.k             -> ParamValue.k,
      ParamType.maxIterations -> ParamValue.maxIterations,
      ParamType.initGLRM      -> ParamValue.initGLRM,
      ParamType.svdMethod     -> ParamValue.svdMethod,
      ParamType.lossGLRMNum   -> ParamValue.lossGLRMNum,
      ParamType.lossGLRMCat   -> ParamValue.lossGLRMCat,
      ParamType.regularizationX
                              -> ParamValue.regularizationX,
      ParamType.regularizationY
                              -> ParamValue.regularizationY,
      ParamType.gammaX        -> ParamValue.gammaX,
      ParamType.gammaY        -> ParamValue.gammaY,
      ParamType.initStepSize  -> ParamValue.initStepSize,
      ParamType.minStepSize   -> ParamValue.minStepSize,
      ParamType.recoverSVD    -> ParamValue.recoverSVD,
      ParamType.imputeOriginal-> ParamValue.imputeOriginal
    )
    case ModelType.W2V           => MMap(
      ParamType.wordModel        -> ParamValue.wordModel,
      ParamType.normModel        -> ParamValue.normModel,
      ParamType.vecSize          -> ParamValue.vecSize,
      ParamType.windowSize       -> ParamValue.windowSize,
      ParamType.minWordFreq      -> ParamValue.minWordFreq,
      ParamType.sentSampleRate   -> ParamValue.sentSampleRate,
      ParamType.epochs           -> ParamValue.epochs,
      ParamType.initLearningRate -> ParamValue.initLearningRate
    )
    case ModelType.SVM           => MMap(
      ParamType.addIntercept     -> ParamValue.addIntercept,
      ParamType.convergenceTol   -> ParamValue.convergenceTol,
      ParamType.gradient         -> ParamValue.gradient,
      ParamType.maxIterations    -> ParamValue.maxIterations,
      ParamType.miniBatchFraction-> ParamValue.miniBatchFraction,
      ParamType.missingValuesHandling
                                 -> ParamValue.missingValuesHandling,
      ParamType.regParam         -> ParamValue.regParam,
      ParamType.stepSize         -> ParamValue.stepSize,
      ParamType.threshold        -> ParamValue.threshold,
      ParamType.updater          -> ParamValue.updater
    )
  }
  def params: MMap[ParamType.Value, Any]= _params
  def params(params: Map[ParamType.Value, Any]): this.type = {
    params.foreach{ case (k, v) => _params(k) = v}
    this
  }
  /**
    * @param paramTypeSet excluded key set
    */
  protected def _mdParams(paramTypeSet: Set[ParamType.Value] = Set()): Parameters = modelType match {
    case ModelType.GLM => {
      val mdParams = new GLMParameters
      // architecture params
      if (! paramTypeSet.contains(ParamType.alpha))
        mdParams._alpha = Array[Double](_params(ParamType.alpha).asInstanceOf[Double])
      if (! paramTypeSet.contains(ParamType.lambda))
        mdParams._lambda = Array[Double](_params(ParamType.lambda).asInstanceOf[Double])
      if (! paramTypeSet.contains(ParamType.lambdaSearch))
        mdParams._lambda_search = _params(ParamType.lambdaSearch).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.intercept))
        mdParams._intercept = _params(ParamType.intercept).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.family))
        mdParams._family = _params(ParamType.family).asInstanceOf[Family]
      if (! paramTypeSet.contains(ParamType.solver))
        mdParams._solver = _params(ParamType.solver).asInstanceOf[Solver]
      if (! paramTypeSet.contains(ParamType.standardize))
        mdParams._standardize = _params(ParamType.standardize).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.maxIterations))
        mdParams._max_iterations = _params(ParamType.maxIterations).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.objectiveEpsilon))
        mdParams._objective_epsilon = _params(ParamType.objectiveEpsilon).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.betaEpsilon))
        mdParams._beta_epsilon= _params(ParamType.betaEpsilon).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.gradientEpsilon))
        mdParams._gradient_epsilon = _params(ParamType.gradientEpsilon).asInstanceOf[Double]
      mdParams
    }
    case ModelType.NB => {
      val mdParams = new NaiveBayesParameters
      if (! paramTypeSet.contains(ParamType.laplace))
        mdParams._laplace = _params(ParamType.laplace).asInstanceOf[Double]
      mdParams
    }
    case ModelType.DRF => {
      val mdParams = new DRFParameters
      if (! paramTypeSet.contains(ParamType.ntrees))
        mdParams._ntrees = _params(ParamType.ntrees).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.maxDepth))
        mdParams._max_depth = _params(ParamType.maxDepth).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.nbins))
        mdParams._nbins = _params(ParamType.nbins).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.balanceClasses))
        mdParams._balance_classes = _params(ParamType.balanceClasses).asInstanceOf[Boolean]
      mdParams._stopping_rounds = _params(ParamType.stoppingRounds).asInstanceOf[Int]
      mdParams._stopping_metric = _params(ParamType.stoppingMetric).asInstanceOf[StoppingMetric]
      mdParams._stopping_tolerance = _params(ParamType.stoppingTolerance).asInstanceOf[Double]
      mdParams
    }
    case ModelType.GBM => {
      val mdParams = new GBMParameters
      if (! paramTypeSet.contains(ParamType.ntrees))
        mdParams._ntrees = _params(ParamType.ntrees).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.maxDepth))
        mdParams._max_depth = _params(ParamType.maxDepth).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.nbins))
        mdParams._nbins = _params(ParamType.nbins).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.balanceClasses))
        mdParams._balance_classes = _params(ParamType.balanceClasses).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.stoppingRounds))
        mdParams._stopping_rounds = _params(ParamType.stoppingRounds).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.stoppingMetric))
        mdParams._stopping_metric = _params(ParamType.stoppingMetric).asInstanceOf[StoppingMetric]
      if (! paramTypeSet.contains(ParamType.stoppingTolerance))
        mdParams._stopping_tolerance = _params(ParamType.stoppingTolerance).asInstanceOf[Double]
      mdParams
    }
    case ModelType.XGB               => {
      val mdParams                   = new XGBoostParameters
      if (! paramTypeSet.contains(ParamType.ntrees))
        mdParams._ntrees             = _params(ParamType.ntrees).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.maxDepth))
        mdParams._max_depth          = _params(ParamType.maxDepth).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.nbins))
        mdParams._max_bins           = _params(ParamType.nbins).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.balanceClasses))
        mdParams._balance_classes    = _params(ParamType.balanceClasses).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.regAlpha))
        mdParams._reg_alpha          = _params(ParamType.regAlpha).asInstanceOf[Float]
      if (! paramTypeSet.contains(ParamType.regLambda))
        mdParams._reg_lambda         = _params(ParamType.regLambda).asInstanceOf[Float]
      if (! paramTypeSet.contains(ParamType.treeMethod))
        mdParams._tree_method        = _params(ParamType.treeMethod).asInstanceOf[TreeMethod]
      if (! paramTypeSet.contains(ParamType.growPolicy))
        mdParams._grow_policy        = _params(ParamType.growPolicy).asInstanceOf[GrowPolicy]
      if (! paramTypeSet.contains(ParamType.booster))
        mdParams._booster            = _params(ParamType.booster).asInstanceOf[Booster]
      if (! paramTypeSet.contains(ParamType.dartNormalizeType))
        mdParams._normalize_type     = _params(ParamType.dartNormalizeType).asInstanceOf[DartNormalizeType]
      if (! paramTypeSet.contains(ParamType.backend))
        mdParams._backend            = _params(ParamType.backend).asInstanceOf[Backend]
      if (! paramTypeSet.contains(ParamType.stoppingRounds))
        mdParams._stopping_rounds    = _params(ParamType.stoppingRounds).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.stoppingMetric))
        mdParams._stopping_metric    = _params(ParamType.stoppingMetric).asInstanceOf[StoppingMetric]
      if (! paramTypeSet.contains(ParamType.stoppingTolerance))
        mdParams._stopping_tolerance = _params(ParamType.stoppingTolerance).asInstanceOf[Double]
      mdParams
    }
    case ModelType.NN                   => {
      val mdParams                      = new DeepLearningParameters
      if (! paramTypeSet.contains(ParamType.balanceClasses))
        mdParams._balance_classes      = _params(ParamType.balanceClasses).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.activation))
        mdParams._activation            = _params(ParamType.activation).asInstanceOf[Activation]
      if (! paramTypeSet.contains(ParamType.hidden))
        mdParams._hidden                = _params(ParamType.hidden).asInstanceOf[Array[Int]]
      if (! paramTypeSet.contains(ParamType.standardize))
        mdParams._standardize           = _params(ParamType.standardize).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.epochs))
        mdParams._epochs                = _params(ParamType.epochs).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.miniBatchSize))
        mdParams._mini_batch_size       = _params(ParamType.miniBatchSize).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.trainSamplesPerIteration))
        mdParams._train_samples_per_iteration
                                        = _params(ParamType.trainSamplesPerIteration).asInstanceOf[Long]
      if (! paramTypeSet.contains(ParamType.adaptiveRate))
        mdParams._adaptive_rate         = _params(ParamType.adaptiveRate).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.rho))
        mdParams._rho                   = _params(ParamType.rho).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.epsilon))
        mdParams._epsilon               = _params(ParamType.epsilon).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.rate))
        mdParams._rate                  = _params(ParamType.rate).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.rateAnnealing))
        mdParams._rate_annealing        = _params(ParamType.rateAnnealing).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.rateDecay))
        mdParams._rate_decay            = _params(ParamType.rateDecay).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.momentumStart))
        mdParams._momentum_start        = _params(ParamType.momentumStart).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.momentumRamp))
        mdParams._momentum_ramp         = _params(ParamType.momentumRamp).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.momentumStable))
        mdParams._momentum_stable       = _params(ParamType.momentumStable).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.nesterovAcceleratedGradient))
        mdParams._nesterov_accelerated_gradient
                                        = _params(ParamType.nesterovAcceleratedGradient).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.inputDropoutRatio))
        mdParams._input_dropout_ratio   = _params(ParamType.inputDropoutRatio).asInstanceOf[Double]
      if ((mdParams._activation == Activation.TanhWithDropout ||
          mdParams._activation == Activation.RectifierWithDropout ||
          mdParams._activation == Activation.MaxoutWithDropout ) &&
          (! paramTypeSet.contains(ParamType.hiddenDropoutRatios)))
        mdParams._hidden_dropout_ratios
                                        = _params(ParamType.hiddenDropoutRatios).asInstanceOf[Array[Double]]
      if (! paramTypeSet.contains(ParamType.l1))
        mdParams._l1                    = _params(ParamType.l1).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.l2))
        mdParams._l2                    = _params(ParamType.l2).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.maxW2))
        mdParams._max_w2                = _params(ParamType.maxW2).asInstanceOf[Float]
      if (! paramTypeSet.contains(ParamType.initialWeightDistribution))
        mdParams._initial_weight_distribution
                                        = _params(ParamType.initialWeightDistribution).asInstanceOf[InitialWeightDistribution]
      if (! paramTypeSet.contains(ParamType.initialWeightScale))
        mdParams._initial_weight_scale  = _params(ParamType.initialWeightScale).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.loss))
        mdParams._loss                  = _params(ParamType.loss).asInstanceOf[Loss]
      if (! paramTypeSet.contains(ParamType.scoreTrainingSamples))
        mdParams._score_training_samples= _params(ParamType.scoreTrainingSamples).asInstanceOf[Long]
      if (! paramTypeSet.contains(ParamType.scoreValidationSamples))
        mdParams._score_validation_samples
                                        = _params(ParamType.scoreValidationSamples).asInstanceOf[Long]
      if (! paramTypeSet.contains(ParamType.scoreValidationSampling))
        mdParams._score_validation_sampling
                                        = _params(ParamType.scoreValidationSampling).asInstanceOf[ClassSamplingMethod]
      if (! paramTypeSet.contains(ParamType.classificationStop))
        mdParams._classification_stop   = _params(ParamType.classificationStop).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.regressionStop))
        mdParams._regression_stop       = _params(ParamType.regressionStop).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.elasticAveraging))
        mdParams._elastic_averaging     = _params(ParamType.elasticAveraging).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.stoppingRounds))
        mdParams._stopping_rounds       = _params(ParamType.stoppingRounds).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.stoppingMetric))
        mdParams._stopping_metric       = _params(ParamType.stoppingMetric).asInstanceOf[StoppingMetric]
      if (! paramTypeSet.contains(ParamType.stoppingTolerance))
        mdParams._stopping_tolerance    = _params(ParamType.stoppingTolerance).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.autoencoder))
        mdParams._autoencoder           = _params(ParamType.autoencoder).asInstanceOf[Boolean]
      mdParams
    }
    case ModelType.KMM           => {
      val mdParams               = new KMeansParameters
      if (! paramTypeSet.contains(ParamType.k))
        mdParams._k              = _params(ParamType.k).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.maxIterations))
        mdParams._max_iterations = _params(ParamType.maxIterations).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.initKMM))
        mdParams._init           = _params(ParamType.initKMM).asInstanceOf[KMeansInitialization]
      if (! paramTypeSet.contains(ParamType.standardize))
        mdParams._standardize    = _params(ParamType.standardize).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.estimateK))
        mdParams._estimate_k     = _params(ParamType.estimateK).asInstanceOf[Boolean]
      mdParams
    }
    case ModelType.GLRM          => {
      val mdParams               = new GLRMParameters
      if (! paramTypeSet.contains(ParamType.k))
        mdParams._k              = _params(ParamType.k).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.maxIterations))
        mdParams._max_iterations = _params(ParamType.maxIterations).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.initGLRM))
        mdParams._init           = _params(ParamType.initGLRM).asInstanceOf[GlrmInitialization]
      if (! paramTypeSet.contains(ParamType.svdMethod))
        mdParams._svd_method     = _params(ParamType.svdMethod).asInstanceOf[SVDMethod]
      if (! paramTypeSet.contains(ParamType.lossGLRMNum))
        mdParams._loss           = _params(ParamType.lossGLRMNum).asInstanceOf[GlrmLoss]
      if (! paramTypeSet.contains(ParamType.lossGLRMCat))
        mdParams._multi_loss     = _params(ParamType.lossGLRMCat).asInstanceOf[GlrmLoss]
      if (! paramTypeSet.contains(ParamType.regularizationX))
        mdParams._regularization_x
                                 = _params(ParamType.regularizationX).asInstanceOf[GlrmRegularizer]
      if (! paramTypeSet.contains(ParamType.gammaX))
        mdParams._gamma_x        = _params(ParamType.gammaX).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.regularizationY))
        mdParams._regularization_y
                                 = _params(ParamType.regularizationY).asInstanceOf[GlrmRegularizer]
      if (! paramTypeSet.contains(ParamType.gammaY))
        mdParams._gamma_y        = _params(ParamType.gammaY).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.initStepSize))
        mdParams._init_step_size = _params(ParamType.initStepSize).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.minStepSize))
        mdParams._min_step_size  = _params(ParamType.minStepSize).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.recoverSVD))
        mdParams._recover_svd    = _params(ParamType.recoverSVD).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.imputeOriginal))
        mdParams._impute_original= _params(ParamType.imputeOriginal).asInstanceOf[Boolean]
      mdParams
    }
    case ModelType.W2V               => {
      val mdParams                   = new Word2VecParameters
      if (! paramTypeSet.contains(ParamType.wordModel))
        mdParams._word_model         = _params(ParamType.wordModel).asInstanceOf[WordModel]
      if (! paramTypeSet.contains(ParamType.normModel))
        mdParams._norm_model         = _params(ParamType.normModel).asInstanceOf[NormModel]
      if (! paramTypeSet.contains(ParamType.vecSize))
        mdParams._vec_size           = _params(ParamType.vecSize).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.windowSize))
        mdParams._window_size        = _params(ParamType.windowSize).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.minWordFreq))
        mdParams._min_word_freq      = _params(ParamType.minWordFreq).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.sentSampleRate))
        mdParams._sent_sample_rate   = _params(ParamType.sentSampleRate).asInstanceOf[Float]
      if (! paramTypeSet.contains(ParamType.epochs))
        mdParams._epochs             = _params(ParamType.epochs).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.initLearningRate))
        mdParams._init_learning_rate = _params(ParamType.initLearningRate).asInstanceOf[Float]
      mdParams
    }
    case ModelType.SVM                => {
      val mdParams                    = new SVMParameters
      if (! paramTypeSet.contains(ParamType.addIntercept))
        mdParams._add_intercept       = _params(ParamType.addIntercept).asInstanceOf[Boolean]
      if (! paramTypeSet.contains(ParamType.convergenceTol))
        mdParams._convergence_tol     = _params(ParamType.convergenceTol).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.gradient))
        mdParams._gradient            = _params(ParamType.gradient).asInstanceOf[Gradient]
      if (! paramTypeSet.contains(ParamType.maxIterations))
        mdParams._max_iterations      = _params(ParamType.maxIterations).asInstanceOf[Int]
      if (! paramTypeSet.contains(ParamType.miniBatchFraction))
        mdParams._mini_batch_fraction = _params(ParamType.miniBatchFraction).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.missingValuesHandling))
        mdParams._missing_values_handling
                                      = _params(ParamType.missingValuesHandling).asInstanceOf[MissingValuesHandling]
      if (! paramTypeSet.contains(ParamType.regParam))
        mdParams._reg_param           = _params(ParamType.regParam).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.stepSize))
        mdParams._step_size           = _params(ParamType.stepSize).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.threshold))
        mdParams._threshold           = _params(ParamType.threshold).asInstanceOf[Double]
      if (! paramTypeSet.contains(ParamType.updater))
        mdParams._updater             = _params(ParamType.updater).asInstanceOf[Updater]
      mdParams
    }
  }
  /** score curve */
  /** confusion matrix */
  /** variance importance */

  /** fit
    * @param train: h2oframe for training
    * @param valid: h2oframe for validation
    * @param respCol: response column
    * @param featCols: feature columns
    * @param ignCols: ignore columns
    * @param weigCol: weight column
    */
  def fit(train: H2OFrame, respCol: String,
    valid: Option[H2OFrame] = None,
    featCols: Option[Array[String]] = None,
    ignCols: Option[Array[String]] = None,
    weigCol: Option[String] = None,
    mdKey: String = "model")(implicit h2oContext: H2OContext): Model[_ <: Model[_, _, _], _ <: Parameters, _ <: Output] = {

    val mdParams = _mdParams()
    mdParams._train = train._key
    mdParams._response_column = respCol
    if (valid != None) mdParams._valid = valid.get._key
    if (ignCols != None) mdParams._ignored_columns = ignCols.get
      else if (featCols != None) mdParams._ignored_columns =
        train.names.diff(featCols.get :+ respCol)
    if (weigCol != None) mdParams._weights_column = weigCol.get

    modelType match {
      case ModelType.GLM => {
        val key: Key[GLMModel] = Key.make(mdKey)
        val model = new GLM(mdParams.asInstanceOf[GLMParameters], key)
        model.trainModel.get
      }
      case ModelType.NB  => {
        val model = new NaiveBayes(mdParams.asInstanceOf[NaiveBayesParameters]) {
          _result = Key.make[NaiveBayesModel](mdKey) // workaround to set the model key
        }
        model.trainModel.get
        water.DKV.get(mdKey).get[NaiveBayesModel]
      }
      case ModelType.DRF => {
        val key: Key[DRFModel] = Key.make(mdKey)
        val model = new DRF(mdParams.asInstanceOf[DRFParameters], key)
        model.trainModel.get
      }
      case ModelType.GBM => {
        val key: Key[GBMModel] = Key.make(mdKey)
        val model = new GBM(mdParams.asInstanceOf[GBMParameters], key)
        model.trainModel.get
      }
      case ModelType.XGB => {
        val key: Key[XGBoostModel] = Key.make(mdKey)
        val model = new XGBoost(mdParams.asInstanceOf[XGBoostParameters], key)
        model.trainModel.get
      }
      case ModelType.NN => {
        val key: Key[DeepLearningModel] = Key.make(mdKey)
        val model = new DeepLearning(mdParams.asInstanceOf[DeepLearningParameters], key)
        model.trainModel.get
      }
      case ModelType.KMM => {
        val key: Key[KMeansModel] = Key.make(mdKey)
        val model = new KMeans(mdParams.asInstanceOf[KMeansParameters])
        model.trainModel.get
      }
      case ModelType.GLRM => {
        val key: Key[KMeansModel] = Key.make(mdKey)
        val model = new GLRM(mdParams.asInstanceOf[GLRMParameters])
        model.trainModel.get
      }
      case ModelType.W2V => {
        val key: Key[Word2VecModel] = Key.make(mdKey)
        val model = new Word2Vec(mdParams.asInstanceOf[Word2VecParameters])
        model.trainModel.get
      }
      case ModelType.SVM => {
        val model = new SVM(mdParams.asInstanceOf[SVMParameters], h2oContext) {
          _result = Key.make[SVMModel](mdKey)
        }
        model.trainModel.get
        water.DKV.get(mdKey).get[SVMModel]
      }
    }
  }

  /** fit
    * @param train: h2oframe for training
    * @param valid: h2oframe for validation
    * @param respCol: response column
    * @param featCols: feature columns
    * @param ignCols: ignore columns
    * @param weigCol: weight column
    */
  def fitAndSave(train: H2OFrame, respCol: String,
    valid: Option[H2OFrame] = None,
    featCols: Option[Array[String]] = None,
    ignCols: Option[Array[String]] = None,
    weigCol: Option[String] = None,
    mdKey: String = "",
    mdSaveType: String = "mojo",
    mdSavePath: String = "",
    force: Boolean = false)(implicit h2oContext: H2OContext): Model[_ <: Model[_, _, _], _ <: Parameters, _ <: Output] = {

    require(Array("mojo", "pojo", "model").contains(mdSaveType))

    val fittedModel = fit(train, respCol,
      valid,
      featCols, ignCols, weigCol,
      mdKey)(h2oContext)

    val mdSuffix = if (mdSaveType == "mojo") "zip"
      else if (mdSaveType == "pojo") "java"
      else "model"
    val mdUrl = URI.create(s"${mdSavePath}/${mdKey}.${mdSuffix}")

    mdSaveType match {
      case "mojo" =>
        ModelSerializationSupport.exportMOJOModel(fittedModel, mdUrl, force)
      case "pojo" =>
        ModelSerializationSupport.exportPOJOModel(fittedModel, mdUrl, force)
      case _ =>
        ModelSerializationSupport.exportH2OModel(fittedModel, mdUrl, force)
    }
    fittedModel
  }

}
