package ai.minxiao.ds4s.core.h2o.prediction

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset,
  Row, SQLContext, SparkSession}
import org.apache.spark.sql.types._

/**
  * H2O Spark AutoEncoder Predictor Trait
  *
  * @note should be coupled with H2OSparkPredictor
  * @example new H2OSparkPredictor(tp, src) with H2OSparkAEPredictor
  *
  * @author mx
  */
trait H2OSparkAEPredictor extends H2OAEPredictor {
  this: H2OSparkPredictor =>

  /**
    * @param rdd data to infer anomaly reasons
    * @param feat2Reason feature-index => reason
    * @param topN top-N reason
    * @return top-N (reason, score) pairs
    */
  def anomalyReason(rdd: RDD[(String, Any)], feat2Reason: Map[Int, String], topN: Int): Array[(String, Double)] = {
    val rowData = getRowData(rdd)
    anomalyReason(rowData, feat2Reason, topN)
  }

  /**
    * @param row data to infer anomaly reasons
    * @param cols row-data column names
    * @param feat2Reason feature-index => reason
    * @param topN top-N reason
    * @return top-N (reason, score) pairs
    */
  private def _anomalyReason(row: Row, cols: Array[String], feat2Reason: Map[Int, String], topN: Int): Array[(String, Double)] = {
    val rowData = getRowData(row, cols)
    anomalyReason(rowData, feat2Reason, topN)
  }

  /**
    * @param row data to infer anomaly reasons
    * @param cols row-data column names
    * @param feat2Reason feature-index => reason
    * @param topN top-N reason
    * @param keys original keys
    * @return top-N reason and scores with original keys (original keys, reason1, reason2, ..., reasonN, score1, score2, ..., scoreN)
    */
  def anomalyReason(row: Row, cols: Array[String], feat2Reason: Map[Int, String], topN: Int, keys: Array[String]): Row = {
    val p = _anomalyReason(row, cols, feat2Reason, topN)
    val rs = p.map(_._1) ++ p.map(_._2)
    val r = keys.map(x => row.get(cols.indexOf(x))) ++ rs
    Row(r: _*)
  }

  /**
    * @param df dataframe to infer reasons
    * @param feat2Reason feature index to reason
    * @param topN top-N reason
    * @param keys original keys
    * @param reasonPrefix reason prefix
    * @param scorePrefix score prefix
    * @param return dataframe attached with reasons and scores
    */
  def anomalyReason(df: DataFrame,
    feat2Reason: Map[Int, String], topN: Int,
    keys: Array[String])(
    reasonPrefix: String, scorePrefix: String)(implicit sqlContext: SQLContext): DataFrame = {
    import sqlContext.implicits._
    val header = df.columns
    val f = anomalyReason(_: Row, header, feat2Reason, topN, keys)
    val ss = df.rdd.map(f)
    val reasonSchema = Array.fill(topN)(reasonPrefix).
      zipWithIndex.map{case (prefix, index) => s"${prefix}${index + 1}"}.
      map(x => StructField(x, StringType, false))
    val scoreSchema = Array.fill(topN)(scorePrefix).
      zipWithIndex.map{case (prefix, index) => s"${prefix}${index + 1}"}.
      map(x => StructField(x, DoubleType, false))
    val schema = StructType(
      keys.map(x => StructField(x, StringType, false)) ++
      reasonSchema ++
      scoreSchema
    )
    sqlContext.createDataFrame(ss, schema)

  }




}
