/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package wvlet.airframe.http

import wvlet.airframe.codec.{GenericException, GenericStackTraceElement, MessageCodec}
import wvlet.airframe.http.HttpMessage.Response
import wvlet.airframe.http.RPCException.rpcErrorMessageCodec
import wvlet.airframe.http.internal.HttpResponseBodyCodec
import wvlet.airframe.json.Json
import wvlet.airframe.msgpack.spi.MsgPack
import wvlet.log.LogSupport

import scala.util.Try

/**
  * RPCException provides a backend-independent (e.g., Finagle or gRPC) RPC error reporting mechanism. Create this
  * exception with (RPCStatus code).toException(...) method.
  *
  * If necessary, we can add more standard error_details parameter like
  * https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto
  */
case class RPCException(
    // RPC status
    status: RPCStatus = RPCStatus.INTERNAL_ERROR_I0,
    // Error message
    message: String = "",
    // Cause of the exception
    cause: Option[Throwable] = None,
    // [optional] Application-specific status code
    appErrorCode: Option[Int] = None,
    // [optional] Application-specific metadata
    metadata: Map[String, Any] = Map.empty
) extends Exception(s"[${status}] ${message}", cause.getOrElse(null))
    with LogSupport {

  private var _includeStackTrace: Option[Boolean] = None

  /**
    * Do not embed stacktrace and the cause objects in the RPC exception error response
    */
  def noStackTrace: RPCException = {
    _includeStackTrace = Some(false)
    this
  }

  def shouldReportStackTrace: Boolean = {
    _includeStackTrace match {
      case Some(b) => b
      case None =>
        status.shouldReportStackTrace
    }
  }

  def toMessage: RPCErrorMessage = {
    RPCErrorMessage(
      code = status.code,
      codeName = status.name,
      message = message,
      stackTrace = if (shouldReportStackTrace) Some(GenericException.extractStackTrace(this)) else None,
      cause = if (shouldReportStackTrace) cause else None,
      appErrorCode = appErrorCode,
      metadata = metadata
    )
  }

  def toJson: Json = {
    rpcErrorMessageCodec.toJson(toMessage)
  }

  def toMsgPack: MsgPack = {
    rpcErrorMessageCodec.toMsgPack(toMessage)
  }

  /**
    * Convert this exception to an HTTP response
    */
  def toResponse: HttpMessage.Response = {
    var resp = Http
      .response(status.httpStatus)
      .addHeader(HttpHeader.xAirframeRPCStatus, status.code.toString)

    try {
      // Embed RPCError into the response body
      resp = resp.withJson(toJson)
    } catch {
      case ex: Throwable =>
        // Show warning
        warn(s"Failed to serialize RPCException: ${this}", ex)
    }
    resp
  }
}

/**
  * A model class for RPC error message body. This message will be embedded to HTTP response body or gRPC trailer.
  *
  * We need this class to avoid directly serde RPCException classes with airframe-codec, so that we can properly
  * propagate the exact stack trace to the client.
  */
case class RPCErrorMessage(
    code: Int = RPCStatus.UNKNOWN_I1.code,
    codeName: String = RPCStatus.UNKNOWN_I1.name,
    message: String = "",
    stackTrace: Option[Seq[GenericStackTraceElement]] = None,
    cause: Option[Throwable] = None,
    appErrorCode: Option[Int] = None,
    metadata: Map[String, Any] = Map.empty
)

object RPCException {

  private val rpcErrorMessageCodec = MessageCodec.of[RPCErrorMessage]

  private def fromRPCErrorMessage(m: RPCErrorMessage): RPCException = {
    val ex = new RPCException(
      status = RPCStatus.ofCode(m.code),
      message = m.message,
      cause = m.cause,
      appErrorCode = m.appErrorCode,
      metadata = m.metadata
    )
    // Recover the original stack trace
    m.stackTrace.foreach { x =>
      ex.setStackTrace(x.map(_.toJavaStackTraceElement).toArray)
    }
    ex
  }

  def fromJson(json: String): RPCException = {
    val m = rpcErrorMessageCodec.fromJson(json)
    fromRPCErrorMessage(m)
  }

  def fromMsgPack(msgpack: MsgPack): RPCException = {
    val m = rpcErrorMessageCodec.fromMsgPack(msgpack)
    fromRPCErrorMessage(m)
  }

  def fromResponse(response: HttpMessage.Response): RPCException = {
    val responseBodyCodec = new HttpResponseBodyCodec[Response]

    response
      .getHeader(HttpHeader.xAirframeRPCStatus)
      .flatMap(x => Try(x.toInt).toOption) match {
      case Some(rpcStatus) =>
        try {
          if (response.message.isEmpty) {
            val status = RPCStatus.ofCode(rpcStatus)
            status.newException(status.name)
          } else {
            val msgpack = responseBodyCodec.toMsgPack(response)
            RPCException.fromMsgPack(msgpack)
          }
        } catch {
          case e: Throwable =>
            RPCStatus.ofCode(rpcStatus).newException(s"Failed to parse the RPC error details: ${e.getMessage}", e)
        }
      case None =>
        RPCStatus.DATA_LOSS_I8.newException(s"Invalid RPC response: ${response}")
    }
  }
}
