/*
 * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.backend.common.lower

import ksp.org.jetbrains.kotlin.backend.common.BackendContext
import ksp.org.jetbrains.kotlin.backend.common.BodyLoweringPass
import ksp.org.jetbrains.kotlin.backend.common.ir.isInlineFunWithReifiedParameter
import ksp.org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import ksp.org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
import ksp.org.jetbrains.kotlin.ir.builders.declarations.buildFun
import ksp.org.jetbrains.kotlin.ir.builders.irCall
import ksp.org.jetbrains.kotlin.ir.builders.irGet
import ksp.org.jetbrains.kotlin.ir.builders.irReturn
import ksp.org.jetbrains.kotlin.ir.declarations.*
import ksp.org.jetbrains.kotlin.ir.expressions.IrBody
import ksp.org.jetbrains.kotlin.ir.expressions.IrExpression
import ksp.org.jetbrains.kotlin.ir.expressions.IrFunctionReference
import ksp.org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import ksp.org.jetbrains.kotlin.ir.expressions.impl.IrFunctionReferenceImpl
import ksp.org.jetbrains.kotlin.ir.expressions.impl.fromSymbolOwner
import ksp.org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
import ksp.org.jetbrains.kotlin.ir.types.IrTypeArgument
import ksp.org.jetbrains.kotlin.ir.types.IrTypeSubstitutor
import ksp.org.jetbrains.kotlin.ir.util.SYNTHETIC_OFFSET
import ksp.org.jetbrains.kotlin.ir.util.render
import ksp.org.jetbrains.kotlin.ir.util.setDeclarationsParent
import ksp.org.jetbrains.kotlin.ir.util.typeSubstitutionMap
import ksp.org.jetbrains.kotlin.ir.visitors.IrTransformer
import ksp.org.jetbrains.kotlin.name.Name
import ksp.org.jetbrains.kotlin.utils.addToStdlib.runIf

/**
 * Replaces callable reference to an inline function with reified parameter with a callable reference to a new non-inline function
 * with substituted types.
 */
class WrapInlineDeclarationsWithReifiedTypeParametersLowering(val context: BackendContext) : BodyLoweringPass {
    private val irFactory
        get() = context.irFactory

    override fun lower(irBody: IrBody, container: IrDeclaration) {
        irBody.transformChildren(object : IrTransformer<IrDeclarationParent?>() {
            override fun visitDeclaration(declaration: IrDeclarationBase, data: IrDeclarationParent?) =
                super.visitDeclaration(declaration, declaration as? IrDeclarationParent ?: data)

            override fun visitFunctionReference(expression: IrFunctionReference, data: IrDeclarationParent?): IrExpression {
                expression.transformChildren(this, data)

                val owner = expression.symbol.owner as? IrSimpleFunction
                    ?: return expression

                if (!owner.isInlineFunWithReifiedParameter()) {
                    return expression
                }
                @Suppress("UNCHECKED_CAST")
                val typeSubstitutor = IrTypeSubstitutor(expression.typeSubstitutionMap as Map<IrTypeParameterSymbol, IrTypeArgument>)

                val function = irFactory.buildFun {
                    name = Name.identifier("${owner.name}${"$"}wrap")
                    returnType = typeSubstitutor.substitute(owner.returnType)
                    visibility = DescriptorVisibilities.LOCAL
                    origin = IrDeclarationOrigin.ADAPTER_FOR_CALLABLE_REFERENCE
                    startOffset = SYNTHETIC_OFFSET
                    endOffset = SYNTHETIC_OFFSET
                }.apply {
                    parent = data ?: error("Unable to get a proper parent while lower ${expression.render()} at ${container.render()}")
                    val irBuilder = context.createIrBuilder(symbol, SYNTHETIC_OFFSET, SYNTHETIC_OFFSET)
                    val forwardExtensionReceiverAsParam = owner.extensionReceiverParameter?.let { extensionReceiver ->
                        runIf(expression.extensionReceiver == null) {
                            addValueParameter(
                                extensionReceiver.name,
                                typeSubstitutor.substitute(extensionReceiver.type)
                            )
                            true
                        }
                    } ?: false
                    owner.valueParameters.forEach { valueParameter ->
                        addValueParameter(
                            valueParameter.name,
                            typeSubstitutor.substitute(valueParameter.type)
                        )
                    }
                    body = irFactory.createBlockBody(
                        expression.startOffset,
                        expression.endOffset
                    ) {
                        statements.add(
                            irBuilder.irReturn(
                                irBuilder.irCall(owner.symbol).also { call ->
                                    expression.extensionReceiver?.setDeclarationsParent(this@apply)
                                    expression.dispatchReceiver?.setDeclarationsParent(this@apply)
                                    val (extensionReceiver, forwardedParams) = if (forwardExtensionReceiverAsParam) {
                                        irBuilder.irGet(valueParameters.first()) to valueParameters.subList(1, valueParameters.size)
                                    } else {
                                        expression.extensionReceiver to valueParameters
                                    }
                                    call.extensionReceiver = extensionReceiver
                                    call.dispatchReceiver = expression.dispatchReceiver

                                    forwardedParams.forEachIndexed { index, valueParameter ->
                                        call.putValueArgument(index, irBuilder.irGet(valueParameter))
                                    }
                                    for (i in 0 until expression.typeArgumentsCount) {
                                        call.putTypeArgument(i, expression.getTypeArgument(i))
                                    }
                                },
                            )
                        )
                    }
                }
                return context.createIrBuilder(container.symbol).irBlock(
                    expression,
                    origin = IrStatementOrigin.ADAPTED_FUNCTION_REFERENCE
                ) {
                    +function
                    +IrFunctionReferenceImpl.fromSymbolOwner(
                        expression.startOffset,
                        expression.endOffset,
                        expression.type,
                        function.symbol,
                        function.typeParameters.size,
                        expression.reflectionTarget
                    )
                }
            }
        }, container as? IrDeclarationParent)
    }
}
