package vn.kalapa.ekyc.liveness

import android.graphics.Bitmap
import androidx.collection.LruCache
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetector
import com.google.mlkit.vision.face.FaceDetectorOptions
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import vn.kalapa.ekyc.KalapaSDK
import vn.kalapa.ekyc.liveness.models.ComeClose
import vn.kalapa.ekyc.liveness.models.GoFar
import vn.kalapa.ekyc.liveness.models.HoldSteady2Seconds
import vn.kalapa.ekyc.liveness.models.LivenessAction
import vn.kalapa.ekyc.liveness.models.LivenessActionStatus
import vn.kalapa.ekyc.liveness.models.Processing
import vn.kalapa.ekyc.liveness.models.Success
import vn.kalapa.ekyc.liveness.models.TiltLeft
import vn.kalapa.ekyc.liveness.models.TiltRight
import vn.kalapa.ekyc.liveness.models.TurnDown
import vn.kalapa.ekyc.liveness.models.TurnLeft
import vn.kalapa.ekyc.liveness.models.TurnRight
import vn.kalapa.ekyc.liveness.models.TurnUp
import vn.kalapa.ekyc.managers.KLPFaceDetectorListener
import vn.kalapa.ekyc.utils.Common
import vn.kalapa.ekyc.utils.Helpers
import java.util.Collections
import java.util.concurrent.ConcurrentLinkedQueue
import kotlin.math.max
import kotlin.random.Random

class InputFace(val inputTime: Long, val face: Face, val frameWidth: Int, val frameHeight: Int)
class LivenessSession(private var livenessSessionType: Common.LIVENESS_VERSION = Common.LIVENESS_VERSION.PASSIVE) {
    var sessionStatus: LivenessSessionStatus = LivenessSessionStatus.UNVERIFIED
    private val MAX_N_FRAME = 600
    val faceList = ConcurrentLinkedQueue<InputFace>()
    private val actionList = Collections.synchronizedList(mutableListOf<LivenessAction>())

    fun getFaceList(): List<InputFace> = faceList.toList()
    private var currActionIdx: Int = -1
    private var currAction: LivenessAction? = null
    private lateinit var index2Action: Map<Int, LivenessAction>
    private val TAG = "OptimizedLivenessSession"

    lateinit var typicalFace: Bitmap
    lateinit var typicalFrame: Bitmap
    var gotTypicalFace = false
    private val CACHE_SIZE = 10
    private val recentDetections = LruCache<Long, Face>(CACHE_SIZE)
    private val processingScope = CoroutineScope(Dispatchers.Default + SupervisorJob())
    private lateinit var faceDetectorListener: KLPFaceDetectorListener

    private val faceDetector: FaceDetector = FaceDetection.getClient(
        FaceDetectorOptions.Builder()
            .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST)
            .setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_ALL)
            .setMinFaceSize(0.25f)
            .enableTracking()
            .build()
    )

    init {
        genActionList()
    }

    fun renewSession(livenessSessionType: Common.LIVENESS_VERSION) {
        this.livenessSessionType = livenessSessionType
        sessionStatus = LivenessSessionStatus.UNVERIFIED
        gotTypicalFace = false
        refreshFaceList()
        genActionList()
    }

    private fun refreshFaceList() {
        faceList.clear()
        recentDetections.evictAll()
    }

    private fun genActionList() {
        Helpers.printLog("genActionList ${livenessSessionType.name}")
        index2Action = when (livenessSessionType) {
            Common.LIVENESS_VERSION.PASSIVE -> mapOf(
                1 to HoldSteady2Seconds(1500),
                2 to Processing()
            )

            Common.LIVENESS_VERSION.ACTIVE -> mapOf(
                1 to HoldSteady2Seconds(1500),
                2 to if (Random.nextInt(2) % 2 == 0) TurnLeft() else TurnRight(),
                3 to if (Random.nextInt(2) % 2 == 0) TurnUp() else TurnDown(),
                4 to if (Random.nextInt(2) % 2 == 0) TiltLeft() else TiltRight(),
                5 to Processing()
            )

            Common.LIVENESS_VERSION.SEMI_ACTIVE -> mapOf(
                1 to HoldSteady2Seconds(1500),
                2 to GoFar(),
                3 to ComeClose(),
                4 to Processing()
            )
        }
        currActionIdx = 0
        currAction = null
        actionList.clear()
    }

    fun process(
        frame: Bitmap,
        rotationAngle: Int,
        listener: KLPFaceDetectorListener,
    ) {
        this.faceDetectorListener = listener
        if (isFinished()) return

        processingScope.launch {
            try {
                val timestamp = System.currentTimeMillis()
                recentDetections.get(timestamp)?.let { cachedFace ->
                    handleDetectionResult(listOf(cachedFace), frame, listener)
                    return@let
                }

                val optimizedImage = withContext(Dispatchers.Default) {
                    optimizeImageForDetection(frame, rotationAngle)
                }
                Helpers.printLog("original - ${frame.width} ${frame.height} optimizedImage - ${optimizedImage.width} ${optimizedImage.height}")

                faceDetector.process(optimizedImage)
                    .addOnSuccessListener { faces ->
                        if (isFinished()) return@addOnSuccessListener

                        faces.firstOrNull()?.let {
                            recentDetections.put(timestamp, it)
                        }

                        handleDetectionResult(faces, optimizedImage.bitmapInternal!!, listener)
                    }
                    .addOnFailureListener { exception ->
                        Helpers.printLog("$TAG Detection failed: ${exception.message}")
                        sessionStatus = LivenessSessionStatus.FAILED
                        listener.onMessage(sessionStatus, currAction?.TAG)
                    }
            } catch (e: Exception) {
                Helpers.printLog("$TAG Processing error: ${e.message}")
                sessionStatus = LivenessSessionStatus.FAILED
                listener.onMessage(sessionStatus, currAction?.TAG)
            }
        }
    }

    private fun handleDetectionResult(
        faces: List<Face>,
        frame: Bitmap,
        faceDetectorListener: KLPFaceDetectorListener,
    ) {
        var isFaceSizeInRange = 0
        if (faces.isEmpty()) {
            sessionStatus = LivenessSessionStatus.NO_FACE
        } else {
            for (face in faces) {
                val inputFace =
                    InputFace(System.currentTimeMillis(), face, frame.width, frame.height)
                if (LivenessAction.isFaceSizeInRange(inputFace)) isFaceSizeInRange++
            }

            if (faces.size > 1 && isFaceSizeInRange > 1) {
                sessionStatus = LivenessSessionStatus.TOO_MANY_FACES
                refreshFaceList()
            } else {
                val face = InputFace(
                    System.currentTimeMillis(),
                    faces[0],
                    frame.width,
                    frame.height
                )
                processFaceDetection(face, frame)
            }
        }

        Helpers.printLog("$TAG Processing: ${currAction?.TAG} - Status $sessionStatus ${faces.size} - faceList ${faceList.size}")
        faceDetectorListener.onMessage(sessionStatus, currAction?.TAG)
    }

    private fun processFaceDetection(inputFace: InputFace, frame: Bitmap) {
        when {
            LivenessAction.isFaceTooSmall(inputFace) ->
                sessionStatus = LivenessSessionStatus.TOO_SMALL

            LivenessAction.isFaceTooBig(inputFace) ->
                sessionStatus = LivenessSessionStatus.TOO_LARGE

            else -> {
                var shouldProcessFace = true
                if (currAction != null && currAction is HoldSteady2Seconds) {
                    // 2.10.5: Remove check margin right unless it is HoldSteadyFor2Seconds
                    if (!LivenessAction.isFaceMarginRight(inputFace.face, frame.width, frame.height)) {
                        shouldProcessFace = false
                        sessionStatus = LivenessSessionStatus.OFF_CENTER
                        refreshFaceList()
                    } else if (KalapaSDK.config.livenessVersion != Common.LIVENESS_VERSION.ACTIVE.version && !LivenessAction.isFaceLookStraight(inputFace.face)) {
                        shouldProcessFace = false
                        sessionStatus = LivenessSessionStatus.ANGLE_NOT_CORRECT
                        refreshFaceList()
                    }
                }
                if (shouldProcessFace) {
                    faceList.add(inputFace)
                    if (faceList.size > MAX_N_FRAME) {
                        sessionStatus = LivenessSessionStatus.EXPIRED
                    } else {
                        handleAction(frame)
                    }
                }
            }
        }
    }

    private fun addAction(nextIndex: Int, action: LivenessAction? = null) {
        currActionIdx += nextIndex
        if (currActionIdx > (index2Action.keys.maxOrNull() ?: 0)) {
            sessionStatus = LivenessSessionStatus.VERIFIED
            return
        }
        if (nextIndex == 0) action?.let { actionList.add(it) }
        else index2Action[currActionIdx]?.let { actionList.add(it) }
    }

    private fun handleAction(frame: Bitmap) {
        if (currActionIdx == 0) addAction(1)
        val previousAction = currAction
        if (currActionIdx <= (index2Action.keys.maxOrNull() ?: 0)) {
            currAction = actionList.last()
            val currActionStatus = currAction!!.process(this)

            if (currActionStatus == LivenessActionStatus.SUCCESS &&
                (currAction?.TAG == "HoldSteady2Seconds" || currAction?.TAG == "ComeClose")
            ) {
                if (!gotTypicalFace) {
                    Helpers.printLog("$TAG Found typical frame!")
                    typicalFrame = frame
                    typicalFace = frame
                }

                gotTypicalFace = true
            }

            when (currActionStatus) {
                LivenessActionStatus.SUCCESS -> {
                    if (livenessSessionType == Common.LIVENESS_VERSION.PASSIVE) {
                        faceDetectorListener.onNextStep()
                    } else if (currAction?.TAG == Success().TAG && previousAction?.TAG != "HoldSteady2Seconds") {
                        faceDetectorListener.onNextStep()
                    }
                    if (currAction!!.isBreakAction) addAction(1) else addAction(0, Success())
                }

                LivenessActionStatus.TIMEOUT -> {
                    genActionList()
                }

                LivenessActionStatus.FAILED -> sessionStatus = LivenessSessionStatus.PROCESSING
            }
        }
    }

    private fun optimizeImageForDetection(image: Bitmap, rotationAngle: Int): InputImage {
        val maxDimension = 1024
        val scaleFactor = if (image.width > maxDimension || image.height > maxDimension) {
            maxDimension.toFloat() / max(image.width, image.height)
        } else 1.0f

        return if (scaleFactor < 1.0f) {
            val scaledWidth = (image.width * scaleFactor).toInt()
            val scaledHeight = (image.height * scaleFactor).toInt()
            val scaledBitmap = Bitmap.createScaledBitmap(image, scaledWidth, scaledHeight, true)
            InputImage.fromBitmap(scaledBitmap, rotationAngle)
        } else {
            InputImage.fromBitmap(image, rotationAngle)
        }
    }

    fun getCurrentAction(): String {
        return if (currActionIdx <= 0 || currActionIdx >= actionList.size) ""
        else actionList.last().TAG
    }

    fun isFinished(): Boolean {
        return sessionStatus == LivenessSessionStatus.FAILED ||
                sessionStatus == LivenessSessionStatus.VERIFIED
    }

    fun cleanup() {
        processingScope.cancel()
        faceDetector.close()
        refreshFaceList()
    }
}

enum class LivenessSessionStatus(val status: Int) {
    UNVERIFIED(-1),
    PROCESSING(0),
    VERIFIED(1),
    EXPIRED(2),
    FAILED(3),
    TOO_SMALL(4),
    TOO_LARGE(5),
    OFF_CENTER(6),
    TOO_MANY_FACES(7),
    NO_FACE(8),
    EYE_CLOSED(9),
    ANGLE_NOT_CORRECT(10)
}
//val PROCESSING = 0
//val VERIFIED = 1
//val EXPIRED = 2
//val FAILED = 3
//val TOO_SMALL = 4
//val TOO_LARGE = 5
//val OFF_CENTER = 6
