package ai.minxiao.ds4s.core.h2o.learning

import java.io.{File, FileOutputStream}
import java.util.Date

import ai.h2o.automl.Algo
import ai.h2o.automl.AutoML
import ai.h2o.automl.AutoMLBuildSpec
import hex.ScoreKeeper.StoppingMetric
import water.fvec.H2OFrame
import water.Key
import water.support.H2OFrameSupport._

/**
  * @author mx
  */
@SerialVersionUID(726576L)
object H2OAutoLearner extends Serializable {

  /**
    * fit AutoML
    * @param train: training data
    * @param respCol: response column
    * @param modelSavePath: path for saving mojo (leader model)
    * @param valid: validation data (optional)
    * @param featCols: feature/descriptive columns (optional), used if data containing much more columns than the feature columns
    * @param ignCols: ignore columns (optional), used if a few columns of data will be ignored
    * @param weigCol: weight column (optional)
    * @param projName: project name (optional)
    * @param loss: loss metric, options are:
    *   AUTO, deviance, logloss, MSE, RMSE, MAE, RMSLE, AUC, lift_top_group, misclassifiction, mean_per_class_error
    * @param maxRunSeconds: max runtime in seconds, default = 3600
    * @param maxRound: max running rounds, default = 3
    * @param stopTol: stopping tolerance, default = 0.001
    * @param stopMetric: stopping metric, options are:
    *   StoppingMetric{ AUTO, deviance, logloss, MSE, RMSE,MAE,RMSLE, AUC, lift_top_group,
    *     misclassification, mean_per_class_error, custom, r2}
    * @param excludedAlgs: exluded algorithms, options are:
    *   Algo {
    *     GLM, DRF, GBM, DeepLearning, StackedEnsemble, XGBoost
    *    }
    */
  def fit(
    // input spec
    train: H2OFrame, respCol: String, modelSavePath: String,
    valid: Option[H2OFrame] = None,
    featCols: Option[Array[String]] = None,
    ignCols: Option[Array[String]] = None,
    weigCol: Option[String] = None,
    // control
    projName: Option[String] = None,
    loss: Option[String] = None,
    maxRunSeconds: Option[Int] = None,
    maxRound: Option[Int] = None,
    stopTol: Option[Double] = None,
    stopMetric: Option[StoppingMetric] = None,
    nfolds: Option[Int] = None,
    keepCVPred: Option[Boolean] = None,
    keepCVModel: Option[Boolean] = None,
    // models
    excludedAlgs: Option[Array[Algo]] = None
    // ensemble parameters
  ): AutoML = {

    /*******************setups**************************************************/
    val autoMLBuildSpec = new AutoMLBuildSpec()
    // input
    autoMLBuildSpec.input_spec.training_frame = train._key
    autoMLBuildSpec.input_spec.response_column = respCol
    if (valid != None) autoMLBuildSpec.input_spec.validation_frame = valid.get._key
    if (ignCols != None) autoMLBuildSpec.input_spec.ignored_columns = ignCols.get
      else if (featCols != None) autoMLBuildSpec.input_spec.ignored_columns =
        train.names.diff(featCols.get :+ respCol)
    if (weigCol != None) autoMLBuildSpec.input_spec.weights_column = weigCol.get
    // control
    if (projName != None) autoMLBuildSpec.build_control.project_name = projName.get
    //if (loss != None) autoMLBuildSpec.build_control.loss = loss.get
    if (maxRunSeconds != None) autoMLBuildSpec.build_control.stopping_criteria.set_max_runtime_secs(maxRunSeconds.get)
    if (maxRound != None) autoMLBuildSpec.build_control.stopping_criteria.set_stopping_rounds(maxRound.get)
    if (stopTol != None) autoMLBuildSpec.build_control.stopping_criteria.set_stopping_tolerance(stopTol.get)
    if (stopMetric != None) autoMLBuildSpec.build_control.stopping_criteria.set_stopping_metric(stopMetric.get)
    // control: cv
    if (nfolds != None) autoMLBuildSpec.build_control.nfolds = nfolds.get
    //if (keepCVPred != None) autoMLBuildSpec.build_control.keep_cross_validation_predictions = keepCVPred.get
    //if (keepCVModel != None) autoMLBuildSpec.build_control.keep_cross_validation_models = keepCVModel.get
    //if (excludedAlgs != None) autoMLBuildSpec.build_models.exclude_algos = excludedAlgs.get

    // running AutoML
    val aml = AutoML.makeAutoML(Key.make(), new Date(), autoMLBuildSpec)
    AutoML.startAutoML(aml)
    aml.get()
    // Note: In some cases the above call is non-blocking
    // So using the following alternative function will block the next commmand, untill the exection of action command
    //AutoML.startAutoML(autoMLBuildSpec).get()  // This is forced blocking call
    // get the leaderboard
    aml.leaderboard
    // get the details of the leader model
    aml.leader

    //save the leader model
    aml.leader().getMojo().writeTo(new FileOutputStream(new File(modelSavePath + "/" + aml.leader()._key + ".zip")))
    aml
  }
}
