/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.fir.backend

import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.backend.utils.ConversionTypeOrigin
import org.jetbrains.kotlin.fir.backend.utils.buildSubstitutorByCalledCallable
import org.jetbrains.kotlin.fir.backend.utils.implicitCast
import org.jetbrains.kotlin.fir.backend.utils.unwrapCallRepresentative
import org.jetbrains.kotlin.fir.declarations.FirTypeAlias
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirUnitExpression
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.types.*
import org.jetbrains.kotlin.fir.visitors.FirDefaultVisitor
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.types.AbstractTypeChecker
import org.jetbrains.kotlin.types.TypeApproximatorConfiguration

class Fir2IrImplicitCastInserter(private val c: Fir2IrComponents) : Fir2IrComponents by c, FirDefaultVisitor<IrElement, IrElement>() {
    override fun visitElement(element: FirElement, data: IrElement): IrElement {
        error("Each FIR element should have it's own `visit` overload, but got ${element::class.simpleName} here")
    }

    override fun visitAnnotation(annotation: FirAnnotation, data: IrElement): IrElement = data

    override fun visitAnnotationCall(annotationCall: FirAnnotationCall, data: IrElement): IrElement = data

    override fun visitAnonymousObjectExpression(anonymousObjectExpression: FirAnonymousObjectExpression, data: IrElement): IrElement = data

    override fun visitAnonymousFunctionExpression(anonymousFunctionExpression: FirAnonymousFunctionExpression, data: IrElement): IrElement {
        return data
    }

    override fun visitBooleanOperatorExpression(booleanOperatorExpression: FirBooleanOperatorExpression, data: IrElement): IrElement = data

    override fun visitComparisonExpression(comparisonExpression: FirComparisonExpression, data: IrElement): IrElement = data

    override fun visitTypeOperatorCall(typeOperatorCall: FirTypeOperatorCall, data: IrElement): IrElement = data

    override fun visitEqualityOperatorCall(equalityOperatorCall: FirEqualityOperatorCall, data: IrElement): IrElement = data

    override fun visitLiteralExpression(literalExpression: FirLiteralExpression, data: IrElement): IrElement = data

    override fun visitThisReceiverExpression(thisReceiverExpression: FirThisReceiverExpression, data: IrElement): IrElement = data

    override fun visitQualifiedAccessExpression(qualifiedAccessExpression: FirQualifiedAccessExpression, data: IrElement): IrElement = data

    override fun visitPropertyAccessExpression(propertyAccessExpression: FirPropertyAccessExpression, data: IrElement): IrElement = data

    override fun visitResolvedQualifier(resolvedQualifier: FirResolvedQualifier, data: IrElement): IrElement = data

    override fun visitErrorResolvedQualifier(errorResolvedQualifier: FirErrorResolvedQualifier, data: IrElement): IrElement {
        // Support for error suppression case
        return visitResolvedQualifier(errorResolvedQualifier, data)
    }

    override fun visitGetClassCall(getClassCall: FirGetClassCall, data: IrElement): IrElement = data

    override fun visitFunctionCall(functionCall: FirFunctionCall, data: IrElement): IrElement = data

    override fun visitCheckNotNullCall(checkNotNullCall: FirCheckNotNullCall, data: IrElement): IrElement = data

    override fun visitCheckedSafeCallSubject(checkedSafeCallSubject: FirCheckedSafeCallSubject, data: IrElement): IrElement = data

    override fun visitSafeCallExpression(safeCallExpression: FirSafeCallExpression, data: IrElement): IrElement = data

    override fun visitStringConcatenationCall(stringConcatenationCall: FirStringConcatenationCall, data: IrElement): IrElement = data

    override fun visitArrayLiteral(arrayLiteral: FirArrayLiteral, data: IrElement): IrElement = data

    override fun visitNamedArgumentExpression(namedArgumentExpression: FirNamedArgumentExpression, data: IrElement): IrElement = data

    override fun visitVarargArgumentsExpression(varargArgumentsExpression: FirVarargArgumentsExpression, data: IrElement): IrElement = data

    override fun visitSpreadArgumentExpression(spreadArgumentExpression: FirSpreadArgumentExpression, data: IrElement): IrElement = data

    // ==================================================================================

    override fun visitExpression(expression: FirExpression, data: IrElement): IrElement {
        return when (expression) {
            is FirBlock -> (data as IrContainerExpression).insertImplicitCasts()
            is FirUnitExpression -> coerceToUnitIfNeeded(data as IrExpression, builtins)
            else -> data
        }
    }

    override fun visitStatement(statement: FirStatement, data: IrElement): IrElement {
        return when (statement) {
            is FirTypeAlias -> data
            is FirUnitExpression -> coerceToUnitIfNeeded(data as IrExpression, builtins)
            is FirBlock -> (data as IrContainerExpression).insertImplicitCasts()
            else -> statement.accept(this, data)
        }
    }

    // ==================================================================================

    override fun visitWhenExpression(whenExpression: FirWhenExpression, data: IrElement): IrElement {
        if (data is IrBlock) {
            return data.insertImplicitCasts()
        }
        val irWhen = data as IrWhen
        if (irWhen.branches.size != whenExpression.branches.size) {
            return data
        }
        val firBranchMap = irWhen.branches.zip(whenExpression.branches).toMap()
        irWhen.branches.replaceAll {
            visitWhenBranch(firBranchMap.getValue(it), it)
        }
        return data
    }

    override fun visitWhenSubjectExpression(whenSubjectExpression: FirWhenSubjectExpression, data: IrElement): IrElement = data

    override fun visitWhenBranch(whenBranch: FirWhenBranch, data: IrElement): IrBranch {
        val irBranch = data as IrBranch
        (irBranch.result as? IrContainerExpression)?.insertImplicitCasts()
        return data
    }

    override fun visitElvisExpression(elvisExpression: FirElvisExpression, data: IrElement): IrElement = data

    // ==================================================================================

    override fun visitDoWhileLoop(doWhileLoop: FirDoWhileLoop, data: IrElement): IrElement {
        val loop = data as IrDoWhileLoop
        (loop.body as? IrContainerExpression)?.insertImplicitCasts(coerceLastExpressionToUnit = true)
        return data
    }

    override fun visitWhileLoop(whileLoop: FirWhileLoop, data: IrElement): IrElement {
        val loop = data as IrWhileLoop
        (loop.body as? IrContainerExpression)?.insertImplicitCasts(coerceLastExpressionToUnit = true)
        return data
    }

    override fun visitBreakExpression(breakExpression: FirBreakExpression, data: IrElement): IrElement = data

    override fun visitContinueExpression(continueExpression: FirContinueExpression, data: IrElement): IrElement = data

    // ==================================================================================

    override fun visitTryExpression(tryExpression: FirTryExpression, data: IrElement): IrElement {
        val irTry = data as IrTry

        irTry.tryResult = irTry.tryResult.insertSpecialCast(
            tryExpression.tryBlock, tryExpression.tryBlock.resolvedType, tryExpression.resolvedType
        )
        for ((irCatch, firCatch) in irTry.catches.zip(tryExpression.catches)) {
            irCatch.result = irCatch.result.insertSpecialCast(firCatch.block, firCatch.block.resolvedType, tryExpression.resolvedType)
        }
        (irTry.finallyExpression as? IrContainerExpression)?.insertImplicitCasts(coerceLastExpressionToUnit = true)
        return data
    }

    override fun visitThrowExpression(throwExpression: FirThrowExpression, data: IrElement): IrElement =
        (data as IrThrow).insertSpecialCast(throwExpression, throwExpression.exception.resolvedType, throwExpression.resolvedType)

    override fun visitBlock(block: FirBlock, data: IrElement): IrElement =
        (data as? IrContainerExpression)?.insertImplicitCasts() ?: data

    override fun visitReturnExpression(returnExpression: FirReturnExpression, data: IrElement): IrElement {
        val irReturn = data as? IrReturn ?: return data
        val expectedType = returnExpression.target.labeledElement.returnTypeRef
        irReturn.value = irReturn.value.insertSpecialCast(
            returnExpression.result,
            returnExpression.result.resolvedType,
            expectedType.coneType
        )
        return data
    }

    // ==================================================================================

    /**
     * This functions processes the following casts:
     * - coercion to Unit
     * - nullability casts based on nullability annotations
     * - casts for dynamic types
     */
    internal fun IrExpression.insertSpecialCast(
        expression: FirExpression,
        valueType: ConeKotlinType,
        expectedType: ConeKotlinType,
    ): IrExpression {
        if (this is IrTypeOperatorCall) {
            return this
        }

        if (this is IrContainerExpression) {
            insertImplicitCasts(coerceLastExpressionToUnit = type.isUnit())
        }

        val expandedValueType = valueType.fullyExpandedType(session)
        val expandedExpectedType = expectedType.fullyExpandedType(session)

        return when {
            expandedExpectedType.isUnit -> {
                coerceToUnitIfNeeded(this, builtins)
            }
            expandedValueType is ConeDynamicType -> {
                if (expandedExpectedType !is ConeDynamicType && !expandedExpectedType.isNullableAny) {
                    generateImplicitCast(this, expandedExpectedType.toIrType(c, ConversionTypeOrigin.DEFAULT))
                } else {
                    this
                }
            }
            typeCanBeEnhancedOrFlexibleNullable(expandedValueType, session) && !expandedExpectedType.acceptsNullValues() -> {
                insertImplicitNotNullCastIfNeeded(expression)
            }
            else -> this
        }
    }

    internal fun IrExpression.insertCastForSmartcastWithIntersection(
        argumentType: ConeKotlinType,
        expectedType: ConeKotlinType
    ): IrExpression {
        if (argumentType !is ConeIntersectionType) return this
        val approximatedArgumentType = argumentType.approximateForIrOrNull(c) ?: argumentType
        if (approximatedArgumentType.isSubtypeOf(expectedType, session)) return this

        return findComponentOfIntersectionForExpectedType(argumentType, expectedType)?.let {
            generateImplicitCast(this, it.toIrType(c))
        } ?: this
    }

    private fun ConeKotlinType.acceptsNullValues(): Boolean {
        // For Captured(in Type) it only accepts nulls if `Type` does
        if (this is ConeCapturedType && this.constructor.projection.kind == ProjectionKind.IN) {
            // But `Captured(in Type)?` does accepts nulls independently of `Type`
            if (isMarkedNullable) return true
            return constructor.projection.type!!.canBeNull(session)
        }
        return canBeNull(session) || hasEnhancedNullability
    }

    private fun IrExpression.insertImplicitNotNullCastIfNeeded(expression: FirExpression): IrExpression {
        if (this is IrGetEnumValue) return this
        // [TypeOperatorLowering] will retrieve the source (from start offset to end offset) as an assertion message.
        // Avoid type casting if we can't determine the source for some reasons, e.g., implicit `this` receiver.
        if (expression.source == null) return this
        return implicitNotNullCast(this)
    }

    private fun IrContainerExpression.insertImplicitCasts(coerceLastExpressionToUnit: Boolean = false): IrContainerExpression {
        if (statements.isEmpty()) return this

        val lastIndex = statements.lastIndex
        statements.forEachIndexed { i, irStatement ->
            if (irStatement !is IrErrorCallExpression && irStatement is IrExpression) {
                if (i != lastIndex || coerceLastExpressionToUnit) {
                    statements[i] = coerceToUnitIfNeeded(irStatement, builtins)
                }
            }
        }

        return this
    }

    internal fun IrBlockBody.insertImplicitCasts(): IrBlockBody {
        if (statements.isEmpty()) return this

        statements.forEachIndexed { i, irStatement ->
            if (irStatement !is IrErrorCallExpression && irStatement is IrExpression) {
                statements[i] = coerceToUnitIfNeeded(irStatement, builtins)
            }
        }
        return this
    }

    override fun visitSmartCastExpression(smartCastExpression: FirSmartCastExpression, data: IrElement): IrElement {
        // We don't want an implicit cast to Nothing?. This expression just encompasses nullability after null check.
        return if (smartCastExpression.isStable && smartCastExpression.smartcastTypeWithoutNullableNothing == null) {
            val smartcastedType = smartCastExpression.resolvedType
            val approximatedType = smartcastedType.approximateForIrOrNull(c)
            if (approximatedType != null) {
                val originalType = smartCastExpression.originalExpression.resolvedType
                val originalNotNullType = originalType.withNullability(nullable = false, session.typeContext)
                if (originalNotNullType.isSubtypeOf(approximatedType, session)) {
                    return data
                }
            }
            implicitCastOrExpression(data as IrExpression, approximatedType ?: smartcastedType)
        } else {
            data
        }
    }

    internal fun implicitCastFromReceivers(
        originalIrReceiver: IrExpression,
        receiver: FirExpression,
        selector: FirQualifiedAccessExpression,
        typeOrigin: ConversionTypeOrigin,
    ): IrExpression {
        return implicitCastFromReceiverForIntersectionTypeOrNull(
            originalIrReceiver,
            receiver,
            selector,
            typeOrigin
        ) ?: implicitCastOrExpression(originalIrReceiver, receiver.resolvedType, typeOrigin)
    }

    private fun implicitCastFromReceiverForIntersectionTypeOrNull(
        originalIrReceiver: IrExpression,
        receiver: FirExpression,
        selector: FirQualifiedAccessExpression,
        typeOrigin: ConversionTypeOrigin,
    ): IrExpression? {
        val receiverExpressionType = receiver.resolvedType.lowerBoundIfFlexible() as? ConeIntersectionType ?: return null
        val referencedDeclaration = selector.calleeReference.toResolvedCallableSymbol()?.unwrapCallRepresentative(c)?.fir

        val receiverType = with(selector) {
            when {
                receiver === dispatchReceiver -> {
                    val dispatchReceiverType = referencedDeclaration?.dispatchReceiverType as? ConeClassLikeType ?: return null
                    dispatchReceiverType.replaceArgumentsWithStarProjections()
                }
                receiver === extensionReceiver -> {
                    val extensionReceiverType = referencedDeclaration?.receiverParameter?.typeRef?.coneType ?: return null
                    val substitutor = selector.buildSubstitutorByCalledCallable(c)
                    val substitutedType = substitutor.substituteOrSelf(extensionReceiverType)
                    // Frontend may write captured types as type arguments (by design), so we need to approximate receiver type after substitution
                    val approximatedType = session.typeApproximator.approximateToSuperType(
                        substitutedType,
                        TypeApproximatorConfiguration.InternalTypesApproximation
                    )
                    approximatedType ?: substitutedType
                }
                else -> return null
            }
        }

        return findComponentOfIntersectionForExpectedType(receiverExpressionType, receiverType)?.let {
            implicitCastOrExpression(originalIrReceiver, it, typeOrigin)
        }
    }

    private fun findComponentOfIntersectionForExpectedType(type: ConeIntersectionType, expectedType: ConeKotlinType): ConeKotlinType? {
        for (componentType in type.intersectedTypes) {
            if (AbstractTypeChecker.isSubtypeOf(session.typeContext, componentType, expectedType)) {
                return componentType
            }
        }
        return null
    }

    private fun implicitCastOrExpression(
        original: IrExpression, castType: ConeKotlinType, typeOrigin: ConversionTypeOrigin = ConversionTypeOrigin.DEFAULT
    ): IrExpression {
        return implicitCastOrExpression(original, castType.toIrType(c, typeOrigin))
    }

    companion object {
        internal fun implicitCastOrExpression(original: IrExpression, castType: IrType): IrExpression {
            val originalNotNull = original.type.makeNotNull()
            if (originalNotNull == castType.makeNotNull()) return original
            return generateImplicitCast(original, castType)
        }

        private fun generateImplicitCast(original: IrExpression, castType: IrType): IrExpression {
            val typeOperator = if (original.type is IrDynamicType) {
                IrTypeOperator.IMPLICIT_DYNAMIC_CAST
            } else {
                IrTypeOperator.IMPLICIT_CAST
            }

            return implicitCast(original, castType, typeOperator)
        }

        internal fun coerceToUnitIfNeeded(original: IrExpression, builtins: Fir2IrBuiltinSymbolsContainer): IrExpression {
            val valueType = original.type
            return if (valueType.isUnit() || valueType.isNothing())
                original
            else
                IrTypeOperatorCallImpl(
                    original.startOffset, original.endOffset,
                    builtins.unitType,
                    IrTypeOperator.IMPLICIT_COERCION_TO_UNIT,
                    builtins.unitType,
                    original
                )
        }

        fun implicitNotNullCast(original: IrExpression): IrTypeOperatorCall {
            // Cast type massage 1. Remove @EnhancedNullability
            // Cast type massage 2. Convert it to a non-null variant (in case of @FlexibleNullability)
            val castType = original.type.removeAnnotations { annotationCall ->
                val constructorSymbol = annotationCall.symbol.takeIf { it.isBound } ?: return@removeAnnotations false
                /*
                 * @EnhancedNullability and @FlexibleNullability are symbols from builtins and should be already
                 *   bound at the time of body conversion, so it's safe to take the owner for them
                 * If symbol is unbound then this annotation can not be neither @EnhancedNullability or @FlexibleNullability
                 */
                @OptIn(UnsafeDuringIrConstructionAPI::class)
                val classId = constructorSymbol.owner.parentAsClass.classId
                classId == StandardClassIds.Annotations.EnhancedNullability ||
                        classId == StandardClassIds.Annotations.FlexibleNullability
            }.makeNotNull()
            return IrTypeOperatorCallImpl(
                original.startOffset,
                original.endOffset,
                castType,
                IrTypeOperator.IMPLICIT_NOTNULL,
                castType,
                original
            )
        }

        internal fun typeCanBeEnhancedOrFlexibleNullable(type: ConeKotlinType, session: FirSession): Boolean {
            return when {
                type.hasEnhancedNullability -> true
                type.hasFlexibleMarkedNullability && type.canBeNull(session) -> true
                else -> false
            }
        }
    }
}
