package ai.minxiao.ds4s.core.h2o.learning

import java.net.URI
import scala.collection.JavaConversions._
import scala.annotation._
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import hex.glm.GLMModel.GLMParameters
import hex.glrm.GLRMModel.GLRMParameters
import hex.grid.{Grid, GridSearch, HyperSpaceSearchCriteria, HyperSpaceWalker}
import hex.grid.GridSearch.SimpleParametersBuilderFactory
import hex.grid.HyperSpaceSearchCriteria.{CartesianSearchCriteria, RandomDiscreteValueSearchCriteria, Strategy}
import hex.kmeans.KMeansModel.KMeansParameters
import hex.{Model, ModelParametersBuilderFactory}
import hex.Model.{Parameters, Output}
import hex.naivebayes.NaiveBayesModel.NaiveBayesParameters
import hex.ScoreKeeper.StoppingMetric
import hex.tree.drf.DRFModel.DRFParameters
import hex.tree.gbm.GBMModel.GBMParameters
import hex.tree.xgboost.XGBoostModel.XGBoostParameters
import hex.word2vec.Word2VecModel.{Word2VecParameters}
import org.apache.spark.h2o.H2OContext
import org.apache.spark.ml.spark.models.svm.SVMParameters
import water.fvec.H2OFrame
import water.Key
import water.support.H2OFrameSupport._
import water.support.{ModelSerializationSupport}

/** H2O Grid Searching
  *
  * @constructor
  * @param modelType ModelType
  * @param gsParams grid search parameters
  * @author mx
  */
@SerialVersionUID(727183L)
class H2OGridSearcher(modelType: ModelType.Value, gsParams: Map[ParamType.Value, Array[Any]],
  gsStrategy: Strategy = Strategy.Cartesian,
  rndStopCriteria: Map[ParamType.Value, Any] = Map.empty)
    extends H2OLearner(modelType) with Serializable {

  // grid search hyperparameters
  private val _gsParams: Map[String, Array[AnyRef]] = gsParams.map{
    case (k, v) => mapParamType2Str(k) -> v.map(_.asInstanceOf[AnyRef])
  }

  /** fit
    * @param train: h2oframe for training
    * @param valid: h2oframe for validation
    * @param respCol: response column
    * @param featCols: feature columns
    * @param ignCols: ignore columns
    * @param weigCol: weight column
    */
  def search(train: H2OFrame, respCol: String,
    valid: Option[H2OFrame] = None,
    featCols: Option[Array[String]] = None,
    ignCols: Option[Array[String]] = None,
    weigCol: Option[String] = None,
    mdKey: String = "model")(implicit h2oContext: H2OContext): Grid[_ <: Parameters] = {

    // model parameters: adding data frames, features, response
    val mdParams = _mdParams(gsParams.keySet)
    mdParams._train = train._key
    mdParams._response_column = respCol
    if (valid != None) mdParams._valid = valid.get._key
    if (ignCols != None) mdParams._ignored_columns = ignCols.get
      else if (featCols != None) mdParams._ignored_columns =
        train.names.diff(featCols.get :+ respCol)
    if (weigCol != None) mdParams._weights_column = weigCol.get

    // searching space criteria
    val gsCriteria = gsStrategy match {
      case Strategy.RandomDiscrete => {
        def let[A](in: A)(body: A => Unit) = {
          body(in)
          in
        }
        let(new RandomDiscreteValueSearchCriteria) { c =>
          if (rndStopCriteria.contains(ParamType.stoppingRounds))
            c.set_stopping_rounds(rndStopCriteria(ParamType.stoppingRounds).asInstanceOf[Int])
          if (rndStopCriteria.contains(ParamType.stoppingMetric))
            c.set_stopping_metric(rndStopCriteria(ParamType.stoppingMetric).asInstanceOf[StoppingMetric])
          if (rndStopCriteria.contains(ParamType.stoppingTolerance))
            c.set_stopping_tolerance(rndStopCriteria(ParamType.stoppingTolerance).asInstanceOf[Double])
          if (rndStopCriteria.contains(ParamType.maxRuntimeSecs))
            c.set_max_runtime_secs(rndStopCriteria(ParamType.maxRuntimeSecs).asInstanceOf[Double])
          if (rndStopCriteria.contains(ParamType.maxModels))
            c.set_max_models(rndStopCriteria(ParamType.maxModels).asInstanceOf[Int])
        }
      }
      case _ /*Cartesian*/ => new CartesianSearchCriteria
    }

    val key: Key[Grid[_ <: Parameters]] = Key.make(mdKey)

    { modelType match {
      case ModelType.NB => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[NaiveBayesParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[NaiveBayesParameters],
          gsCriteria
        )
      case ModelType.GLM => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[GLMParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[GLMParameters],
          gsCriteria
        )
      case ModelType.NN => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[DeepLearningParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[DeepLearningParameters],
          gsCriteria
        )
      case ModelType.DRF => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[DRFParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[DRFParameters],
          gsCriteria
        )
      case ModelType.GBM => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[GBMParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[GBMParameters],
          gsCriteria
        )
      case ModelType.XGB => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[XGBoostParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[XGBoostParameters],
          gsCriteria
        )
      case ModelType.KMM => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[KMeansParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[KMeansParameters],
          gsCriteria
        )
      case ModelType.GLRM => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[GLRMParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[GLRMParameters],
          gsCriteria
        )
      case ModelType.W2V => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[Word2VecParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[Word2VecParameters],
          gsCriteria
        )
      case ModelType.SVM => GridSearch.startGridSearch(key,
          mdParams.asInstanceOf[SVMParameters],
          _gsParams,
          new SimpleParametersBuilderFactory[SVMParameters],
          gsCriteria
        )
      }
    }.get
  }

  /** fit
    * @param train: h2oframe for training
    * @param valid: h2oframe for validation
    * @param respCol: response column
    * @param featCols: feature columns
    * @param ignCols: ignore columns
    * @param weigCol: weight column
    */
  def searchAndSave(train: H2OFrame, respCol: String,
    valid: Option[H2OFrame] = None,
    featCols: Option[Array[String]] = None,
    ignCols: Option[Array[String]] = None,
    weigCol: Option[String] = None,
    mdKey: String = "",
    mdSaveType: String = "mojo",
    mdSavePath: String = "",
    force: Boolean = false)(implicit h2oContext: H2OContext): Grid[_ <: Parameters] = {

    require(Array("mojo", "pojo", "model").contains(mdSaveType))

    val mdSuffix = if (mdSaveType == "mojo") "zip"
      else if (mdSaveType == "pojo") "java"
      else "model"

    val fittedGS = search(train, respCol,
      valid,
      featCols, ignCols, weigCol,
      mdKey)(h2oContext)

    fittedGS.getModels.foreach { md =>
      val mdUrl = URI.create(s"${mdSavePath}/${md._key}.${mdSuffix}")
      mdSaveType match {
        case "mojo" =>
          ModelSerializationSupport.exportMOJOModel(md, mdUrl, force)
        case "pojo" =>
          ModelSerializationSupport.exportPOJOModel(md, mdUrl, force)
        case _ =>
          ModelSerializationSupport.exportH2OModel(md, mdUrl, force)
      }
    }
    fittedGS
  }

}
