package vn.kalapa.ekyc.capturesdk.tflite

import android.app.Activity
import android.graphics.Bitmap
import android.graphics.PointF
import android.os.SystemClock
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.ops.CastOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import vn.kalapa.ekyc.managers.KLPCardModelManager
import vn.kalapa.ekyc.utils.Helpers
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStreamReader
import kotlin.math.sqrt

class KLPDetector(
    private val activity: Activity,
    private val modelName: String,  // Changed from modelPath to modelName
    labelPath: String,
    var isAutoCaptureOn: Boolean = true,
    private val onImageListener: OnImageDetectedListener,
) : KLPDetectorListener {
    private val INPUT_MEAN = 0f
    private val INPUT_STANDARD_DEVIATION = 255f
    private val INPUT_IMAGE_TYPE = DataType.FLOAT32
    private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
    private val CONFIDENCE_THRESHOLD = 0.3F
    private val IOU_THRESHOLD = 0.5F
    private val CARD_CONF = 0.65f
    private val CORNER_CONF = 0.35f
    private val VERIFY_CAPTURING_DURATION = 500L
    private val MOVEMENT_THRESHOLD = 0.4F

    private lateinit var interpreter: Interpreter
    private val labels = mutableListOf<String>()
    private var tensorWidth = 0
    private var tensorHeight = 0
    private var numChannel = 0
    private var numElements = 0
    private var conditionStartTime: Long = 0L
    private var gpuDelegate: GpuDelegate? = null
    private var isModelInitialized = false

    // Pre-allocate buffers
    private val tensorImage: TensorImage = TensorImage(INPUT_IMAGE_TYPE)
    private lateinit var outputBuffer: TensorBuffer

    private val imageProcessor = ImageProcessor.Builder()
        .add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
        .add(CastOp(INPUT_IMAGE_TYPE))
        .build()

    init {
        loadLabels(labelPath)
        initializeModel()
    }

    private fun initializeModel() {
        val cardModel = KLPCardModelManager.getModel()
        var readyToScan = false
        if (cardModel != null) {
            try {
                val compatList = CompatibilityList()
                val options = Interpreter.Options().apply {
                    if (compatList.isDelegateSupportedOnThisDevice) {
                        val delegateOptions = compatList.bestOptionsForThisDevice
                        gpuDelegate = GpuDelegate(delegateOptions)
                        this.addDelegate(gpuDelegate)
                    } else {
                        this.setNumThreads(4)
                    }
                }
                /**
                 * Initialize interpreter: Why have to catch the error? It is because when the compatList.isDelegateSupportedOnThisDevice is true,
                 * the interpreter will throw an error if the device does not support the GPU delegate.
                 * so we have to catch the error and initialize the interpreter without the GPU delegate.
                 */
                interpreter = try {
                    Interpreter(cardModel, options)
                } catch (e: Exception) {
                    Interpreter(cardModel)
                }
                val inputShape = interpreter.getInputTensor(0)?.shape()
                val outputShape = interpreter.getOutputTensor(0)?.shape()

                if (inputShape != null) {
                    tensorWidth = inputShape[1]
                    tensorHeight = inputShape[2]
                    if (inputShape[1] == 3) {
                        tensorWidth = inputShape[2]
                        tensorHeight = inputShape[3]
                    }
                }

                if (outputShape != null) {
                    numChannel = outputShape[1]
                    numElements = outputShape[2]
                    outputBuffer = TensorBuffer.createFixedSize(
                        intArrayOf(1, numChannel, numElements),
                        OUTPUT_IMAGE_TYPE
                    )
                }

                isModelInitialized = true
                readyToScan = true
            } catch (e: Exception) {
                Helpers.printLog("Error initializing interpreter: ${e.message}")
            }
        } else {
            // Try to get cardmodel next time.
        }
        onImageListener.onSetupModelComplete(readyToScan)
    }

    private fun loadLabels(labelPath: String) {
        try {
            activity.assets.open(labelPath).use { inputStream ->
                BufferedReader(InputStreamReader(inputStream)).use { reader ->
                    reader.lineSequence().forEach { labels.add(it) }
                }
            }
        } catch (e: IOException) {
            e.printStackTrace()
        }
    }

    fun restart(isGpu: Boolean) {
        close()
        detected = false
        isModelInitialized = false
        initializeModel()
    }

    fun close() {
        try {
            synchronized(interpreterLock) {
                if (::interpreter.isInitialized) {
                    interpreter.close()
                }
            }
            gpuDelegate?.close()
        } catch (exception: Exception) {
            // Already closed
        }
    }

    private val interpreterLock = Any()

    @Synchronized
    fun detect(frame: Bitmap, doneProcessed: () -> Unit) {
        if (!isModelInitialized || tensorWidth == 0 || tensorHeight == 0 || numChannel == 0 || numElements == 0) {
            doneProcessed()
            return
        }

        val startTime = SystemClock.uptimeMillis()
        val resizedBitmap = Bitmap.createScaledBitmap(frame, tensorWidth, tensorHeight, false)
        tensorImage.load(resizedBitmap)
        val processedImage = imageProcessor.process(tensorImage)

        try {
            synchronized(interpreterLock) {
                interpreter.run(processedImage.buffer, outputBuffer.buffer)
            }
            val bestBoxes = bestBox(outputBuffer.floatArray)

            val inferenceTime = SystemClock.uptimeMillis() - startTime
            if (bestBoxes == null) {
                onEmptyDetect()
            } else {
                onDetect(
                    resizedBitmap.width,
                    resizedBitmap.height,
                    bestBoxes,
                    inferenceTime,
                    doneProcessed
                )
            }
        } catch (e: Exception) {
            Helpers.printLog("Error running interpreter: ${e.message}")
        } finally {
            resizedBitmap.recycle()
            doneProcessed()
        }
    }

    @Synchronized
    fun getCurrentBoundingBoxes(frame: Bitmap): List<BoundingBox>? {
        if (tensorWidth == 0 || tensorHeight == 0 || numChannel == 0 || numElements == 0) {
            Helpers.printLog("Tensor dimensions are not initialized.")
            return null
        }

        return try {
            val resizedBitmap =
                Bitmap.createScaledBitmap(frame, tensorWidth, tensorHeight, false)
            val tensorImage = TensorImage(INPUT_IMAGE_TYPE).apply { load(resizedBitmap) }
            val processedImage = imageProcessor.process(tensorImage)
            val output = TensorBuffer.createFixedSize(
                intArrayOf(1, numChannel, numElements),
                OUTPUT_IMAGE_TYPE
            )
            synchronized(interpreterLock) {
                interpreter.run(processedImage.buffer, output.buffer)
            }
            val bestBoxes =
                bestBox(output.floatArray)?.filter { it.clsName == "corner" && it.cnf > CORNER_CONF }
                    ?: emptyList()
            if (bestBoxes.size == 4) bestBoxes else null
        } catch (e: OutOfMemoryError) {
            Helpers.printLog("OutOfMemoryError: ${e.message}")
            null
        } catch (e: IllegalArgumentException) {
            Helpers.printLog("IllegalArgumentException: ${e.message}")
            null
        } catch (e: Exception) {
            Helpers.printLog("Error running interpreter: ${e.message}")
            null
        }
    }

    private fun bestBox(array: FloatArray): List<BoundingBox>? {
        val boundingBoxes = mutableListOf<BoundingBox>()

        for (c in 0 until numElements) {
            var maxConf = CONFIDENCE_THRESHOLD
            var maxIdx = -1
            var j = 4
            var arrayIdx = c + numElements * j

            while (j < numChannel) {
                if (array[arrayIdx] > maxConf) {
                    maxConf = array[arrayIdx]
                    maxIdx = j - 4
                }
                j++
                arrayIdx += numElements
            }

            if (maxConf > CONFIDENCE_THRESHOLD) {
                val clsName = labels[maxIdx]
                val cx = array[c]
                val cy = array[c + numElements]
                val w = array[c + numElements * 2]
                val h = array[c + numElements * 3]
                val x1 = cx - (w / 2F)
                val y1 = cy - (h / 2F)
                val x2 = cx + (w / 2F)
                val y2 = cy + (h / 2F)

                if (x1 < 0F || x1 > 1F || y1 < 0F || y1 > 1F || x2 < 0F || x2 > 1F || y2 < 0F || y2 > 1F) continue

                boundingBoxes.add(
                    BoundingBox(
                        x1 = x1, y1 = y1, x2 = x2, y2 = y2,
                        cx = cx, cy = cy, w = w, h = h,
                        cnf = maxConf, cls = maxIdx, clsName = clsName
                    )
                )
            }
        }

        return if (boundingBoxes.isEmpty()) null else applyNMS(boundingBoxes)
    }

    private fun applyNMS(boxes: List<BoundingBox>): MutableList<BoundingBox> {
        val sortedBoxes = boxes.sortedByDescending { it.cnf }.toMutableList()
        val selectedBoxes = mutableListOf<BoundingBox>()

        while (sortedBoxes.isNotEmpty()) {
            val first = sortedBoxes.first()
            selectedBoxes.add(first)
            sortedBoxes.remove(first)

            val iterator = sortedBoxes.iterator()
            while (iterator.hasNext()) {
                val box = iterator.next()
                if (calculateIoU(first, box) > IOU_THRESHOLD) {
                    iterator.remove()
                }
            }
        }

        return selectedBoxes
    }

    private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float {
        val x1 = maxOf(box1.x1, box2.x1)
        val y1 = maxOf(box1.y1, box2.y1)
        val x2 = minOf(box1.x2, box2.x2)
        val y2 = minOf(box1.y2, box2.y2)
        val intersectionArea = maxOf(0F, x2 - x1) * maxOf(0F, y2 - y1)
        val box1Area = box1.w * box1.h
        val box2Area = box2.w * box2.h
        return intersectionArea / (box1Area + box2Area - intersectionArea)
    }

    override fun onEmptyDetect() {
        onImageListener.onImageNotDetected()
    }

    private var detected = false
    private var previousBoundingBox: BoundingBox? = null

    override fun onDetect(
        frameWidth: Int,
        frameHeight: Int,
        boundingBoxes: List<BoundingBox>,
        inferenceTime: Long,
        doneProcessed: () -> Unit,
    ) {
        if (detected && isAutoCaptureOn) return

        val actualHeight = frameWidth * 0.75f
        val offsetRatio = ((frameHeight - actualHeight) / 2) / frameHeight
        val offsetBottom = 1 - offsetRatio
        val corners = ArrayList<BoundingBox>()
        val cards = ArrayList<BoundingBox>()
        val topLeft = PointF(Float.MAX_VALUE, Float.MAX_VALUE)
        val bottomRight = PointF(Float.MIN_VALUE, Float.MIN_VALUE)

        for (boundBox in boundingBoxes) {
            if (boundBox.cnf < if (boundBox.clsName == "card") CARD_CONF else CORNER_CONF) {
//                Helpers.printLog("onDetect ignore boundBox ${boundBox.clsName} - ${boundBox.cnf}")
                continue
            }
//            else
//                Helpers.printLog("onDetect boundBox ${boundBox.clsName} - ${boundBox.cnf}")
            if (boundBox.clsName == "card")
                cards.add(boundBox)
            else if (boundBox.clsName == "corner")
                corners.add(boundBox)
            if (boundBox.y1 < topLeft.y) topLeft.y = boundBox.y1
            if (boundBox.x1 < topLeft.x) topLeft.x = boundBox.x1
            if (boundBox.y2 > bottomRight.y) bottomRight.y = boundBox.y2
            if (boundBox.x2 > bottomRight.x) bottomRight.x = boundBox.x2
        }
//        Helpers.printLog("onDetect ----------------------------")
        val conditionMet =
            cards.size == 1 && corners.size in (4..5) && topLeft.y > offsetRatio && bottomRight.y < offsetBottom

        if (conditionMet) {
            onImageListener.onImageInMask()
            if (conditionStartTime == 0L) {
                conditionStartTime = SystemClock.uptimeMillis()
                previousBoundingBox = corners[0]
            } else {
                val elapsedTime = SystemClock.uptimeMillis() - conditionStartTime
                val currentBoundingBox = corners[0]
                if (hasMovedSignificantly(previousBoundingBox ?: return, currentBoundingBox)) {
                    onImageListener.onImageOutOfMask()
                    conditionStartTime = 0L
                } else if (elapsedTime >= VERIFY_CAPTURING_DURATION) {
                    if (isAutoCaptureOn) {
                        onImageListener.onImageDetected()
                        detected = true
                    }
                }
                previousBoundingBox = currentBoundingBox
            }
        } else {
            conditionStartTime = 0L
            if (cards.isNotEmpty()) {
                onImageListener.onImageOutOfMask()
            }
        }
        doneProcessed()
    }

    /**
     * Calculate the distance between two bounding boxes using the Euclidean distance formula.
     * @param box1 the first bounding box
     * @param box2 the second bounding box
     * @return the distance between the two bounding boxes
     */
    private fun calculateDistance(box1: BoundingBox, box2: BoundingBox): Float {
        val deltaX = (box2.x1 + box2.x2) / 2 - (box1.x1 + box1.x2) / 2
        val deltaY = (box2.y1 + box2.y2) / 2 - (box1.y1 + box1.y2) / 2
        return sqrt(deltaX * deltaX + deltaY * deltaY)
    }

    private fun hasMovedSignificantly(previousBox: BoundingBox, currentBox: BoundingBox): Boolean =
        calculateDistance(previousBox, currentBox) > MOVEMENT_THRESHOLD

}

interface KLPDetectorListener {
    fun onEmptyDetect()
    fun onDetect(
        frameWidth: Int,
        frameHeight: Int,
        boundingBoxes: List<BoundingBox>,
        inferenceTime: Long,
        doneProcessed: () -> Unit,
    )
}