/*
 * Copyright (C) 2023 ByteDance Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

@file:Suppress("UNCHECKED_CAST")

package com.bytedance.ultimate.inflater.plugin.arsc

import java.lang.reflect.Field
import kotlin.reflect.KClass

/**
 * Created by ChengTao(chentao.joe@bytedance.com) on 2022/8/12.
 */
class CastContext {
    private val filedValueMap = mutableMapOf<String, Any?>()
    private val filedIndexMap = mutableMapOf<String, Int>()

    fun setFieldValue(filedName: String, value: Any?) {
        filedValueMap[filedName] = value
    }

    fun setFieldIndex(filedName: String, index: Int) {
        filedIndexMap[filedName] = index
    }

    fun getFieldValue(filedName: String): Any? {
        return filedValueMap[filedName]
    }

    fun getFiledIndex(filedName: String): Int? {
        return filedIndexMap[filedName]
    }

}

@Target(AnnotationTarget.FIELD)
annotation class FieldCast(val value: KClass<out FieldCaster<*>>)

interface FieldCaster<T> {
    fun castAndPlusAssign(field: Field, castContext: CastContext, byteList: ByteList): T

    fun unCast(field: Field, value: T, interceptor: UnCastInterceptor): List<Byte>
}

@Suppress("MemberVisibilityCanBePrivate", "SameParameterValue")
abstract class AbsFieldCaster<T> : FieldCaster<T> {

    private lateinit var castContext: CastContext

    abstract fun cast(field: Field, byteList: ByteList): T

    open fun getFieldSize(field: Field, value: T) = sizeOf(field)

    final override fun castAndPlusAssign(
        field: Field,
        castContext: CastContext,
        byteList: ByteList
    ): T {
        this.castContext = castContext
        return cast(field, byteList).also { value ->
            byteList += getFieldSize(field, value)
        }
    }

    protected fun <T> getFieldValueOrNull(key: String): T? {
        return castContext.getFieldValue(key) as T?
    }

    protected fun <T> getFieldValue(key: String): T {
        return getFieldValueOrNull<T>(key)!!
    }

    protected fun getFieldIndexOrNull(fieldName: String): Int? {
        return castContext.getFiledIndex(fieldName)
    }


    protected fun getFieldIndex(fieldName: String): Int {
        return getFieldIndexOrNull(fieldName)!!
    }
}

open class DefaultFieldCaster<T> : AbsFieldCaster<T>() {
    override fun cast(field: Field, byteList: ByteList): T {
        TODO("Not yet implemented")
    }

    override fun unCast(field: Field, value: T, interceptor: UnCastInterceptor): List<Byte> {
        TODO("Not yet implemented")
    }
}

abstract class IntFieldTransformCaster<T> : AbsFieldCaster<T>() {
    final override fun cast(field: Field, byteList: ByteList): T {
        return castTransform(byteList.castToInt(0, sizeOf(field)))
    }

    final override fun unCast(field: Field, value: T, interceptor: UnCastInterceptor): List<Byte> {
        return unCastTransform(value).toByteList(sizeOf(field))
    }

    abstract fun castTransform(value: Int): T

    abstract fun unCastTransform(value: T): Int
}

val Field.canCast: Boolean
    get() {
        return declaredAnnotations.any { annotation ->
            when (annotation) {
                is unit_8_t,
                is unit_16_t,
                is unit_32_t,
                is unit_size_t_group,
                is FieldCast -> true

                else -> false
            }
        }
    }


inline fun <reified T> cast(byteList: ByteList): T {
    return cast(byteList, T::class.java)
}

fun <T> cast(byteList: ByteList, clazz: Class<T>): T {
    // 1. record original index for rest
    val originalIndex = byteList.currentIndex
    // 2. cast to object
    val fields = clazz.declaredFields.filter { it.canCast }
    val parameterTypes = fields.map { it.type }.toTypedArray()
    val constructor = clazz.getDeclaredConstructor(*parameterTypes)
    val castContext = CastContext()
    val initArgs = fields.map { field ->
        castContext.setFieldIndex(field.name, byteList.currentIndex)
        getFieldValueAndPlusAssign(field, castContext, byteList).also { value ->
            castContext.setFieldValue(field.name, value)
        }
    }.toTypedArray()
    val value = constructor.newInstance(*initArgs)
    // 3. rest index
    byteList -= (byteList.currentIndex - originalIndex)
    return value as T
}

fun getFieldValueAndPlusAssign(
    field: Field,
    castContext: CastContext,
    byteList: ByteList
): Any? {
    // 1. check FieldCast
    val fieldCaster = field.declaredAnnotations.filterIsInstance<FieldCast>()
        .firstOrNull()?.value?.java?.getDeclaredConstructor()?.newInstance()
    if (fieldCaster != null) {
        return fieldCaster.castAndPlusAssign(field, castContext, byteList)
    }
    // 2. check unit_size_t_group
    if (field.declaredAnnotations.filterIsInstance<unit_size_t_group>().isNotEmpty()) {
        return castAndPlusAssign(byteList, field.type)
    }
    // 3. check field type is Int or not
    if (field.type == Int::class.java) {
        val size = sizeOf(field)
        return if (size > 0) byteList.castToIntAndPlusAssign(size) else 0
    }
    // 3. check field's annotations contain unit_t_size_group
    return null
}

inline fun <reified T> castAndPlusAssign(byteList: ByteList): T {
    return castAndPlusAssign(byteList, T::class.java)
}

fun <T> castAndPlusAssign(byteList: ByteList, clazz: Class<T>): T {
    return cast(byteList, clazz).also { byteList += sizeOf(clazz) }
}

interface UnCastInterceptor {
    fun intercept(field: Field, value: Any): List<Byte>?
}

class UnInterceptUnCastInterceptor : UnCastInterceptor {
    override fun intercept(field: Field, value: Any) = null
}


fun Any.unCast(interceptor: UnCastInterceptor = UnInterceptUnCastInterceptor()): List<Byte> {
    if (this is List<*>) {
        return this.mapNotNull { it?.unCast(interceptor) }.flatten()
    }
    return this::class.java.declaredFields
        .filter { it.canCast }
        .mapNotNull { field ->
            unCastFieldToByteList(
                field,
                field.also { it.isAccessible = true }.get(this),
                interceptor
            )
        }
        .flatten()
}

private fun unCastFieldToByteList(
    field: Field,
    value: Any?,
    interceptor: UnCastInterceptor
): List<Byte>? {
    if (value == null) {
        return null
    }
    // 1. interceptor
    val byteList = interceptor.intercept(field, value)
    if (byteList != null) {
        return byteList
    }
    // 2. check FieldCast
    val fieldCaster = field.declaredAnnotations.filterIsInstance<FieldCast>()
        .firstOrNull()?.value?.java?.getDeclaredConstructor()?.newInstance()
    if (fieldCaster != null) {
        return (fieldCaster as FieldCaster<Any>).unCast(field, value, interceptor)
    }
    // 3. check unit_size_t_group
    if (field.declaredAnnotations.filterIsInstance<unit_size_t_group>().isNotEmpty()) {
        return value.unCast(interceptor)
    }
    // 4. check value is Int or not
    if (value is Int) {
        return value.toByteList(sizeOf(field))
    }
    return null
}