/*
 * DATABRICKS CONFIDENTIAL & PROPRIETARY
 * __________________
 *
 * Copyright 2023-present Databricks, Inc.
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains the property of Databricks, Inc.
 * and its suppliers, if any.  The intellectual and technical concepts contained herein are
 * proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
 * patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
 * or reproduction of this information is strictly forbidden unless prior written permission is
 * obtained from Databricks, Inc.
 *
 * If you view or obtain a copy of this information and believe Databricks, Inc. may not have
 * intended it to be made available, please promptly report it to Databricks Legal Department
 * @ legal@databricks.com.
 */
package com.databricks.spark.sql.remotefiltering

import java.util.UUID

import com.databricks.sql.remotefiltering.{CloudFetchResult, EmbeddedSparkConnectClient}

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.types.StructType

/**
 * This class is the implementation of the abstract class defined in Spark SQL. The implementation
 * is split from the class to be able to load the class via reflection from a different
 * classloader.
 *
 * Within the Spark Connect client, this class can be freely used.
 */
class EmbeddedSparkConnectClientImpl(
    host: String,
    token: String,
    clusterId: Option[String],
    workloadId: Option[String],
    port: Int = 443)
    extends EmbeddedSparkConnectClient(host, token, clusterId, workloadId, port)
    with Logging {

  // The user agent is used to identify the usage for the remote filtering usage directly
  // from the existing query logs.
  private val USER_AGENT_FRAGMENT = ";user_agent=SCALA_REMOTE_FILTERING"

  // UUID used to identify the serverless session. There is exactly one session per instance
  // of the embedded client.
  private val sessionID = UUID.randomUUID().toString

  private def withClient[S](f: SparkSession => S): S = {
    val endpoint = clusterId match {
      case Some(id) => s"x-databricks-cluster-id=${id}"
      case None =>
        logInfo(s"Remote filtering serverless session ID: ${sessionID}")
        s"x-databricks-session-id=${sessionID}"
    }

    val tokenConfig = if (token.isEmpty) {
      ""
    } else {
      s";token=${token}"
    }

    val propagatedWorkloadId = workloadId match {
      case Some(id) =>
        logInfo(s"Propagating workloadId to remote cluster: ${id}")
        s";x-databricks-workload-id=${id}"
      case None => ""
    }

    val spark = SparkSession.builder()
      .remote(s"sc://${host}:${port}/${tokenConfig};user_id=na;${endpoint}${USER_AGENT_FRAGMENT}" +
      s"${propagatedWorkloadId}")
      .build()
    try {
      f(spark)
    } finally {
      spark.close()
    }
  }

  /**
   * Creates a new proto plan object to read from the table name. The is returned as a plain Java
   * Object reference to hide the actual value.
   *
   * @param tableName
   *   The name of the table.
   * @return
   *   A plain object reference of the proto plan.
   */
  override def table(tableName: String): proto.Plan = withClient { spark =>
    spark.read.table(tableName).plan
  }

  override def limit(plan: Object, limit: Int): proto.Plan = withClient { spark =>
    val df = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
    df.limit(limit).plan
  }

  /**
   * Given a plan object will apply the predicates as SQL string expressions on the plan. This is
   * similar to calling `df.filter`.
   * @param plan
   *   Proto plan
   * @param predicates
   *   List of string expressions to be applied
   * @return
   *   new proto plan
   */
  override def applyPredicates(plan: Object, predicates: Array[String]): proto.Plan = withClient {
    spark =>
      // For each predicate, compile it to SQL and push it into the plan.
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      predicates.map(expr).reduceOption(_ && _).map(p => df.filter(p)).getOrElse(df).plan
  }

  override def applyGroupBy(
      plan: Object,
      aggExpr: Array[String],
      groupExpr: Array[String]): proto.Plan =
    withClient { spark =>
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      val aggs: Array[Column] = aggExpr.map(expr(_))
      val groups: Array[Column] = groupExpr.map(expr(_))

      // If no aggregations are present, we simply rewrite this as a distinct query.
      val res = if (aggs.isEmpty) {
        df.select(groups: _*).distinct()
      } else {
        val rgds: RelationalGroupedDataset = df.groupBy(groups: _*)
        rgds.agg(aggs.head, aggs.tail: _*)
      }
      res.plan
    }

  /**
   * Returns a new proto plan object for a given SQL query.
   *
   * @param query
   *   The SQL querys tring.
   * @return
   *   A plain object reference of the proto plan.
   */
  override def sql(query: String): proto.Plan = withClient { spark =>
    spark.sql(query).plan
  }

  /**
   * Given a valid proto plan object will execute the query and return a sequence of InternalRows.
   * This will materialize all of the rows in memory.
   *
   * @param plan
   * @return
   */
  override def execute(plan: Object): Seq[CloudFetchResult] = withClient { spark =>
    val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
    val result = df.collectHybridCloudResult()
    try {
      result.toCloudFetchBatches
    } finally {
      result.close()
    }
  }

  /**
   * Returns the schema of the plan object.
   *
   * @param plan
   * @return
   */
  override def schema(plan: Object): StructType = withClient { spark =>
    spark.newDataFrame(plan.asInstanceOf[proto.Plan]).schema
  }

  override def select(plan: Object, cols: Array[String]): proto.Plan = withClient { spark =>
    val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
    df.select(cols.head, cols.tail: _*).plan
  }
}
