package ai.minxiao.ds4s.core.h2o.prediction

import hex.genmodel.easy.RowData
import hex.ModelCategory
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset,
  Row, SQLContext, SparkSession}
import org.apache.spark.sql.types._

/**
  * H2OSparkPredictor
  *
  * A trained H2O Model for use in Spark Environment.
  * It loads a pre-trained h2o pojo/mojo model and add prediction functions on RDD, Row, DataFrame
  *
  * @constructor load a trained model with a type and source (name/path).
  * @param tp the model's type: mojo vs pojo
  * @param src the name/path of the pojo/mojo
  *
  * @author mx
  */
@SerialVersionUID(728380L)
class H2OSparkPredictor(private val tp: String, private val src: String)
    extends H2OPredictor(tp, src) with Serializable {

  // from RDD to RowData
  protected def getRowData(rdd: RDD[(String, Any)]): RowData = {
    val rowData = new RowData
    val f = (x: Tuple2[String, Any]) => this.names.contains(x._1)
    rdd.filter(f)
      .foreach{case (x, y) => rowData.put(x, y.asInstanceOf[AnyRef])}
    rowData
  }


  /**
    * Predict on RDD
    * @param rdd One instance is a RDD[Tuple2[String, Any]], where each element is a feature (name-value pair)
    * @return label-score pairs (if applicable)
    */
  def predictRdd(rdd: RDD[(String, Any)]): (String, Double) = {
    val rowData = getRowData(rdd)
    predict(rowData)
  }

  protected def getRowData(row: Row, cols: Array[String]): RowData = {
    val rowData = new RowData
    this.names.filter(cols.contains) // intersection of model columns and data columns
      .foreach(col => rowData.put(col, row.get(row.fieldIndex(col)).asInstanceOf[AnyRef]))
    rowData
  }
  /**
    * Predict on Row, one Row is one instance, helpinng function, used in predictRow
    * @param row one instance
    * @param cols Row column names (dataframe column names)
    * @return label-score pair (if applicable)
    */
  def _predictRow(row: Row, cols: Array[String]): (String, Double) = {
    val rowData = getRowData(row, cols)
    predict(rowData)
  }

  /**
    * Prediction on Row data with key attached
    * @param row data
    * @param cols data column names
    * @param keys keys, array of strings, kept in the result, default=Array()
    * @return a row of predicted label with confidence score (if applicalbe) and original keys (if provided)
    */
  def predictRow(row: Row, cols: Array[String], keys: Array[String] = Array()): Row = {
    val p = _predictRow(row, cols)
    val r = keys.map(x => row.get(cols.indexOf(x))) :+ p._1 :+ p._2
    Row(r: _*)
  }

  /**
    * Conditional Prediction on Row data with key attached, along with non-prediction value/score
    * @param row data
    * @param cols data column names
    * @param usePrediction the logic function to decide whether apply the model on this particular instance
    * @param labelCol origin label column, if not applying the model, this column's value will be directy retrieved
    * @param keys origin keys, default=Array()
    * @return the row of predictions with keys (optional)
    */
  def predictRowCond(row: Row, cols: Array[String],
    usePrediction: Row => Boolean, labelCol: String, score: Double,
    keys: Array[String] = Array()): Row = {
    val p =
      if (usePrediction(row)) _predictRow(row, cols)
      else (row.getString(cols.indexOf(labelCol)), score)
    val r = keys.map(x => row.get(cols.indexOf(x))) :+ p._1 :+ p._2
    Row(r: _*)
  }

  /**
    * prediction on a whole dataframe with key attached
    * @param df data
    * @param keys keys, default=Array()
    * @param predLabelName the column name used to store the prediction label, default="label"
    * @param predScoreName the column name used to store the prediction score, default="score"
    * @return dataframe with key(s), label, score
    */
  def predictDF(df: DataFrame, keys: Array[String] = Array.empty[String],
    predLabelName: String = "label", predScoreName: String = "score")(implicit sqlContext: SQLContext): DataFrame = {
    import sqlContext.implicits._
    val header = df.columns
    val f = predictRow(_: Row, header, keys)
    val ss = df.rdd.map(f)
    val schema = StructType(
      keys.map(x => StructField(x, StringType, false)) :+
      StructField(predLabelName, StringType, false) :+
      StructField(predScoreName, DoubleType, false)
    )
    sqlContext.createDataFrame(ss, schema)
  }

}
