package vn.kalapa.ekyc.capturesdk.tflite

import android.app.Activity
import android.graphics.Bitmap
import android.graphics.PointF
import android.os.SystemClock
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.ops.CastOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import vn.kalapa.ekyc.managers.KLPCardModelManager
import vn.kalapa.ekyc.utils.Helpers
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStreamReader
import kotlin.math.sqrt

// Centralized configuration for detector thresholds and settings
data class DetectorConfig(
    val inputMean: Float = 0f,
    val inputStd: Float = 255f,
    val confidenceThreshold: Float = 0.3f,
    val cardConfidence: Float = 0.65f,
    val cornerConfidence: Float = 0.35f,
    val iouThreshold: Float = 0.5f,
    val movementThreshold: Float = 0.02f,
    val stableDurationMs: Long = 500L,
    val numThreads: Int = 4
)

// Tracks stability state for auto‐capture
private class CaptureStateTracker(private val config: DetectorConfig) {
    private var startTime: Long = 0L
    private var previousBox: BoundingBox? = null
    private var detectedOnce: Boolean = false

    fun reset() {
        startTime = 0L
        previousBox = null
    }

    fun handle(
        corner: BoundingBox,
        isAutoCaptureOn: Boolean,
        listener: OnImageDetectedListener
    ) {
        listener.onImageProcessing(CardDetectStatus.IN_MASK)//onImageInMask()
        val now = SystemClock.uptimeMillis()
        if (startTime == 0L) {
            startTime = now
            previousBox = corner
        } else {
            val elapsed = now - startTime
            val prev = previousBox ?: corner
            if (distance(prev, corner) > config.movementThreshold) {
                listener.onImageProcessing(CardDetectStatus.UNSTABLE)//onImageUnstable()
                reset()
            } else if (elapsed >= config.stableDurationMs && isAutoCaptureOn && !detectedOnce) {
                listener.onImageDetected()
                detectedOnce = true
            }
            previousBox = corner
        }
    }

    private fun distance(b1: BoundingBox, b2: BoundingBox): Float {
        val dx = (b2.x1 + b2.x2) / 2 - (b1.x1 + b1.x2) / 2
        val dy = (b2.y1 + b2.y2) / 2 - (b1.y1 + b1.y2) / 2
        return sqrt(dx * dx + dy * dy)
    }

}

class KLPDetector(
    private val activity: Activity,
    private val modelName: String,
    labelPath: String,
    var isAutoCaptureOn: Boolean = true,
    private val onImageListener: OnImageDetectedListener,
    private val config: DetectorConfig = DetectorConfig()
) : KLPDetectorListener {

    private lateinit var interpreter: Interpreter
    private var gpuDelegate: GpuDelegate? = null
    private val labels = mutableListOf<String>()
    private var tensorWidth = 0
    private var tensorHeight = 0
    private var numChannels = 0
    private var numElements = 0
    private var isModelInitialized = false

    private val tensorImage = TensorImage(DataType.FLOAT32)
    private lateinit var outputBuffer: TensorBuffer
    private val imageProcessor = ImageProcessor.Builder()
        .add(NormalizeOp(config.inputMean, config.inputStd))
        .add(CastOp(DataType.FLOAT32))
        .build()

    private val stateTracker = CaptureStateTracker(config)
    private val interpreterLock = Any()

    init {
        loadLabels(labelPath)
        initializeModel()
    }

    private fun loadLabels(path: String) {
        try {
            activity.assets.open(path).use { stream ->
                BufferedReader(InputStreamReader(stream)).use { reader ->
                    reader.lineSequence().forEach { labels.add(it) }
                }
            }
        } catch (e: IOException) {
            e.printStackTrace()
        }
    }

    private fun initializeModel() {
        val modelBytes = KLPCardModelManager.getModel() ?: run {
            onImageListener.onSetupModelComplete(false)
            return
        }
        try {
            val options = Interpreter.Options().apply {
                val compat = CompatibilityList()
                if (compat.isDelegateSupportedOnThisDevice) {
                    gpuDelegate = GpuDelegate(compat.bestOptionsForThisDevice)
                    addDelegate(gpuDelegate)
                } else setNumThreads(config.numThreads)
            }
            interpreter = try {
                Interpreter(modelBytes, options)
            } catch (_: Exception) {
                Interpreter(modelBytes)
            }

            val inShape = interpreter.getInputTensor(0).shape()
            tensorWidth = if (inShape[1] == 3) inShape[2] else inShape[1]
            tensorHeight = if (inShape[1] == 3) inShape[3] else inShape[2]

            val outShape = interpreter.getOutputTensor(0).shape()
            numChannels = outShape[1]
            numElements = outShape[2]
            outputBuffer = TensorBuffer.createFixedSize(intArrayOf(1, numChannels, numElements), DataType.FLOAT32)

            isModelInitialized = true
            onImageListener.onSetupModelComplete(true)
        } catch (e: Exception) {
            Helpers.printLog("Error initializing interpreter: ${e.message}")
            onImageListener.onSetupModelComplete(false)
        }
    }

    fun restart() {
        close()
        stateTracker.reset()
        isModelInitialized = false
        initializeModel()
    }

    fun close() {
        try {
            synchronized(interpreterLock) {
                if (::interpreter.isInitialized) interpreter.close()
            }
            gpuDelegate?.close()
        } catch (_: Exception) {
        }
    }

    @Synchronized
    fun detect(frame: Bitmap, doneProcessed: () -> Unit) {
        if (!isModelInitialized) {
            doneProcessed()
            return
        }
        val start = SystemClock.uptimeMillis()
        val resized = Bitmap.createScaledBitmap(frame, tensorWidth, tensorHeight, false)
        tensorImage.load(resized)
        val processed = imageProcessor.process(tensorImage)

        try {
            synchronized(interpreterLock) {
                interpreter.run(processed.buffer, outputBuffer.buffer)
            }
            val boxes = parseOutput(outputBuffer.floatArray)
            val inferenceTime = SystemClock.uptimeMillis() - start
            if (boxes.isNullOrEmpty()) onEmptyDetect()
            else onDetect(resized.width, resized.height, boxes, inferenceTime, doneProcessed)
        } catch (e: Exception) {
            Helpers.printLog("Error running interpreter: ${e.message}")
        } finally {
            resized.recycle()
            doneProcessed()
        }
    }

    @Synchronized
    fun getCurrentBoundingBoxes(frame: Bitmap): List<BoundingBox>? {
        if (!isModelInitialized) return null
        return try {
            val resized = Bitmap.createScaledBitmap(frame, tensorWidth, tensorHeight, false)
            val ti = TensorImage(DataType.FLOAT32).apply { load(resized) }
            val proc = imageProcessor.process(ti)
            val out = TensorBuffer.createFixedSize(intArrayOf(1, numChannels, numElements), DataType.FLOAT32)
            synchronized(interpreterLock) { interpreter.run(proc.buffer, out.buffer) }
            val corners = parseOutput(out.floatArray)
                ?.filter { it.clsName == "corner" && it.cnf > config.cornerConfidence }
                ?: emptyList()
            if (corners.size == 4) corners else null
        } catch (e: Exception) {
            Helpers.printLog("Error in getCurrentBoundingBoxes: ${e.message}")
            null
        }
    }

    private fun parseOutput(data: FloatArray): List<BoundingBox>? {
        val boxes = mutableListOf<BoundingBox>()
        for (i in 0 until numElements) {
            var bestConf = config.confidenceThreshold
            var bestIdx = -1
            var offset = i + numElements * 4
            for (c in 4 until numChannels) {
                val conf = data[offset]
                if (conf > bestConf) {
                    bestConf = conf; bestIdx = c - 4
                }
                offset += numElements
            }
            if (bestIdx >= 0) {
                val sx = data[i]
                val sy = data[i + numElements]
                val sw = data[i + numElements * 2]
                val sh = data[i + numElements * 3]
                val x1 = sx - sw / 2
                val y1 = sy - sh / 2
                val x2 = sx + sw / 2
                val y2 = sy + sh / 2
                if (x1 in 0f..1f && y1 in 0f..1f && x2 in 0f..1f && y2 in 0f..1f) {
                    boxes.add(BoundingBox(x1, y1, x2, y2, sx, sy, sw, sh, bestConf, bestIdx, labels[bestIdx]))
                }
            }
        }
        return boxes.takeIf { it.isNotEmpty() }?.let { applyNMS(it) }
    }

    private fun applyNMS(boxes: List<BoundingBox>): MutableList<BoundingBox> {
        val sorted = boxes.sortedByDescending { it.cnf }.toMutableList()
        val result = mutableListOf<BoundingBox>()
        while (sorted.isNotEmpty()) {
            val top = sorted.removeAt(0)
            result.add(top)
            val iter = sorted.iterator()
            while (iter.hasNext()) {
                if (iou(top, iter.next()) > config.iouThreshold) iter.remove()
            }
        }
        return result
    }

    private fun iou(b1: BoundingBox, b2: BoundingBox): Float {
        val x1 = maxOf(b1.x1, b2.x1)
        val y1 = maxOf(b1.y1, b2.y1)
        val x2 = minOf(b1.x2, b2.x2)
        val y2 = minOf(b1.y2, b2.y2)
        val inter = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
        val area1 = b1.w * b1.h
        val area2 = b2.w * b2.h
        return inter / (area1 + area2 - inter)
    }

    override fun onEmptyDetect() {
        onImageListener.onImageNotDetected()
    }

    override fun onDetect(
        frameWidth: Int,
        frameHeight: Int,
        boundingBoxes: List<BoundingBox>,
        inferenceTime: Long,
        doneProcessed: () -> Unit
    ) {
        if (!isModelInitialized) return
        val actualH = frameWidth * 0.75f
        val offsetRatio = ((frameHeight - actualH) / 2) / frameHeight
        val offsetBottom = 1 - offsetRatio

        val cards = mutableListOf<BoundingBox>()
        val corners = mutableListOf<BoundingBox>()
        val topLeft = PointF(Float.MAX_VALUE, Float.MAX_VALUE)
        val bottomRight = PointF(Float.MIN_VALUE, Float.MIN_VALUE)

        for (b in boundingBoxes) {
            if (b.cnf < if (b.clsName == "card") config.cardConfidence else config.cornerConfidence) continue
            if (b.clsName == "card") cards.add(b) else corners.add(b)
            topLeft.x = minOf(topLeft.x, b.x1); topLeft.y = minOf(topLeft.y, b.y1)
            bottomRight.x = maxOf(bottomRight.x, b.x2); bottomRight.y = maxOf(bottomRight.y, b.y2)
        }

        if (cards.size == 1 && corners.size in 4..5) {
            if (topLeft.y > offsetRatio && bottomRight.y < offsetBottom) {
                val centerX = (topLeft.x + bottomRight.x) / 2
                val centerY = (topLeft.y + bottomRight.y) / 2
                // Detect center box instead of corner box
                val centerBox = BoundingBox(centerX, centerY, centerX, centerY, centerX, centerY, 0f, 0f, 1f, cards[0].cls, cards[0].clsName)
                Helpers.printLog("corner[0] ${corners[0]} \n centerBox $centerBox")
                stateTracker.handle(centerBox, isAutoCaptureOn, onImageListener)
            } else {
                stateTracker.reset()
                if (cards.isNotEmpty()) onImageListener.onImageProcessing(CardDetectStatus.CORNER_NOT_REVEALED) //onImageOutOfMask()
            }
        } else {
            stateTracker.reset()
            if (cards.isNotEmpty()) onImageListener.onImageProcessing(CardDetectStatus.OUT_OF_MASK) //onImageOutOfMask()
        }
        doneProcessed()
    }
}

interface KLPDetectorListener {
    fun onEmptyDetect()
    fun onDetect(
        frameWidth: Int,
        frameHeight: Int,
        boundingBoxes: List<BoundingBox>,
        inferenceTime: Long,
        doneProcessed: () -> Unit
    )
}
