package dogacel.kotlinx.protobuf.gen

import com.google.protobuf.Descriptors
import com.google.protobuf.compiler.PluginProtos
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.CodeBlock
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeName
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asTypeName
import dogacel.kotlinx.protobuf.gen.DefaultValues.defaultValueOf
import dogacel.kotlinx.protobuf.gen.Utils.toFirstLower
import dogacel.kotlinx.protobuf.gen.Utils.toLowerCamelCaseIf
import kotlinx.serialization.Serializable
import java.nio.file.Path
import kotlin.io.path.Path

/**
 * Links are used to keep track of the [TypeName] of the fields and the types they reference before the code
 * is generated. This way, code generation of types do not depend on each other and can be generated in any
 * order. A type is defined as [Descriptors.GenericDescriptor] which is a common interface for all descriptors.
 * Corresponding kotlin definition type is stored in a [TypeName] which comes from `kotlinpoet` library.
 */
private typealias Link = Pair<Descriptors.GenericDescriptor, TypeName>

/**
 * A map of [Link]s that can be accessed using property access syntax.
 */
private typealias TypeLinks = Map<Descriptors.GenericDescriptor, TypeName>

/**
 * A class that generates the Kotlin code for the given protobuf files.
 */
class CodeGenerator {
    private val typeLinks: TypeLinks
    private val filesInOrder: List<Descriptors.FileDescriptor>
    private val options: CodeGeneratorOptions

    /**
     * Create a code generator for the given [PluginProtos.CodeGeneratorRequest]. This request is created from
     * the protoc compiler.
     */
    constructor(
        request: PluginProtos.CodeGeneratorRequest,
        options: CodeGeneratorOptions = CodeGeneratorOptions(),
    ) {
        this.options = options
        // https://protobuf.dev/reference/java/api-docs/com/google/protobuf/compiler/PluginProtos.CodeGeneratorRequest
        // FileDescriptorProtos for all files in files_to_generate and everything
        // they import.  The files will appear in topological order, so each file
        // appears before any file that imports it.
        val files = mutableMapOf<String, Descriptors.FileDescriptor>()
        filesInOrder =
            request.protoFileList.map { file ->
                files.computeIfAbsent(
                    file.name,
                ) {
                    val deps =
                        file.dependencyList.map { dep ->
                            files[dep] ?: throw IllegalStateException("Dependency $dep not found for file ${file.name}")
                        }
                    Descriptors.FileDescriptor.buildFrom(file, deps.toTypedArray())
                }
            }
        typeLinks =
            filesInOrder.flatMap { fileDescriptor ->
                getAllLinks(fileDescriptor, options.packagePrefix)
            }.toMap()
    }

    /**
     * Create a code generator for the given [Descriptors.FileDescriptor]s. This is useful if you want to
     * generate code for a subset of the files in a proto package.
     */
    constructor(
        vararg fileDescriptors: Descriptors.FileDescriptor,
        options: CodeGeneratorOptions = CodeGeneratorOptions(),
    ) {
        this.options = options

        if (options.autoGenerateDependencies) {
            // Extract dependencies and topologically sort the descriptors.
            val depths = mutableMapOf<Descriptors.FileDescriptor, Int>()

            fun calculator(
                descriptors: List<Descriptors.FileDescriptor>,
                depth: Int = 0,
            ) {
                descriptors.forEach {
                    depths.computeIfAbsent(it) { 0 }
                    depths[it] = depth.coerceAtLeast(depths[it]!!)

                    calculator(it.dependencies, depth + 1)
                    calculator(it.publicDependencies, depth + 1)
                }
            }

            calculator(fileDescriptors.toList())

            filesInOrder = depths.entries.sortedByDescending { it.value }.map { it.key }
        } else {
            // Assume dependencies and files come in order
            filesInOrder = fileDescriptors.toList()
        }
        typeLinks =
            filesInOrder.flatMap { fileDescriptor ->
                getAllLinks(fileDescriptor, options.packagePrefix)
            }.toMap()
    }

    /**
     * Generate the source files to the given [path].
     */
    fun generateFiles(path: Path = Path("./generated")) {
        generateFileSpecs().forEach { file ->
            file.writeTo(path)
        }
    }

    /**
     * Generate the file specs that contain the generated code without persisting them.
     */
    fun generateFileSpecs(): List<FileSpec> {
        return filesInOrder.map { fileDescriptor ->
            generateSingleFile(fileDescriptor).build()
        }
    }

    /**
     * Whether to generate a class for the given descriptor or not. Currently, we do not create classes for
     * well-known types. They are imported from "io.github.dogacel:kotlinx-protobuf-runtime-common" library.
     */
    fun shouldGenerateClass(descriptor: Descriptors.Descriptor): Boolean {
        return options.wellKnownTypes.getFor(descriptor) == null
    }

    /**
     * Generate the code for the given [Descriptors.FileDescriptor]. Returns a [FileSpec.Builder] so users
     * can add additional code to the file.
     *
     * A file contains classes and enums.
     *
     * @param fileDescriptor [Descriptors.FileDescriptor] to generate code for.
     * @return [FileSpec.Builder] that contains the generated code.
     */
    private fun generateSingleFile(fileDescriptor: Descriptors.FileDescriptor): FileSpec.Builder {
        val packageName =
            if (options.packagePrefix.isNotEmpty()) {
                options.packagePrefix + '.' + fileDescriptor.`package`
            } else {
                fileDescriptor.`package`
            }

        val fileName = fileDescriptor.name.substringAfterLast('/')
        val fileSpec = FileSpec.builder(packageName, fileName)

        fileDescriptor.messageTypes.forEach { messageType ->
            if (shouldGenerateClass(messageType)) {
                val typeSpec = generateSingleClass(messageType)
                fileSpec.addType(typeSpec.build())
            }
        }

        fileDescriptor.enumTypes.forEach { enumType ->
            val typeSpec = generateSingleEnum(enumType)
            fileSpec.addType(typeSpec.build())
        }

        if (options.generateServices) {
            fileDescriptor.services.forEach { service ->
                val typeSpec = generateSingleService(service, options.generateGrpcServices)
                fileSpec.addType(typeSpec.build())
            }
        }

        return fileSpec
    }

    /**
     * Check if the given message descriptor contains any fields or not.
     *
     * This is used to avoid generating data classes with 0 properties because they are not allowed in Kotlin.
     *
     * @param messageDescriptor the descriptor to check if it has any fields.
     * @return whether the given message descriptor contains any fields or not.
     */
    fun hasNoField(messageDescriptor: Descriptors.Descriptor): Boolean {
        return messageDescriptor.fields.isEmpty()
    }

    /**
     * Get the overriden methods for a method with no fields.
     *
     * @param messageDescriptor descriptor of the message to generate empty override methods for.
     * @return a list of [FunSpec]s that contain the overriden `toString`, `hashCode` and `equals`.
     */
    private fun getEmptyOverrideFunSpecs(messageDescriptor: Descriptors.Descriptor): List<FunSpec> {
        return listOf(
            FunSpec.builder("toString")
                .addModifiers(KModifier.OVERRIDE)
                .addStatement("return %S", messageDescriptor.name)
                .returns(String::class)
                .build(),
            FunSpec.builder("hashCode")
                .addModifiers(KModifier.OVERRIDE)
                .addStatement("return %L", messageDescriptor.name.hashCode())
                .returns(Int::class)
                .build(),
            FunSpec.builder("equals")
                .addModifiers(KModifier.OVERRIDE)
                .addParameter("other", Any::class.asTypeName().copy(nullable = true))
                .addStatement("return other is ${messageDescriptor.name}")
                .returns(Boolean::class)
                .build(),
        )
    }

    /**
     * Generate a single parameter for the given [Descriptors.FieldDescriptor]. Returns a
     * [ParameterSpec.Builder] so users can add additional code to the parameter.
     *
     * A parameter contains a name, type and default value. Parameters are used in constructors.
     *
     * @param fieldDescriptor [Descriptors.FieldDescriptor] to generate code for.
     * @return [ParameterSpec.Builder] that contains the generated code.
     */
    private fun generateSingleParameter(fieldDescriptor: Descriptors.FieldDescriptor): ParameterSpec.Builder {
        val fieldTypeName = TypeNames.typeNameOf(fieldDescriptor, typeLinks)
        val fieldName = fieldDescriptor.name.toLowerCamelCaseIf(options.useCamelCase)

        val defaultValue = defaultValueOf(fieldDescriptor, typeLinks)

        return ParameterSpec.builder(fieldName, fieldTypeName)
            .addAnnotations(Annotations.annotationsOf(fieldDescriptor))
            .defaultValue("%L", defaultValue)
    }

    /**
     * Generate a single class for the given [Descriptors.Descriptor]. Returns a [TypeSpec.Builder] so users
     * can add additional code to the class.
     *
     * A class contains subclasses, enums, parameters and properties.
     *
     * @param messageDescriptor [Descriptors.Descriptor] to generate code for.
     * @return [TypeSpec.Builder] that contains the generated code.
     */
    private fun generateSingleClass(messageDescriptor: Descriptors.Descriptor): TypeSpec.Builder {
        val typeSpec =
            TypeSpec.classBuilder(messageDescriptor.name)
                .addAnnotation(Serializable::class)

        if (hasNoField(messageDescriptor)) {
            typeSpec.addFunctions(getEmptyOverrideFunSpecs(messageDescriptor))
        } else {
            typeSpec.addModifiers(KModifier.DATA)
        }

        // A Data class needs a primary constructor with all the parameters.
        val parameters = messageDescriptor.fields.map { generateSingleParameter(it).build() }
        val constructorSpec =
            FunSpec
                .constructorBuilder()
                .addParameters(parameters)

        // A trick to handle oneof fields. We need to make sure that only one of the fields is set.
        // Validation is done in `init` block so objects in invalid states can't be initialized.
        messageDescriptor.oneofs.forEach { oneOfDescriptor ->
            if (oneOfDescriptor.fields.isNotEmpty()) {
                val codeSpec = CodeBlock.builder()
                codeSpec.addStatement("require(")
                codeSpec.indent()
                codeSpec.addStatement("listOfNotNull(")
                codeSpec.indent()
                oneOfDescriptor.fields.forEach {
                    codeSpec.addStatement("%L,", it.name.toLowerCamelCaseIf(options.useCamelCase))
                }
                codeSpec.unindent()
                codeSpec.addStatement(").size <= 1")
                codeSpec.unindent()
                codeSpec.addStatement(") { \"Should only contain one of ${oneOfDescriptor.name}.\" } ")
                constructorSpec.addCode(codeSpec.build())
            }
        }

        typeSpec.primaryConstructor(constructorSpec.build())

        // A data class should define all parameters in constructors as parameters using `val` keyword.
        messageDescriptor.fields.forEach { fieldDescriptor ->
            val type = TypeNames.typeNameOf(fieldDescriptor, typeLinks)
            val fieldName = fieldDescriptor.name.toLowerCamelCaseIf(options.useCamelCase)

            typeSpec.addProperty(
                PropertySpec.builder(fieldName, type)
                    .initializer(fieldName)
                    .build(),
            )
        }

        // Recursively generate nested classes and enums.
        val nestedTypes =
            messageDescriptor.nestedTypes.filterNot {
                it.options.mapEntry
            }
                .filter { shouldGenerateClass(it) }
                .map {
                    generateSingleClass(it).build()
                }
        typeSpec.addTypes(nestedTypes)

        val nestedEnums =
            messageDescriptor.enumTypes.map {
                generateSingleEnum(it).build()
            }
        typeSpec.addTypes(nestedEnums)

        return typeSpec
    }

    /**
     * Generate a single enum for the given [Descriptors.EnumDescriptor]. Returns a [TypeSpec.Builder] so users
     * can add additional code to the enum.
     *
     * An enum contains enum constants.
     *
     * @param enumDescriptor [Descriptors.EnumDescriptor] to generate code for.
     * @return [TypeSpec.Builder] that contains the generated code.
     */
    private fun generateSingleEnum(enumDescriptor: Descriptors.EnumDescriptor): TypeSpec.Builder {
        val typeSpec =
            TypeSpec
                .enumBuilder(enumDescriptor.name)
                .addAnnotation(Serializable::class)

        enumDescriptor.values.forEach { valueDescriptor ->
            typeSpec.addEnumConstant(
                valueDescriptor.name,
                TypeSpec.anonymousClassBuilder()
                    .addAnnotations(Annotations.annotationsOf(valueDescriptor))
                    .build(),
            )
        }

        return typeSpec
    }

    /**
     * Generate a single service for the given [Descriptors.ServiceDescriptor]. Returns a [TypeSpec.Builder] so
     * users can add additional code to the service.
     *
     * A service contains methods.
     *
     * @param serviceDescriptor [Descriptors.ServiceDescriptor] to generate code for.
     * @param isGrpcCompatible whether to generate service as a bindable gRPC or not.
     * @return [TypeSpec.Builder] that contains the generated code.
     */
    private fun generateSingleService(
        serviceDescriptor: Descriptors.ServiceDescriptor,
        isGrpcCompatible: Boolean,
    ): TypeSpec.Builder {
        val typeSpec =
            TypeSpec
                .classBuilder(serviceDescriptor.name)
                .addModifiers(KModifier.ABSTRACT)

        serviceDescriptor.methods.forEach {
            val methodFunSpec = generateSingleMethod(it)
            typeSpec.addFunction(methodFunSpec.build())
        }

        return typeSpec
    }

    /**
     * Generate a single method for the given [Descriptors.MethodDescriptor]. Returns a [FunSpec.Builder] so
     * users can add additional code to the method.
     *
     * A method contains a request and a response type.
     *
     * If the request type is a stream, the method will take a [Flow] of the response type. Otherwise, it will
     * return a single response. If the response type is a stream, the method will return a [Flow] of the
     * response type. Otherwise, it will return a single response.
     *
     * @param methodDescriptor [Descriptors.MethodDescriptor] to generate code for.
     * @return [FunSpec.Builder] that contains the generated code.
     */
    private fun generateSingleMethod(methodDescriptor: Descriptors.MethodDescriptor): FunSpec.Builder {
        val (requestType, responseType) = TypeNames.typeNameOf(methodDescriptor, typeLinks)

        val requestParamSpec =
            ParameterSpec
                .builder(methodDescriptor.inputType.name.toLowerCamelCaseIf().toFirstLower(), requestType)

        return FunSpec
            .builder(methodDescriptor.name.toFirstLower())
            .addModifiers(KModifier.OPEN)
            .addModifiers(KModifier.SUSPEND)
            .addParameter(requestParamSpec.build())
            .returns(responseType)
            .addCode("return TODO()")
    }

    /**
     * Get a [Link] to the given [Descriptors.EnumDescriptor].
     *
     * @param simpleNames if class is nested, all parent class names.
     */
    private fun getEnumLink(
        enumDescriptor: Descriptors.EnumDescriptor,
        packageName: String,
        simpleNames: List<String>,
    ): Link {
        return Link(
            enumDescriptor,
            ClassName(packageName, simpleNames + enumDescriptor.name),
        )
    }

    /**
     * Get all [Link]s for the type defined in the given [Descriptors.Descriptor] and nested types. Those nested
     * types can be other messages or enums.
     *
     * @param simpleNames if class is nested, all parent class names.
     */
    private fun getAllLinks(
        descriptor: Descriptors.Descriptor,
        packageName: String,
        simpleNames: List<String>,
    ): List<Link> {
        val wellKnownType = options.wellKnownTypes.getFor(descriptor)

        if (wellKnownType != null) {
            return listOf(Link(descriptor, wellKnownType))
        }

        val messages =
            descriptor.nestedTypes.flatMap { nestedType ->
                getAllLinks(nestedType, packageName, simpleNames + descriptor.name)
            }

        val enums =
            descriptor.enumTypes.map {
                getEnumLink(it, packageName, simpleNames + descriptor.name)
            }

        val self =
            Link(
                descriptor,
                ClassName(packageName, simpleNames + descriptor.name),
            )
        return (messages + enums + self)
    }

    /**
     * Get all [Link]s a file descriptor contains. This includes all message types, enums and dependencies.
     *
     * @param packagePrefix Prefix to add to the package name of the file descriptor.
     */
    private fun getAllLinks(
        fileDescriptor: Descriptors.FileDescriptor,
        packagePrefix: String = "",
    ): List<Link> {
        val publicDependencies =
            fileDescriptor.publicDependencies.flatMap {
                getAllLinks(it, packagePrefix)
            }

        val dependencies =
            fileDescriptor.dependencies.flatMap {
                getAllLinks(it, packagePrefix)
            }

        val packageName =
            if (packagePrefix.isNotEmpty()) {
                packagePrefix + '.' + fileDescriptor.`package`
            } else {
                fileDescriptor.`package`
            }

        val messages =
            fileDescriptor.messageTypes.flatMap {
                getAllLinks(it, packageName, listOf())
            }

        val enums =
            fileDescriptor.enumTypes.map {
                getEnumLink(it, packageName, listOf())
            }

        return (publicDependencies + dependencies + messages + enums)
    }
}
