/*
 * Copyright (C) 2023 ByteDance Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.bytedance.ultimate.inflater.plugin.internal.transformer

import com.bytedance.ultimate.inflater.plugin.internal.find
import org.objectweb.asm.Opcodes
import org.objectweb.asm.tree.ClassNode
import org.objectweb.asm.tree.InsnList
import org.objectweb.asm.tree.InsnNode
import org.objectweb.asm.tree.LdcInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import org.objectweb.asm.tree.MethodNode
import org.objectweb.asm.tree.TypeInsnNode
import org.objectweb.asm.tree.VarInsnNode
import java.util.regex.Pattern

/**
 * Created by ChengTao(chentao.joe@bytedance.com) on 2023/4/20.
 */
@Suppress("DuplicatedCode")
internal class ViewCreatorTransformer : ClassTransformer {

    override fun transform(klass: ClassNode) {
        klass.takeIf(classFilter)?.let {
            // 1. modify createXXXView
            modifyMethod(klass, createViewFilter) { insnList, classInternalName ->
                insnList.apply {
                    // new XXXView(context, attributeSet)
                    add(TypeInsnNode(Opcodes.NEW, classInternalName))
                    add(InsnNode(Opcodes.DUP))
                    add(VarInsnNode(Opcodes.ALOAD, 1))
                    add(VarInsnNode(Opcodes.ALOAD, 2))
                    add(
                        MethodInsnNode(
                            Opcodes.INVOKESPECIAL,
                            classInternalName,
                            "<init>",
                            "(L${CONTEXT};L${ATTRIBUTE_SET};)V",
                            false
                        )
                    )
                    add(InsnNode(Opcodes.ARETURN))
                }
            }
            // 2. modify createXXXCustomLayoutInflaterFactory
            modifyMethod(klass, createLayoutInflaterFactoryFilter) { insnList, classInternalName ->
                insnList.apply {
                    // return new XXXLayoutInflaterFactory()
                    add(TypeInsnNode(Opcodes.NEW, classInternalName))
                    add(InsnNode(Opcodes.DUP))
                    add(
                        MethodInsnNode(
                            Opcodes.INVOKESPECIAL,
                            classInternalName,
                            "<init>",
                            "()V",
                            false
                        )
                    )
                    add(InsnNode(Opcodes.ARETURN))
                }
            }
        }
    }

    private inline fun modifyMethod(
        klass: ClassNode,
        methodFilter: (method: MethodNode) -> Boolean,
        methodModifier: (insnList: InsnList, classInternalName: String) -> Unit
    ) {
        klass.methods.filter(methodFilter).forEach { methodNode ->
            // 1. check method is following template:
            // throw new IllegalStateException("Stub: XXX")
            // 1.1 try to check in strict
            // new           // class java/lang/IllegalStateException
            // dup
            // ldc           // String Stub: XXX
            // invokespecial // Method java/lang/IllegalStateException."<init>":(Ljava/lang/String;)V
            // athrow
            var ldcInsnNode: LdcInsnNode? = null
            val athrowNode = methodNode.instructions.find(Opcodes.NEW)
                ?.let { it as? TypeInsnNode }
                ?.takeIf { it.desc == ILLEGAL_STATE_EXCEPTION } // new class java/lang/IllegalStateException
                ?.next?.takeIf { it.opcode == Opcodes.DUP } // dup
                ?.next?.takeIf { it.opcode == Opcodes.LDC }
                ?.let { it as? LdcInsnNode }
                ?.takeIf {
                    it.cst?.toString()?.startsWith("Stub: ") == true
                } // ldc String Stub: XXX
                ?.next?.takeIf { it.opcode == Opcodes.INVOKESPECIAL }
                ?.let { it as? MethodInsnNode }
                ?.takeIf {
                    it.owner == ILLEGAL_STATE_EXCEPTION
                            && it.name == "<init>"
                            && it.desc == "(Ljava/lang/String;)V"
                } // invokespecial // Method java/lang/IllegalStateException."<init>":(Ljava/lang/String;)V
                ?.next?.takeIf { it.opcode == Opcodes.ATHROW } // athrow
            if (athrowNode != null) {
                ldcInsnNode = athrowNode.previous?.previous as? LdcInsnNode
            }
            // 1.2 try to check in lenient
            if (ldcInsnNode == null) {
                ldcInsnNode = methodNode.instructions.find(Opcodes.LDC)
                    ?.let { it as? LdcInsnNode }
                    ?.takeIf { it.cst?.toString()?.startsWith("Stub: ") == true }
            }
            if (ldcInsnNode == null) {
                return@forEach
            }
            // 2. get class internal name
            val classInternalName = ldcInsnNode.cst.toString()
                .substring("Stub: ".length)
                .replace(".", "/")
            // 3. replace instructions of the method
            methodNode.instructions.also { it.clear() }.add(InsnList().also { insnList ->
                methodModifier(insnList, classInternalName)
            })
        }
    }

    companion object {
        private const val ABS_CACHED_VIEW_CREATOR =
            "com/bytedance/ultimate/inflater/internal/ui/view/AbsCachedViewCreator"
        private const val ABS_APP_COMPAT_VIEW_CREATOR =
            "com/bytedance/ultimate/inflater/internal/ui/view/AbsAppCompatViewCreator"
        private const val INCLUDE_VIEW_CREATOR =
            "com/bytedance/ultimate/inflater/internal/ui/view/IncludeViewCreator"
        private const val MERGE_VIEW_CREATOR =
            "com/bytedance/ultimate/inflater/internal/ui/view/MergeViewCreator"
        private const val GENERATE_VIEW_CREATOR_NAME_PREFIX =
            "com/bytedance/ultimate/inflater/internal/ui/view"

        private val classFilter = object : (ClassNode) -> Boolean {
            override fun invoke(klass: ClassNode): Boolean {
                // 1. check name
                // class name must be starts with GENERATE_VIEW_CREATOR_NAME_PREFIX
                if (klass.name.startsWith(GENERATE_VIEW_CREATOR_NAME_PREFIX).not()) {
                    return false
                }
                // class name must be ends with _ViewCreator
                if (klass.name.endsWith("_ViewCreator").not()) {
                    return false
                }
                // class name must not be built in class
                when (klass.name) {
                    ABS_CACHED_VIEW_CREATOR,
                    ABS_APP_COMPAT_VIEW_CREATOR,
                    INCLUDE_VIEW_CREATOR,
                    MERGE_VIEW_CREATOR -> return false
                }
                // 2. check superName
                return when (klass.superName) {
                    ABS_CACHED_VIEW_CREATOR,
                    ABS_APP_COMPAT_VIEW_CREATOR -> true

                    else -> false
                }
            }
        }

        // View
        private const val VIEW = "android/view/View"
        private const val CONTEXT = "android/content/Context"
        private const val ATTRIBUTE_SET = "android/util/AttributeSet"
        private val createViewPatten = Pattern.compile(
            "create[a-zA-Z]*View"
        )
        private val createViewFilter = object : (MethodNode) -> Boolean {
            override fun invoke(method: MethodNode): Boolean {
                // 1. check name
                if (createViewPatten.matcher(method.name).matches().not()) {
                    return false
                }
                // 2. check desc
                return method.desc == "(L${CONTEXT};L${ATTRIBUTE_SET};)L${VIEW};"
            }
        }

        // LayoutInflater.Factory
        private const val LAYOUT_INFLATER_FACTORY =
            "android/view/LayoutInflater\$Factory"
        private val createLayoutInflaterFactoryPatten = Pattern.compile(
            "create[a-zA-Z]*CustomLayoutInflaterFactory"
        )
        private val createLayoutInflaterFactoryFilter = object : (MethodNode) -> Boolean {
            override fun invoke(method: MethodNode): Boolean {
                // 1. check name
                if (createLayoutInflaterFactoryPatten.matcher(method.name).matches().not()) {
                    return false
                }
                // 2. check desc
                return method.desc == "()L${LAYOUT_INFLATER_FACTORY};"
            }
        }

        // Exception
        private const val ILLEGAL_STATE_EXCEPTION = "java/lang/IllegalStateException"
    }
}
