package com.deque.networking.utils

import okhttp3.Interceptor
import okhttp3.Interceptor.Chain
import okhttp3.MediaType
import okhttp3.RequestBody
import okhttp3.Response
import okhttp3.ResponseBody
import okio.Buffer
import okio.BufferedSink
import okio.GzipSink
import okio.GzipSource
import okio.buffer

class GzipRequestInterceptor : Interceptor {
    override fun intercept(chain: Chain): Response {
        val originalRequest = chain.request()
        val method = originalRequest.method()
        val isAxeResultEndpoint = originalRequest.url().toString().contains("/attest/result/axe")
        val isHttps = originalRequest.isHttps

        if (!(isAxeResultEndpoint && isHttps)) return chain.proceed(originalRequest)

        return when (method) {
            "GET" -> handleGetResult(chain)
            "POST" -> handlePostResult(chain)
            else -> chain.proceed(originalRequest)
        }
    }

    private fun handleGetResult(chain: Chain): Response {
        val newRequest = chain.request().newBuilder()
        newRequest.addHeader("Accept-Encoding", "gzip")
        val response = chain.proceed(newRequest.build())

        return if (isGzipped(response)) {
            unzip(response)
        } else {
            response
        }
    }

    private fun handlePostResult(chain: Chain): Response {
        val request = chain.request()
        val compressedRequest = request.newBuilder()
            .header("Content-Encoding", "gzip")
            .post(GzipRequestBody(request.body()!!).withContentSize())
            .build()
        return chain.proceed(compressedRequest)
    }

    private fun unzip(response: Response): Response {
        val body = response.body() ?: return response
        val gzipSource = GzipSource(body.source())
        val bodyString = gzipSource.buffer().readUtf8()

        val responseBody = ResponseBody.create(body.contentType(), bodyString)

        val strippedHeaders = response.headers().newBuilder()
            .removeAll("Content-Encoding")
            .removeAll(("Content-Length"))
            .build()

        return response.newBuilder()
            .headers(strippedHeaders)
            .body(responseBody)
            .message(response.message())
            .build()
    }

    private fun isGzipped(response: Response): Boolean {
        return response.header("Content-Encoding") != null
                && response.header("Content-Encoding") == "gzip"
    }
}

class GzipRequestBody(private val requestBody: RequestBody) : RequestBody() {
    override fun contentType(): MediaType? {
        return requestBody.contentType()
    }

    override fun contentLength(): Long {
        return -1
    }

    override fun writeTo(sink: BufferedSink) {
        val gzipSink = GzipSink(sink).buffer()
        requestBody.writeTo(gzipSink)
        gzipSink.close()
    }

    fun withContentSize(): RequestBody {
        val buffer = Buffer()
        this.writeTo(buffer)

        return object : RequestBody() {
            override fun contentType(): MediaType? {
                return requestBody.contentType()
            }

            override fun contentLength(): Long {
                return buffer.size
            }

            override fun writeTo(sink: BufferedSink) {
                sink.write(buffer.snapshot())
            }
        }
    }
}