package net.sansa_stack.rdf.spark.io

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}


class ParquetTransformer(val session: SparkSession) {

  def exportAsParquet(inputPath: String, outputPath: String): Unit = {
    val schema = new StructType()
      .add(StructField("s", StringType, nullable = false))
      .add(StructField("p", StringType, nullable = false))
      .add(StructField("o", StringType, nullable = false))

    val rdd = NTripleReader.load(session, inputPath)

    val df = session.createDataFrame(
      rdd.map(t => Row(t.getSubject.toString, t.getPredicate.toString, t.getObject.toString)),
      schema)

    df.sort("s")
      .write.mode(org.apache.spark.sql.SaveMode.Overwrite)
      .partitionBy("p")
      .parquet(outputPath)

  }

  def transform(inputPath: String, outputPath: String): Unit = {
    // create triples table
    val triplesTable = createTriplesTableHive(tableDirectory = inputPath)
    triplesTable.printSchema()

    val tmp = session.sql("SELECT * FROM triples")
    println("#triples: " + tmp.count())
    tmp.show(10)

    // save as Parquet
    tmp.repartition(4)
      .write
      .mode(org.apache.spark.sql.SaveMode.Overwrite)
      .format("parquet")
      .saveAsTable("ntriples")
  }

  def triplesTable(): Unit = {

  }

  val regex = "(\\\\S+)\\\\s+(\\\\S+)\\\\s+(.*)"


  def createTriplesTableHive(schema: SQLSchema = SQLSchemaDefault, tableDirectory: String): DataFrame = {
    session.sql(
      s"""
         |CREATE EXTERNAL TABLE IF NOT EXISTS ${schema.triplesTable}
         |(${schema.subjectCol} STRING, ${schema.predicateCol} STRING, ${schema.objectCol} STRING)
         |ROW FORMAT  SERDE 'org.apache.hadoop.hive.serde2.RegexSerDe'
         |WITH SERDEPROPERTIES(
         |"input.regex" = '$regex'
         |)
         |LOCATION '$tableDirectory'
      """.stripMargin)

  }
}

/**
  * Takes an RDD[Triple] and serializes it as Parquet format to disk.
  *
  * @author Lorenz Buehmann
  */
object ParquetTransformer {

  def apply(session: SparkSession): ParquetTransformer = new ParquetTransformer(session)

//  def transform(inputPath: String, outputPath: String): Unit = {
//
//  }

  def main(args: Array[String]): Unit = {
    val session = SparkSession.builder()
      .appName("Parquet transformer")
//      .config("spark.eventLog.enabled", true)
//      .master("local")
//      .enableHiveSupport()
      .getOrCreate()

    ParquetTransformer(session).exportAsParquet(args(0), args(1))

//    val df = session.read.format("parquet").load(args(1))
//      .createOrReplaceTempView("triples")
//    val chunk = session.sql("select * from triples where s = 'http://dbpedia.org/resource/Dennis_Bergkamp'")
// //    val chunk = df.filter("s = 'http://dbpedia.org/resource/Dennis_Bergkamp'")
//    chunk.show(10)
//    chunk.explain(true)



//    //ParquetTransformer(session).transform("/home/user/work/datasets/lubm/1000", "/tmp/lubm/10/parquet")
//
//    // read table by name
//    val triples = session.table("ntriples")
//    println(triples.count())
//
//    println(triples.show(10))
    session.close()
  }

}



/**
  * The SQL schema used for an RDF graph.
  *
  * @param triplesTable the name of the triples table
  * @param subjectCol   the name of the subject column
  * @param predicateCol the name of the predicate column
  * @param objectCol    the name of the object column
  *
  * @author Lorenz Buehmann
  */
class SQLSchema(val triplesTable: String, val subjectCol: String, val predicateCol: String, val objectCol: String) {}

object SQLSchemaDefault extends SQLSchema("TRIPLES", "s", "p", "o") {}
