package com.liveperson.infra.network.http.requests

import com.liveperson.infra.Command
import com.liveperson.infra.ICallback
import com.liveperson.infra.log.LPLog
import com.liveperson.infra.model.JWKInfo
import com.liveperson.infra.otel.LPTelemetryAttributeKey
import com.liveperson.infra.otel.LPTelemetryManager
import com.liveperson.infra.otel.LPTraceType
import com.liveperson.infra.otel.models.OtlpAttribute
import com.liveperson.infra.otel.models.OtlpValueData
import com.nimbusds.jose.jwk.JWKSet
import java.net.URL

class JWKSRequest(val domain: String, val callback: ICallback<JWKInfo, Exception>): Command {
    companion object {
        private const val TAG = "JWKSRequest"
        private const val LP_JWKS_URL = "https://%s/well-known/jwks"

        // HTTP connect timeout in milliseconds
        private const val connectTimeout = 30000

        // HTTP read timeout in milliseconds
        private const val readTimeout = 30000

        // JWK set size limit, in bytes
        private const val sizeLimit = 300000
    }

    override fun execute() {
        val jwksUrl = String.format(LP_JWKS_URL, domain)
        var jwkSet: JWKSet? = null
        val otlpAttributes: MutableList<OtlpAttribute> = ArrayList()
        otlpAttributes.add(OtlpAttribute(
                LPTelemetryAttributeKey.URL_FULL.value,
                OtlpValueData.StringValue(jwksUrl)
            )
        )
        otlpAttributes.add(OtlpAttribute(
                LPTelemetryAttributeKey.METHOD.value,
                OtlpValueData.StringValue("GET")
            )
        )
        val jwksSpan = LPTelemetryManager.begin(LPTraceType.GET_JWKS_KEY_REQ, otlpAttributes)

        try {
            jwkSet = JWKSet.load(URL(jwksUrl), connectTimeout, readTimeout, sizeLimit)
        } catch (e: Exception) {
            LPLog.d(TAG, "Failed to load jwkSet from: $jwksUrl")
            jwksSpan?.cancel()
            callback.onError(e)
        }

        if (jwkSet == null || jwkSet.size() <= 0) {
            LPLog.d(TAG, "jwkSet is blank")
            jwksSpan?.cancel()
            callback.onSuccess(null)
            return
        }

        val kId = jwkSet.keys[0].keyID
        val jwk = jwkSet.getKeyByKeyId(kId)
        jwksSpan?.end()
        callback.onSuccess(JWKInfo(jwk, kId))
    }
}