/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.streaming

import java.util.concurrent.CopyOnWriteArrayList

import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{EVENT, LISTENER}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.{CloseableIterator, SparkConnectClient}
import org.apache.spark.sql.streaming.StreamingQueryListener.{AsyncStateCommitCompletionEvent, Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}

class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
  private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
  private var executionThread: Option[Thread] = Option.empty

  val lock = new Object()

  def close(): Unit = {
    listeners.forEach(remove(_))
  }

  // BEGIN-EDGE
  private val LOCAL_CREDENTIALS_KEY = "X-Databricks-Local-Credentials-Key"

  private def requestCredential: Map[String, String] = {
    // A call to request the server to cache the credentials of the caller
    // and return a credential key. Subsequent call with this key would make the server
    //  handle the request with the cached credential.
    //  This is called under 2 conditions:
    //  1. For the initial request to the server to start a long-running gRPC response stream,
    //     This is because the long-running stream needs the credential to refresh the connection.
    //  2. For the subsequent listener callback function in case they contain spark command.
    //     This is to make sure such server connection is invoked with the registrant's credential.

    val cmdBuilder = Command.newBuilder()
    cmdBuilder.getStreamingQueryListenerBusCommandBuilder.setCacheUserCredentials(true)
    val response = sparkSession.execute(cmdBuilder.build()).head
    val key = response.getStreamingQueryListenerEventsResult.getCredentialKey
    Map(LOCAL_CREDENTIALS_KEY -> key)
  }

  private def removeCachedCredential(listener: StreamingQueryListener): Unit = {
    // A call to remove the credential cached.
    // This is invoked when the listener is removed.
    val cmdBuilder = Command.newBuilder()
    if (listener.getMetadata.contains(LOCAL_CREDENTIALS_KEY)) {
      cmdBuilder.getStreamingQueryListenerBusCommandBuilder.setRemoveUserCredentials(
        listener.getMetadata(LOCAL_CREDENTIALS_KEY))
      sparkSession.execute(cmdBuilder.build())
    }
  }
  // END-EDGE

  def append(listener: StreamingQueryListener): Unit = {
    // BEGIN-EDGE
    try {
      listener.setMetadata(requestCredential)
    } catch {
      case e: Throwable =>
        logWarning("Failed to add the listener credential. " +
          "The listener is not added, please try again. " +
          "If error persists, please contact support.", e)
        return
    }
    // END-EDGE

    lock.synchronized {
      listeners.add(listener)

      if (listeners.size() == 1) {
        var iter: Option[CloseableIterator[ExecutePlanResponse]] = Option.empty
        try {
          iter = Some(registerServerSideListener())
        } catch {
          case e: Exception =>
            logWarning("Failed to add the listener, please add it again.", e)
            listeners.remove(listener)
            return
        }
        executionThread = Some(new Thread(new Runnable {
          def run(): Unit = {
            queryEventHandler(iter.get)
          }
        }))
        // Start the thread
        executionThread.get.start()
      }
    }
  }

  def remove(listener: StreamingQueryListener): Unit = lock.synchronized {
    if (listeners.size() == 1) {
      val cmdBuilder = Command.newBuilder()
      cmdBuilder.getStreamingQueryListenerBusCommandBuilder
        .setRemoveListenerBusListener(true)
      try {
        sparkSession.execute(cmdBuilder.build())
      } catch {
        case e: Exception =>
          logWarning("Failed to remove the listener, please remove it again.", e)
          return
      }
      if (executionThread.isDefined) {
        executionThread.get.interrupt()
        executionThread = Option.empty
      }
    }
    listeners.remove(listener)

    // BEGIN-EDGE
    try {
      removeCachedCredential(listener)
    } catch {
      case e: Throwable =>
        logWarning("Failed to remove the listener credential. " +
          "The listener is still removed.", e)
    }
    // END-EDGE
  }

  def list(): Array[StreamingQueryListener] = lock.synchronized {
    listeners.asScala.toArray
  }

  def registerServerSideListener(): CloseableIterator[ExecutePlanResponse] = {
    val credentialMetadata = requestCredential // EDGE
    val cmdBuilder = Command.newBuilder()
    cmdBuilder.getStreamingQueryListenerBusCommandBuilder
      .setAddListenerBusListener(true)

    val plan = Plan.newBuilder().setCommand(cmdBuilder.build()).build()
    val iterator = sparkSession.client.execute(plan, extraHeaders = credentialMetadata) // EDGE
    while (iterator.hasNext) {
      val response = iterator.next()
      if (response.getStreamingQueryListenerEventsResult.hasListenerBusListenerAdded &&
        response.getStreamingQueryListenerEventsResult.getListenerBusListenerAdded) {
        return iterator
      }
    }
    iterator
  }

  def queryEventHandler(iter: CloseableIterator[ExecutePlanResponse]): Unit = {
    try {
      while (iter.hasNext) {
        val response = iter.next()
        val listenerEvents = response.getStreamingQueryListenerEventsResult.getEventsList
        listenerEvents.forEach(event => {
          event.getEventType match {
            case StreamingQueryEventType.QUERY_PROGRESS_EVENT =>
              postToAll(QueryProgressEvent.fromJson(event.getEventJson))
            case StreamingQueryEventType.QUERY_IDLE_EVENT =>
              postToAll(QueryIdleEvent.fromJson(event.getEventJson))
            case StreamingQueryEventType.QUERY_TERMINATED_EVENT =>
              postToAll(QueryTerminatedEvent.fromJson(event.getEventJson))
            case StreamingQueryEventType.ASYNC_STATE_COMMIT_COMPLETION_EVENT =>
              postToAll(AsyncStateCommitCompletionEvent.fromJson(event.getEventJson))
            case _ =>
              logWarning(log"Unknown StreamingQueryListener event: ${MDC(EVENT, event)}")
          }
        })
      }
    } catch {
      case e: Exception =>
        logWarning("StreamingQueryListenerBus Handler thread received exception, all client" +
          " side listeners are removed and handler thread is terminated.", e)
        lock.synchronized {
          executionThread = Option.empty
          listeners.forEach(remove(_))
        }
    }
  }

  def postToAll(event: Event): Unit = lock.synchronized {
    listeners.forEach(listener =>
      try {
        // Update the SparkConnectClient's thread local to use credentials associated with the user
        // that created the listener while running its callback.
        val handle = sparkSession.client.ExtraRequestMetadata.create(listener.getMetadata) // EDGE
        handle.runWith { // EDGE
          event match {
            case t: QueryStartedEvent =>
              listener.onQueryStarted(t)
            case t: QueryProgressEvent =>
              listener.onQueryProgress(t)
            case t: QueryIdleEvent =>
              listener.onQueryIdle(t)
            case t: QueryTerminatedEvent =>
              listener.onQueryTerminated(t)
            case t: AsyncStateCommitCompletionEvent =>
              listener.onAsyncStateCommitCompletion(t)
            case _ => logWarning(log"Unknown StreamingQueryListener event: ${MDC(EVENT, event)}")
          }
        }
      } catch {
        case e: Exception =>
          logWarning(log"Listener ${MDC(LISTENER, listener)} threw an exception", e)
      })
  }
}
