/*
 * Copyright 2010-2023 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.compile

import ksp.com.intellij.psi.PsiElement
import ksp.com.intellij.psi.util.PsiTreeUtil
import ksp.org.jetbrains.kotlin.KtFakeSourceElementKind
import ksp.org.jetbrains.kotlin.analysis.api.compile.CodeFragmentCapturedValue
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.api.LLFirResolveSession
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.api.resolveToFirSymbol
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.util.containingKtFileIfAny
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.util.parentsCodeFragmentAware
import ksp.org.jetbrains.kotlin.descriptors.ClassKind
import ksp.org.jetbrains.kotlin.fir.FirElement
import ksp.org.jetbrains.kotlin.fir.FirSession
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.toRegularClassSymbol
import ksp.org.jetbrains.kotlin.fir.declarations.*
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isInline
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isLocal
import ksp.org.jetbrains.kotlin.fir.expressions.*
import ksp.org.jetbrains.kotlin.fir.labelName
import ksp.org.jetbrains.kotlin.fir.references.FirSuperReference
import ksp.org.jetbrains.kotlin.fir.references.FirThisReference
import ksp.org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import ksp.org.jetbrains.kotlin.fir.resolve.defaultType
import ksp.org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.*
import ksp.org.jetbrains.kotlin.fir.symbols.lazyResolveToPhase
import ksp.org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import ksp.org.jetbrains.kotlin.fir.types.FirTypeRef
import ksp.org.jetbrains.kotlin.fir.types.builder.buildResolvedTypeRef
import ksp.org.jetbrains.kotlin.fir.types.resolvedType
import ksp.org.jetbrains.kotlin.fir.resolve.toSymbol
import ksp.org.jetbrains.kotlin.fir.visitors.FirDefaultVisitorVoid
import ksp.org.jetbrains.kotlin.name.StandardClassIds
import ksp.org.jetbrains.kotlin.psi
import ksp.org.jetbrains.kotlin.psi.*
import java.util.*

class CodeFragmentCapturedSymbol(
    val value: CodeFragmentCapturedValue,
    val symbol: FirBasedSymbol<*>,
    val typeRef: FirTypeRef,
)

data class CodeFragmentCapturedId(val symbol: FirBasedSymbol<*>)

object CodeFragmentCapturedValueAnalyzer {
    fun analyze(resolveSession: LLFirResolveSession, codeFragment: FirCodeFragment): CodeFragmentCapturedValueData {
        val selfSymbols = CodeFragmentDeclarationCollector().apply { codeFragment.accept(this) }.symbols.toSet()
        val capturedVisitor = CodeFragmentCapturedValueVisitor(resolveSession, selfSymbols)
        codeFragment.accept(capturedVisitor)
        return CodeFragmentCapturedValueData(capturedVisitor.values, capturedVisitor.files)
    }
}

class CodeFragmentCapturedValueData(val symbols: List<CodeFragmentCapturedSymbol>, val files: List<KtFile>)

private class CodeFragmentDeclarationCollector : FirDefaultVisitorVoid() {
    private val collectedSymbols = mutableListOf<FirBasedSymbol<*>>()

    val symbols: List<FirBasedSymbol<*>>
        get() = Collections.unmodifiableList(collectedSymbols)

    override fun visitElement(element: FirElement) {
        if (element is FirDeclaration) {
            collectedSymbols += element.symbol
        }

        element.acceptChildren(this)
    }
}

private class CodeFragmentCapturedValueVisitor(
    private val resolveSession: LLFirResolveSession,
    private val selfSymbols: Set<FirBasedSymbol<*>>,
) : FirDefaultVisitorVoid() {
    private val collectedMappings = LinkedHashMap<CodeFragmentCapturedId, CodeFragmentCapturedSymbol>()
    private val collectedFiles = LinkedHashSet<KtFile>()

    private val assignmentLhs = mutableListOf<FirBasedSymbol<*>>()

    val values: List<CodeFragmentCapturedSymbol>
        get() = collectedMappings.values.toList()

    val files: List<KtFile>
        get() = collectedFiles.toList()

    private val session: FirSession
        get() = resolveSession.useSiteFirSession

    override fun visitElement(element: FirElement) {
        processElement(element)

        val lhs = (element as? FirVariableAssignment)?.lValue?.toResolvedCallableSymbol(session)
        if (lhs != null) {
            assignmentLhs.add(lhs)
        }

        element.acceptChildren(this)

        if (lhs != null) {
            require(assignmentLhs.removeLast() == lhs)
        }
    }

    private fun processElement(element: FirElement) {
        if (element is FirExpression) {
            val symbol = element.resolvedType.toSymbol(session)
            if (symbol != null) {
                registerFile(symbol)
            }
        }

        when (element) {
            is FirSuperReference -> {
                val symbol = (element.superTypeRef as? FirResolvedTypeRef)?.toRegularClassSymbol(session)
                if (symbol != null && symbol !in selfSymbols) {
                    val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                    val capturedValue = CodeFragmentCapturedValue.SuperClass(symbol.classId, isCrossingInlineBounds)
                    register(CodeFragmentCapturedSymbol(capturedValue, symbol, element.superTypeRef))
                }
            }
            is FirThisReference -> {
                val contextReceiverNumber = element.contextReceiverNumber
                val symbol = element.boundSymbol
                if (symbol != null && symbol !in selfSymbols) {
                    when (symbol) {
                        is FirClassSymbol<*> -> {
                            if (symbol.classKind != ClassKind.OBJECT) {
                                val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                                val capturedValue = CodeFragmentCapturedValue.ContainingClass(symbol.classId, isCrossingInlineBounds)
                                val typeRef = buildResolvedTypeRef { coneType = symbol.defaultType() }
                                register(CodeFragmentCapturedSymbol(capturedValue, symbol, typeRef))
                            }
                        }
                        is FirFunctionSymbol<*>, is FirPropertySymbol -> {
                            @Suppress("USELESS_IS_CHECK") // Smart-cast is not applied from a 'when' condition in K1
                            require(symbol is FirCallableSymbol<*>)

                            if (contextReceiverNumber >= 0) {
                                val contextReceiver = symbol.resolvedContextReceivers[contextReceiverNumber]
                                val labelName = contextReceiver.labelName
                                if (labelName != null) {
                                    val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                                    val capturedValue = CodeFragmentCapturedValue
                                        .ContextReceiver(contextReceiverNumber, labelName, isCrossingInlineBounds)
                                    register(
                                        CodeFragmentCapturedSymbol(
                                            capturedValue, contextReceiver.symbol, contextReceiver.typeRef
                                        )
                                    )
                                }
                            } else {
                                val labelName = element.labelName
                                    ?: (symbol as? FirAnonymousFunctionSymbol)?.label?.name
                                    ?: symbol.name.asString()

                                val receiverParameter = symbol.receiverParameter ?: error("Receiver parameter not found")
                                val typeRef = receiverParameter.typeRef
                                val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                                val capturedValue = CodeFragmentCapturedValue.ExtensionReceiver(labelName, isCrossingInlineBounds)
                                register(
                                    CodeFragmentCapturedSymbol(capturedValue, receiverParameter.symbol, typeRef)
                                )
                            }
                        }
                    }
                }
            }
            is FirResolvable -> {
                val symbol = element.calleeReference.toResolvedCallableSymbol()
                if (symbol != null && symbol !in selfSymbols) {
                    processCall(element, symbol)
                }
            }
        }
    }

    private fun processCall(element: FirElement, symbol: FirCallableSymbol<*>) {
        // Desugared inc/dec FIR looks as follows:
        // lval <unary>: R|kotlin/Int| = R|<local>/x|
        // R|<local>/x| = R|<local>/<unary>|.R|kotlin/Int.inc|()
        // We visit the x in the first line before we visit the assignment and need to check the source to determine that the variable
        // is mutated.
        // The x in the second line isn't visited because it's a FirDesugaredAssignmentValueReferenceExpression.
        val isMutated = assignmentLhs.lastOrNull() == symbol || element.source?.kind is KtFakeSourceElementKind.DesugaredIncrementOrDecrement
        when (symbol) {
            is FirValueParameterSymbol -> {
                val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                val capturedValue = CodeFragmentCapturedValue.Local(symbol.name, isMutated, isCrossingInlineBounds)
                register(CodeFragmentCapturedSymbol(capturedValue, symbol, symbol.resolvedReturnTypeRef))
            }
            is FirPropertySymbol -> {
                if (symbol.isLocal) {
                    val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                    val capturedValue = when {
                        symbol.isForeignValue -> CodeFragmentCapturedValue.ForeignValue(symbol.name, isCrossingInlineBounds)
                        symbol.hasDelegate -> CodeFragmentCapturedValue.LocalDelegate(symbol.name, isMutated, isCrossingInlineBounds)
                        else -> CodeFragmentCapturedValue.Local(symbol.name, isMutated, isCrossingInlineBounds)
                    }
                    register(CodeFragmentCapturedSymbol(capturedValue, symbol, symbol.resolvedReturnTypeRef))
                } else {
                    // Property call generation depends on complete backing field resolution (Fir2IrLazyProperty.backingField)
                    symbol.lazyResolveToPhase(FirResolvePhase.BODY_RESOLVE)
                }
            }
            is FirBackingFieldSymbol -> {
                val propertyName = symbol.propertySymbol.name
                val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
                val capturedValue = CodeFragmentCapturedValue.BackingField(propertyName, isMutated, isCrossingInlineBounds)
                register(CodeFragmentCapturedSymbol(capturedValue, symbol, symbol.resolvedReturnTypeRef))
            }
            is FirNamedFunctionSymbol -> {
                if (symbol.isLocal) {
                    registerFile(symbol)
                }
            }
        }

        if (symbol.callableId == StandardClassIds.Callables.coroutineContext) {
            val isCrossingInlineBounds = isCrossingInlineBounds(element, symbol)
            val capturedValue = CodeFragmentCapturedValue.CoroutineContext(isCrossingInlineBounds)
            register(CodeFragmentCapturedSymbol(capturedValue, symbol, symbol.resolvedReturnTypeRef))
        }
    }

    private fun register(mapping: CodeFragmentCapturedSymbol) {
        val id = CodeFragmentCapturedId(mapping.symbol)
        val previousMapping = collectedMappings[id]

        if (previousMapping != null) {
            val previousValue = previousMapping.value
            val newValue = mapping.value

            require(previousValue.javaClass == newValue.javaClass)

            // Only replace non-mutated value with a mutated one.
            if (previousValue.isMutated || !newValue.isMutated) {
                return
            }
        }

        collectedMappings[id] = mapping
        registerFile(mapping.symbol)
    }

    private fun registerFile(symbol: FirBasedSymbol<*>) {
        val needsRegistration = when (symbol) {
            is FirRegularClassSymbol -> symbol.isLocal
            is FirAnonymousObjectSymbol -> true
            is FirNamedFunctionSymbol -> symbol.callableId.isLocal
            else -> false
        }

        if (!needsRegistration) {
            return
        }

        val file = symbol.fir.containingKtFileIfAny ?: return
        if (!file.isCompiled) {
            collectedFiles.add(file)
        }
    }

    private fun isCrossingInlineBounds(element: FirElement, symbol: FirBasedSymbol<*>): Boolean {
        val callSite = element.source?.psi ?: return false
        val declarationSite = symbol.fir.source?.psi ?: return false
        val commonParent = findCommonParentContextAware(callSite, declarationSite) ?: return false

        for (elementInBetween in callSite.parentsCodeFragmentAware) {
            if (elementInBetween === commonParent) {
                break
            }

            if (elementInBetween is KtFunction) {
                val symbolInBetween = elementInBetween.resolveToFirSymbol(resolveSession)
                if (symbolInBetween is FirCallableSymbol<*> && !symbolInBetween.isInline) {
                    return true
                }
            }
        }

        return false
    }

    private fun findCommonParentContextAware(callSite: PsiElement, declarationSite: PsiElement): PsiElement? {
        val directParent = PsiTreeUtil.findCommonParent(callSite, declarationSite)
        if (directParent != null) {
            return directParent
        }

        val codeFragment = callSite.containingFile as? KtCodeFragment ?: return null
        val codeFragmentContext = codeFragment.context ?: return null
        return PsiTreeUtil.findCommonParent(codeFragmentContext, declarationSite)
    }
}
