/*
 * DATABRICKS CONFIDENTIAL & PROPRIETARY
 * __________________
 *
 * Copyright 2023-present Databricks, Inc.
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains the property of Databricks, Inc.
 * and its suppliers, if any.  The intellectual and technical concepts contained herein are
 * proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
 * patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
 * or reproduction of this information is strictly forbidden unless prior written permission is
 * obtained from Databricks, Inc.
 *
 * If you view or obtain a copy of this information and believe Databricks, Inc. may not have
 * intended it to be made available, please promptly report it to Databricks Legal Department
 * @ legal@databricks.com.
 */
package com.databricks.spark.sql.remotefiltering

import java.util.UUID

import scala.collection.JavaConverters._
import scala.concurrent.duration.FiniteDuration
import scala.util.Random

import com.databricks.sql.remotefiltering.{CloudFetchResults, EmbeddedSparkConnectClient, RemoteMetric, RemoteMetricsNode}

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.connect.client.MTlsBuilder
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.types.StructType


/**
 * This class is the implementation of the abstract class defined in Spark SQL. The implementation
 * is split from the class to be able to load the class via reflection from a different
 * classloader.
 *
 * Within the Spark Connect client, this class can be freely used.
 */
class EmbeddedSparkConnectClientImpl(
    host: String,
    token: String,
    clusterId: Option[String],
    sourceClusterId: Option[String],
    customTags: Option[String],
    workloadId: Option[String],
    port: Int = 443,
    grpcMaxMessageSize: Int = 1024 * 1024 * 1024,
    persistSession: Boolean = false,
    mtlsEnabled: Boolean = true,
    reuseSessionEnabled: Boolean = false)
  extends EmbeddedSparkConnectClient(host, token, clusterId, sourceClusterId, customTags,
      workloadId, port, reuseSessionEnabled)
    with Logging {

  // The user agent is used to identify the usage for the remote filtering usage directly
  // from the existing query logs.
  private val USER_AGENT = "SCALA_REMOTE_FILTERING"

  // The spark session is lazily created and reassigned to a new session if the previous one
  // is closed.
  private var spark: Option[SparkSession] = None

  // The session ID is used to create the session. This might differ from
  // EmbeddedSparkConnectClient.sessionId in case new session id is created
  // after this client is created.
  private var clientSessionId: Option[String] = None


  private def createNewSession(): SparkSession = {
    // TODO(nemanja.boric): remove the mTLS SparkConf once we are in the PuPr and start charging.
    val clientBuilder = (new MTlsBuilder)
      .withMTlsEnabled(mtlsEnabled)
      .host(host)
      .port(port)
      .userId("na")
      .userAgent(USER_AGENT)
      .grpcMaxMessageSize(grpcMaxMessageSize)

    clusterId match {
      case Some(id) =>
        logInfo(s"Remote filtering shared cluster ID: ${id}")
        clientBuilder.option("x-databricks-cluster-id", id)
      case None =>
        logInfo(s"Remote filtering serverless session ID: ${sessionId}")
        clientBuilder.option("x-databricks-session-id", sessionId)
        clientSessionId = Some(sessionId)
    }

    if (!token.isEmpty) {
      clientBuilder.token(token)
    }

    if (workloadId.isDefined) {
      logInfo(s"Propagating workloadId to remote cluster: ${workloadId.get}")
      clientBuilder.option("x-databricks-workload-id", workloadId.get)
    }

    // Add billing metadata.
    clientBuilder.option("x-databricks-workload-type", "FilteringService")
    if (sourceClusterId.isDefined) {
      logInfo(s"Remote filtering source cluster ID: ${sourceClusterId.get}")
      clientBuilder.option("x-databricks-fgac-source-cluster-id", sourceClusterId.get)
    }
    if (customTags.isDefined) {
      logInfo(s"Remote filtering source cluster customTags: ${customTags.get}")
      clientBuilder.option("x-databricks-custom-user-tags", customTags.get)
    }

    val client = clientBuilder.build()
    SparkSession.builder().client(client).getOrCreate()
  }

  def callWithClient[S](f: SparkSession => S): S = {
    if (spark.isEmpty) spark = Some(createNewSession())
    try {
      f(spark.get)
    } finally {
      if (!persistSession) {
        // Close the session if persistSession is false
        spark.get.close()
        spark = None
      }
    }
  }

  private def callAndRetryWithClient[S](f: SparkSession => S, maxRetries: Int,
      delay: FiniteDuration, maxJitter: FiniteDuration = FiniteDuration(2000, "ms")): S = {
    try {
      callWithClient(f)
    } catch {
      case e: SparkException if maxRetries > 0 && isRetriableError(e.getMessage) =>
        if (clientSessionId.contains(EmbeddedSparkConnectClient.sessionId)) {
          // Only set a new session id if the client was using the most recent
          // session id.
          EmbeddedSparkConnectClient.sessionId = UUID.randomUUID().toString
        }
        spark = None
        logInfo(s"Creating a new session with session id ${sessionId} due to session closure.")
        val jitter = Random.nextDouble() * maxJitter.toMillis
        Thread.sleep(delay.toMillis + jitter.toLong)
        callAndRetryWithClient(f, maxRetries - 1, delay * 2, maxJitter)
    }
  }

  /**
   * Checks the error message to see if it contains any of the specific error strings that are
   * known to be transient and can be retried.
   *
   * @param message The exception message string.
   * @return true if the message contains any of the specific error strings, false otherwise.
   */
  private def isRetriableError(message: String): Boolean = {
    // Check if the gateway has closed the session
    val containsFailedPrecondition = message.contains(
      "grpc_shaded.io.grpc.StatusRuntimeException: FAILED_PRECONDITION: BAD_REQUEST:")

    // Check if the gateway has reported a transient error
    val containsAborted = message.contains(
      "grpc_shaded.io.grpc.StatusRuntimeException: ABORTED: ABORTED:")

    // Check if too many new sessions are being created, to back off and retry.
    val containsResourceExhausted = message.contains(
      "grpc_shaded.io.grpc.StatusRuntimeException: RESOURCE_EXHAUSTED: RESOURCE_EXHAUSTED:")

    // Return true if any of the conditions is met
    containsFailedPrecondition || containsAborted || containsResourceExhausted
  }

  private def withClient[S](f: SparkSession => S): S = {
    if (reuseSessionEnabled) {
      // Retry the operation with a new session in case of timeout. Low number of retries since if
      // query still fails after a new session is created, it is likely due to a separate issue.
      callAndRetryWithClient(f, 3, FiniteDuration(10, "s"))
    } else {
      callWithClient(f)
    }
  }


  /**
   * Creates a new proto plan object to read from the table name. The is returned as a plain Java
   * Object reference to hide the actual value.
   *
   * @param tableName
   *   The name of the table.
   * @return
   *   A plain object reference of the proto plan.
   */
  override def table(tableName: String): proto.Plan = withClient { spark =>
    spark.read.table(tableName).plan
  }

  override def limit(plan: Object, limit: Int): proto.Plan = withClient { spark =>
    val df = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
    df.limit(limit).plan
  }

  /**
   * Given a plan object will apply the predicates as SQL string expressions on the plan. This is
   * similar to calling `df.filter`.
   * @param plan
   *   Proto plan
   * @param predicates
   *   List of string expressions to be applied
   * @return
   *   new proto plan
   */
  override def applyPredicates(plan: Object, predicates: Array[String]): proto.Plan = withClient {
    spark =>
      // For each predicate, compile it to SQL and push it into the plan.
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      predicates.map(expr).reduceOption(_ && _).map(p => df.filter(p)).getOrElse(df).plan
  }

  override def applyGroupBy(
      plan: Object,
      aggExpr: Array[String],
      groupExpr: Array[String]): proto.Plan =
    withClient { spark =>
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      val aggs: Array[Column] = aggExpr.map(expr(_))
      val groups: Array[Column] = groupExpr.map(expr(_))

      // If no aggregations are present, we simply rewrite this as a distinct query.
      val res = if (aggs.isEmpty) {
        df.select(groups: _*).distinct()
      } else {
        val rgds: RelationalGroupedDataset = df.groupBy(groups: _*)
        rgds.agg(aggs.head, aggs.tail: _*)
      }
      res.plan
    }

  /**
   * Returns a new proto plan object for a given SQL query.
   *
   * @param query
   *   The SQL querys tring.
   * @return
   *   A plain object reference of the proto plan.
   */
  override def sql(query: String): proto.Plan = withClient { spark =>
    spark.sql(query).plan
  }

  /**
   * Given a valid proto plan object will execute the query and return a sequence of InternalRows.
   * This will materialize all of the rows in memory.
   *
   * @param plan
   * @return
   */
  override def execute(plan: Object): CloudFetchResults = try {
    withClient { spark =>
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      val result = df.collectHybridCloudResult()
      try {
        CloudFetchResults(result.toCloudFetchBatches, result.metricsOpt.map(translateMetrics))
      } finally {
        result.close()
      }
    }
  } finally {
    if (persistSession && !reuseSessionEnabled) {
      // Close spark session only after executing plan
      spark.get.close()
    }
  }

  /**
   * Translates {@link proto.ExecutePlanRequest.Metrics} into a {@link Seq} of
   * {@link RemoteMetricsNode}.
   *
   * @param metrics A protobuf metrics object
   * @return A {@link Seq} of {@link RemoteMetricsNode}
   */
  private def translateMetrics(metrics: proto.ExecutePlanResponse.Metrics)
  : Seq[RemoteMetricsNode] = {
    metrics.getMetricsList.asScala.toSeq.map { metricsObject =>
      RemoteMetricsNode(
        metricsObject.getName,
        metricsObject.getExecutionMetricsMap.asScala.map {
          case (metricsKey, metric) =>
            (metricsKey, RemoteMetric(metric.getName, metric.getMetricType, metric.getValue))
        }.toMap)
    }
  }

  /**
   * Returns the schema of the plan object.
   *
   * @param plan
   * @return
   */
  override def schema(plan: Object): StructType = withClient { spark =>
    spark.newDataFrame(plan.asInstanceOf[proto.Plan]).schema
  }

  override def select(plan: Object, cols: Array[String]): proto.Plan = withClient { spark =>
    if (cols.isEmpty) { // ES-1057381: Do not prune if no columns are selected
      plan.asInstanceOf[proto.Plan]
    } else {
      val df: DataFrame = spark.newDataFrame(plan.asInstanceOf[proto.Plan])
      df.select(cols.head, cols.tail: _*).plan
    }
  }
}
