/*
 * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.fir.session

import ksp.org.jetbrains.kotlin.KtSourceElement
import ksp.org.jetbrains.kotlin.analyzer.common.CommonPlatformAnalyzerServices
import ksp.org.jetbrains.kotlin.config.JvmAnalysisFlags
import ksp.org.jetbrains.kotlin.config.LanguageVersionSettings
import ksp.org.jetbrains.kotlin.fir.*
import ksp.org.jetbrains.kotlin.fir.analysis.CheckersComponent
import ksp.org.jetbrains.kotlin.fir.analysis.FirDefaultOverridesBackwardCompatibilityHelper
import ksp.org.jetbrains.kotlin.fir.analysis.FirOverridesBackwardCompatibilityHelper
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.FirIdentityLessPlatformDeterminer
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.FirInlineCheckerPlatformSpecificComponent
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.FirPlatformUpperBoundsProvider
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.FirPrimaryConstructorSuperTypeCheckerPlatformComponent
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirNameConflictsTracker
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.expression.FirGenericArrayClassLiteralSupport
import ksp.org.jetbrains.kotlin.fir.analysis.jvm.FirJvmOverridesBackwardCompatibilityHelper
import ksp.org.jetbrains.kotlin.fir.analysis.jvm.checkers.FirJavaNullabilityWarningUpperBoundsProvider
import ksp.org.jetbrains.kotlin.fir.analysis.jvm.checkers.FirJvmAnnotationsPlatformSpecificSupportComponent
import ksp.org.jetbrains.kotlin.fir.analysis.jvm.checkers.FirJvmInlineCheckerComponent
import ksp.org.jetbrains.kotlin.fir.analysis.jvm.checkers.FirJvmPrimaryConstructorSuperTypeCheckerPlatformComponent
import ksp.org.jetbrains.kotlin.fir.caches.FirCachesFactory
import ksp.org.jetbrains.kotlin.fir.caches.FirThreadUnsafeCachesFactory
import ksp.org.jetbrains.kotlin.fir.declarations.*
import ksp.org.jetbrains.kotlin.fir.deserialization.FirDeserializationExtension
import ksp.org.jetbrains.kotlin.fir.extensions.*
import ksp.org.jetbrains.kotlin.fir.java.FirJavaVisibilityChecker
import ksp.org.jetbrains.kotlin.fir.java.FirJvmDefaultModeComponent
import ksp.org.jetbrains.kotlin.fir.java.FirSyntheticPropertiesStorage
import ksp.org.jetbrains.kotlin.fir.java.JvmSupertypeUpdater
import ksp.org.jetbrains.kotlin.fir.java.deserialization.FirJvmDeserializationExtension
import ksp.org.jetbrains.kotlin.fir.java.enhancement.FirAnnotationTypeQualifierResolver
import ksp.org.jetbrains.kotlin.fir.java.enhancement.FirEnhancedSymbolsStorage
import ksp.org.jetbrains.kotlin.fir.java.enhancement.JavaCompilerRequiredAnnotationEnhancementProvider
import ksp.org.jetbrains.kotlin.fir.java.scopes.JavaOverridabilityRules
import ksp.org.jetbrains.kotlin.fir.modules.FirJavaModuleResolverProvider
import ksp.org.jetbrains.kotlin.fir.resolve.*
import ksp.org.jetbrains.kotlin.fir.resolve.calls.overloads.ConeCallConflictResolverFactory
import ksp.org.jetbrains.kotlin.fir.resolve.calls.overloads.FirDeclarationOverloadabilityHelperImpl
import ksp.org.jetbrains.kotlin.fir.resolve.calls.FirSyntheticNamesProvider
import ksp.org.jetbrains.kotlin.fir.resolve.calls.jvm.JvmCallConflictResolverFactory
import ksp.org.jetbrains.kotlin.fir.resolve.inference.InferenceComponents
import ksp.org.jetbrains.kotlin.fir.resolve.providers.impl.FirQualifierResolverImpl
import ksp.org.jetbrains.kotlin.fir.resolve.providers.impl.FirTypeResolverImpl
import ksp.org.jetbrains.kotlin.fir.resolve.transformers.FirDummyCompilerLazyDeclarationResolver
import ksp.org.jetbrains.kotlin.fir.resolve.transformers.PlatformSupertypeUpdater
import ksp.org.jetbrains.kotlin.fir.resolve.transformers.mpp.FirExpectActualMatchingContextImpl
import ksp.org.jetbrains.kotlin.fir.resolve.transformers.plugin.CompilerRequiredAnnotationEnhancementProvider
import ksp.org.jetbrains.kotlin.fir.scopes.*
import ksp.org.jetbrains.kotlin.fir.scopes.impl.*
import ksp.org.jetbrains.kotlin.fir.scopes.jvm.FirJvmDelegatedMembersFilter
import ksp.org.jetbrains.kotlin.fir.scopes.jvm.JvmMappedScope.FirMappedSymbolStorage
import ksp.org.jetbrains.kotlin.fir.serialization.FirProvidedDeclarationsForMetadataService
import ksp.org.jetbrains.kotlin.fir.symbols.FirLazyDeclarationResolver
import ksp.org.jetbrains.kotlin.fir.types.*
import ksp.org.jetbrains.kotlin.incremental.components.EnumWhenTracker
import ksp.org.jetbrains.kotlin.incremental.components.ImportTracker
import ksp.org.jetbrains.kotlin.incremental.components.LookupTracker
import ksp.org.jetbrains.kotlin.resolve.jvm.JvmTypeSpecificityComparator
import ksp.org.jetbrains.kotlin.resolve.jvm.modules.JavaModuleResolver

// -------------------------- Required components --------------------------

@OptIn(SessionConfiguration::class)
fun FirSession.registerCommonComponents(languageVersionSettings: LanguageVersionSettings) {
    register(FirLanguageSettingsComponent::class, FirLanguageSettingsComponent(languageVersionSettings))
    register(TypeComponents::class, TypeComponents(this))
    register(InferenceComponents::class, InferenceComponents(this))

    register(FirDeclaredMemberScopeProvider::class, FirDeclaredMemberScopeProvider(this))
    register(FirCorrespondingSupertypesCache::class, FirCorrespondingSupertypesCache(this))
    register(FirDefaultParametersResolver::class, FirDefaultParametersResolver())

    register(FirExtensionService::class, FirExtensionService(this))

    register(FirSubstitutionOverrideStorage::class, FirSubstitutionOverrideStorage(this))
    register(FirIntersectionOverrideStorage::class, FirIntersectionOverrideStorage(this))
    register(FirSynthesizedStorage::class, FirSynthesizedStorage(this))
    register(FirGeneratedMemberDeclarationsStorage::class, FirGeneratedMemberDeclarationsStorage(this))
    register(FirSamConstructorStorage::class, FirSamConstructorStorage(this))
    register(FirOverrideService::class, FirOverrideService(this))
    register(FirDynamicMembersStorage::class, FirDynamicMembersStorage(this))
    register(FirEnumEntriesSupport::class, FirEnumEntriesSupport(this))
    register(FirOverrideChecker::class, FirStandardOverrideChecker(this))
    register(FirDeclarationOverloadabilityHelper::class, FirDeclarationOverloadabilityHelperImpl(this))
    register(FirAnnotationsPlatformSpecificSupportComponent::class, FirAnnotationsPlatformSpecificSupportComponent.Default)
    register(FirPrimaryConstructorSuperTypeCheckerPlatformComponent::class, FirPrimaryConstructorSuperTypeCheckerPlatformComponent.Default)
    register(FirGenericArrayClassLiteralSupport::class, FirGenericArrayClassLiteralSupport.Disabled)
    register(FirMissingDependencyStorage::class, FirMissingDependencyStorage(this))
    register(FirPlatformSpecificCastChecker::class, FirPlatformSpecificCastChecker.Default)
}

@OptIn(SessionConfiguration::class)
fun FirSession.registerCommonComponentsAfterExtensionsAreConfigured() {
    register(FirFunctionTypeKindService::class, FirFunctionTypeKindServiceImpl(this))
    register(FirProvidedDeclarationsForMetadataService::class, FirProvidedDeclarationsForMetadataService.create(this))
}

val firCachesFactoryForCliMode: FirCachesFactory
    get() = FirThreadUnsafeCachesFactory

@OptIn(SessionConfiguration::class)
fun FirSession.registerCliCompilerOnlyComponents(languageVersionSettings: LanguageVersionSettings) {
    register(FirCachesFactory::class, firCachesFactoryForCliMode)
    register(SealedClassInheritorsProvider::class, SealedClassInheritorsProviderImpl)
    register(FirLazyDeclarationResolver::class, FirDummyCompilerLazyDeclarationResolver)
    register(FirExceptionHandler::class, FirCliExceptionHandler)
    register(FirModulePrivateVisibilityChecker::class, FirModulePrivateVisibilityChecker.Standard(this))
    register(
        FirLookupDefaultStarImportsInSourcesSettingHolder::class,
        FirLookupDefaultStarImportsInSourcesSettingHolder.createDefault(languageVersionSettings)
    )

    register(FirRegisteredPluginAnnotations::class, FirRegisteredPluginAnnotationsImpl(this))
    register(FirPredicateBasedProvider::class, FirPredicateBasedProviderImpl(this))
}

class FirSharableJavaComponents(
    val enhancementStorage: FirEnhancedSymbolsStorage,
    val mappedStorage: FirMappedSymbolStorage
) {
    constructor(cachesFactory: FirCachesFactory) : this(
        FirEnhancedSymbolsStorage(cachesFactory),
        FirMappedSymbolStorage(cachesFactory)
    )
}

@OptIn(SessionConfiguration::class)
fun FirSession.registerJavaComponents(
    javaModuleResolver: JavaModuleResolver,
    predefinedComponents: FirSharableJavaComponents? = null,
) {
    register(FirJavaModuleResolverProvider::class, FirJavaModuleResolverProvider(javaModuleResolver))
    val jsr305State = languageVersionSettings.getFlag(JvmAnalysisFlags.javaTypeEnhancementState)
    register(
        FirAnnotationTypeQualifierResolver::class,
        FirAnnotationTypeQualifierResolver(this, jsr305State, javaModuleResolver)
    )
    register(FirEnhancedSymbolsStorage::class, predefinedComponents?.enhancementStorage ?: FirEnhancedSymbolsStorage(this))
    register(FirMappedSymbolStorage::class, predefinedComponents?.mappedStorage ?: FirMappedSymbolStorage(this))
    register(FirSyntheticPropertiesStorage::class, FirSyntheticPropertiesStorage(this))
    register(
        FirJvmDefaultModeComponent::class,
        FirJvmDefaultModeComponent(languageVersionSettings.getFlag(JvmAnalysisFlags.jvmDefaultMode))
    )
    register(PlatformSupertypeUpdater::class, JvmSupertypeUpdater(this))
    register(PlatformSpecificOverridabilityRules::class, JavaOverridabilityRules(this))
    register(FirDeserializationExtension::class, FirJvmDeserializationExtension(this))
    register(FirEnumEntriesSupport::class, FirJvmEnumEntriesSupport(this))
    register(CompilerRequiredAnnotationEnhancementProvider::class, JavaCompilerRequiredAnnotationEnhancementProvider)
    register(FirAnnotationsPlatformSpecificSupportComponent::class, FirJvmAnnotationsPlatformSpecificSupportComponent)
    register(FirPrimaryConstructorSuperTypeCheckerPlatformComponent::class, FirJvmPrimaryConstructorSuperTypeCheckerPlatformComponent)

    register(FirVisibilityChecker::class, FirJavaVisibilityChecker)
    register(ConeCallConflictResolverFactory::class, JvmCallConflictResolverFactory)
    register(
        FirTypeSpecificityComparatorProvider::class,
        FirTypeSpecificityComparatorProvider(JvmTypeSpecificityComparator(typeContext))
    )
    register(FirPlatformClassMapper::class, FirJavaClassMapper(this))
    register(FirSyntheticNamesProvider::class, FirJavaSyntheticNamesProvider)
    register(FirOverridesBackwardCompatibilityHelper::class, FirJvmOverridesBackwardCompatibilityHelper)
    register(FirInlineCheckerPlatformSpecificComponent::class, FirJvmInlineCheckerComponent())
    register(FirGenericArrayClassLiteralSupport::class, FirGenericArrayClassLiteralSupport.Enabled)
    register(FirDelegatedMembersFilter::class, FirJvmDelegatedMembersFilter(this))
    register(FirPlatformUpperBoundsProvider::class, FirJavaNullabilityWarningUpperBoundsProvider(this))
    register(FirDefaultImportProviderHolder::class, FirDefaultImportProviderHolder(FirJvmDefaultImportProvider))
}

/**
 * Registers default components for [FirSession]
 * They could be overridden by calling a function that registers specific platform components
 */
@OptIn(SessionConfiguration::class)
fun FirSession.registerDefaultComponents() {
    register(FirVisibilityChecker::class, FirVisibilityChecker.Default)
    register(ConeCallConflictResolverFactory::class, DefaultCallConflictResolverFactory)
    register(FirPlatformClassMapper::class, FirPlatformClassMapper.Default)
    register(FirOverridesBackwardCompatibilityHelper::class, FirDefaultOverridesBackwardCompatibilityHelper)
    register(FirDelegatedMembersFilter::class, FirDelegatedMembersFilter.Default)
    register(FirPlatformSpecificCastChecker::class, FirPlatformSpecificCastChecker.Default)
    register(FirDefaultImportProviderHolder::class, FirDefaultImportProviderHolder(CommonPlatformAnalyzerServices))
    register(FirIdentityLessPlatformDeterminer::class, FirIdentityLessPlatformDeterminer.Default)
}

// -------------------------- Resolve components --------------------------

/*
 * Resolve components which are same on all platforms
 */
@OptIn(SessionConfiguration::class)
fun FirSession.registerResolveComponents(lookupTracker: LookupTracker? = null, enumWhenTracker: EnumWhenTracker? = null, importTracker: ImportTracker? = null) {
    register(FirQualifierResolver::class, FirQualifierResolverImpl(this))
    register(FirTypeResolver::class, FirTypeResolverImpl(this))
    register(CheckersComponent::class, CheckersComponent())
    register(FirNameConflictsTrackerComponent::class, FirNameConflictsTracker())
    register(FirModuleVisibilityChecker::class, FirModuleVisibilityChecker.Standard(this))
    register(SourcesToPathsMapper::class, SourcesToPathsMapper())
    if (lookupTracker != null) {
        val firFileToPath: (KtSourceElement) -> String? = {
            sourcesToPathsMapper.getSourceFilePath(it)
        }
        register(
            FirLookupTrackerComponent::class,
            IncrementalPassThroughLookupTrackerComponent(lookupTracker, firFileToPath)
        )
    }
    if (enumWhenTracker != null) {
        register(
            FirEnumWhenTrackerComponent::class,
            IncrementalPassThroughEnumWhenTrackerComponent(enumWhenTracker)
        )
    }
    if (importTracker != null) {
        register(
            FirImportTrackerComponent::class,
            IncrementalPassThroughImportTrackerComponent(importTracker)
        )
    }
    register(FirExpectActualMatchingContextFactory::class, FirExpectActualMatchingContextImpl.Factory)
}

@OptIn(SessionConfiguration::class)
fun FirSession.registerModuleData(moduleData: FirModuleData) {
    register(FirModuleData::class, moduleData)
}
