package org.mule.weave.v2.module.java

import org.mule.weave.v2.core.RuntimeConfigProperties
import org.mule.weave.v2.grammar.AsOpId
import org.mule.weave.v2.grammar.literals.TypeLiteral
import org.mule.weave.v2.module.java.JavaClassHelper.toWeaveType
import org.mule.weave.v2.module.pojo.DefaultClassLoaderService
import org.mule.weave.v2.parser.SafeStringBasedParserInput
import org.mule.weave.v2.parser.ast.LocationInjectorHelper
import org.mule.weave.v2.parser.ast.annotation.AnnotationArgumentNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationArgumentsNode
import org.mule.weave.v2.parser.ast.annotation.AnnotationNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallNode
import org.mule.weave.v2.parser.ast.functions.FunctionCallParametersNode
import org.mule.weave.v2.parser.ast.functions.FunctionNode
import org.mule.weave.v2.parser.ast.functions.FunctionParameter
import org.mule.weave.v2.parser.ast.functions.FunctionParameters
import org.mule.weave.v2.parser.ast.header.directives.FunctionDirectiveNode
import org.mule.weave.v2.parser.ast.header.directives.TypeDirective
import org.mule.weave.v2.parser.ast.header.directives.VarDirective
import org.mule.weave.v2.parser.ast.header.directives.VersionDirective
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode
import org.mule.weave.v2.parser.ast.structure.ArrayNode
import org.mule.weave.v2.parser.ast.structure.QuotedStringNode
import org.mule.weave.v2.parser.ast.structure.schema.SchemaNode
import org.mule.weave.v2.parser.ast.structure.schema.SchemaPropertyNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.types.UnionTypeNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode
import org.mule.weave.v2.parser.phase._
import org.mule.weave.v2.sdk.WeaveResource

import java.lang.annotation.Annotation
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.lang.reflect.Parameter
import java.lang.reflect.Type
import java.net.URL
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class JavaModuleLoader extends ModuleLoader {

  override def loadModule(nameIdentifier: NameIdentifier, parsingContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] = {
    if (RuntimeConfigProperties.DISABLE_JAVA_MODULE_LOADER) {
      None
    } else {
      val className: String = getClassName(nameIdentifier)
      val toOption: Option[Class[_]] = loadClass(className)
      toOption
        .map((clazz) => {
          val maybeUrl: Option[URL] = Option(clazz.getResource(clazz.getSimpleName + ".class"))
          val url: String = maybeUrl.map(_.toExternalForm).getOrElse(className.replaceAll(".", "/") + ".class")
          val input: ParsingContentInput = ParsingContentInput(WeaveResource(url, ""), nameIdentifier, SafeStringBasedParserInput(""))
          val withPosition: ModuleNode = LocationInjectorHelper.injectPosition(
            if (parsingContext.settings.javaModuleLoaderLoadTypes) {
              buildModule(clazz, nameIdentifier, parsingContext)
            } else {
              buildModuleWithoutTypes(clazz, nameIdentifier, parsingContext)
            })
          val result: ParsingResult[ModuleNode] = ParsingResult(input, withPosition)
          SuccessResult(result, parsingContext)
        })
    }
  }

  private def getClassName(nameIdentifier: NameIdentifier) = nameIdentifier.nameElements().mkString(".")

  private def loadClass(className: String): Option[Class[_]] = {
    DefaultClassLoaderService.loadClass(className)
  }

  def buildModule(clazz: Class[_], nameIdentifier: NameIdentifier, parsingContext: ParsingContext): ModuleNode = {
    def typeDirectiveIdSuffix = "Class"
    val references: mutable.Map[String, WeaveTypeNode] = mutable.Map()
    val rootClassName = clazz.getName
    val getWeaveType: Type => WeaveTypeNode = (t: Type) => toWeaveType(t, typeDirectiveIdSuffix = Some(typeDirectiveIdSuffix), rootClassName = Some(rootClassName), typesReferences = references)

    val functions = buildFunctions(clazz, getWeaveType, parsingContext)
    val constructors = buildConstructors(clazz, getWeaveType)
    val varDirectives: Seq[VarDirective] = buildVariables(clazz, getWeaveType)

    val rootTypeDirective = TypeDirective(NameIdentifier(typeDirectiveIdSuffix), None, toWeaveType(clazz, expanded = true, typeDirectiveIdSuffix = Some(typeDirectiveIdSuffix), rootClassName = Some(rootClassName), typesReferences = references))
    val typeDirectives: ArrayBuffer[TypeDirective] = ArrayBuffer(rootTypeDirective)
    references.foreach((tp) => typeDirectives.+=(TypeDirective(NameIdentifier(tp._1), None, tp._2, Seq(AnnotationNode(NameIdentifier("Internal"), Some(AnnotationArgumentsNode(Seq(AnnotationArgumentNode(NameIdentifier("permits"), ArrayNode(Seq()))))))))))

    ModuleNode(nameIdentifier, Seq(new VersionDirective()) ++ typeDirectives ++ varDirectives ++ functions ++ constructors)
  }

  def buildModuleWithoutTypes(clazz: Class[_], nameIdentifier: NameIdentifier, parsingContext: ParsingContext): ModuleNode = {
    val getWeaveType: Type => WeaveTypeNode = (t: Type) => toWeaveType(t)

    val functions = buildFunctions(clazz, getWeaveType, parsingContext)
    val constructors = buildConstructors(clazz, getWeaveType)
    val varDirectives: Seq[VarDirective] = buildVariables(clazz, getWeaveType)

    ModuleNode(nameIdentifier, Seq(new VersionDirective()) ++ varDirectives ++ functions ++ constructors)
  }

  def buildFunctions(clazz: Class[_], getWeaveType: Type => WeaveTypeNode, parsingContext: ParsingContext): Array[FunctionDirectiveNode] = {
    val ordering = if (parsingContext.settings.javaModuleLoaderDeterministicFunctionsOrdering) MethodOrdering else MethodOrderingByName
    clazz.getDeclaredMethods
      .filter(method => Modifier.isStatic(method.getModifiers) && Modifier.isPublic(method.getModifiers))
      .sorted(ordering)
      .map(method =>
        FunctionDirectiveNode(
          NameIdentifier(method.getName),
          FunctionNode(
            FunctionParameters(method.getParameters.toSeq.map((param) => {
              FunctionParameter(NameIdentifier(param.getName), None, Some(mapParameterType(param, getWeaveType)))
            })),
            FunctionCallNode(
              VariableReferenceNode(javaInvokeMethodName),
              FunctionCallParametersNode(Seq(
                QuotedStringNode(clazz.getName),
                QuotedStringNode(method.getName),
                ArrayNode(method.getParameterTypes.toSeq.map((paramType) => QuotedStringNode(paramType.getName))),
                ArrayNode(method.getParameters.toSeq.map((parameter) => VariableReferenceNode(NameIdentifier(parameter.getName))))))),
            Some(mapReturnType(method, getWeaveType)))))
  }

  def buildConstructors(clazz: Class[_], getWeaveType: Type => WeaveTypeNode): Array[FunctionDirectiveNode] = {
    clazz.getDeclaredConstructors
      .filter((constructor) => Modifier.isPublic(constructor.getModifiers))
      .map((constructor) => {
        FunctionDirectiveNode(
          NameIdentifier("new"),
          FunctionNode(
            FunctionParameters(constructor.getParameters.toSeq.map((param) => {
              FunctionParameter(NameIdentifier(param.getName), None, Some(mapParameterType(param, getWeaveType)))
            })),
            FunctionCallNode(
              VariableReferenceNode(javaInvokeMethodName),
              FunctionCallParametersNode(Seq(
                QuotedStringNode(clazz.getName),
                QuotedStringNode("new"),
                ArrayNode(constructor.getParameterTypes.toSeq.map((paramType) => QuotedStringNode(paramType.getName))),
                ArrayNode(constructor.getParameters.toSeq.map((parameter) => VariableReferenceNode(NameIdentifier(parameter.getName))))))),
            Some(getWeaveType(clazz))))
      })
  }

  def buildVariables(clazz: Class[_], getWeaveType: Type => WeaveTypeNode): Seq[VarDirective] = {
    if (clazz.isEnum) {
      val constants = clazz.getEnumConstants
      constants.map((constant) => {
        val name = constant.asInstanceOf[Enum[_]].name()
        val classSchema = SchemaPropertyNode(QuotedStringNode("class"), QuotedStringNode(clazz.getName), None)
        val quotedString = BinaryOpNode(AsOpId, QuotedStringNode(name), TypeReferenceNode(NameIdentifier("Enum"), None, Some(SchemaNode(Seq(classSchema)))))
        VarDirective(NameIdentifier(name), quotedString)
      })
    } else {
      val fields = clazz.getDeclaredFields
        .filter((f) => {
          Modifier.isStatic(f.getModifiers) && Modifier.isPublic(f.getModifiers)
        })
        .map((f) => {
          val value = FunctionCallNode(VariableReferenceNode(javaFieldMethodName), FunctionCallParametersNode(Seq(QuotedStringNode(clazz.getName), QuotedStringNode(f.getName))))
          VarDirective(NameIdentifier(f.getName), value, Some(getWeaveType(f.getGenericType)))
        })
      fields
    }
  }

  def mapParameterType(param: Parameter, getWeaveType: Type => WeaveTypeNode): WeaveTypeNode = {
    val paramType = param.getType
    val declaredAnnotations: Array[Annotation] = param.getDeclaredAnnotations
    val nonNullParam: Boolean = paramType.isPrimitive || declaredAnnotations.exists((annotation) => {
      val name: String = annotation.annotationType().getName
      name.endsWith(".NonNull") || name.endsWith(".Nonnull") || name.endsWith(".NotNull")
    })
    if (nonNullParam) {
      getWeaveType(param.getParameterizedType)
    } else {
      UnionTypeNode(Seq(getWeaveType(param.getParameterizedType), TypeReferenceNode(NameIdentifier(TypeLiteral.NULL_TYPE_NAME))))
    }
  }

  def mapReturnType(method: Method, getWeaveType: Type => WeaveTypeNode): WeaveTypeNode = {
    val declaredAnnotations: Array[Annotation] = method.getDeclaredAnnotations
    val exists: Boolean = declaredAnnotations.exists((annotation) => annotation.annotationType().getName.endsWith(".Nullable"))
    //If return type is annotated with a Nullable
    if (exists) {
      val returnType = getWeaveType(method.getGenericReturnType)
      returnType match {
        case UnionTypeNode(elems, _, _, _) => UnionTypeNode(elems :+ TypeReferenceNode(NameIdentifier(TypeLiteral.NULL_TYPE_NAME)))
        case _                             => UnionTypeNode(Seq(getWeaveType(method.getGenericReturnType), TypeReferenceNode(NameIdentifier(TypeLiteral.NULL_TYPE_NAME))))

      }
    } else {
      getWeaveType(method.getGenericReturnType)
    }
  }

  def javaInvokeMethodName: NameIdentifier = {
    NameIdentifier(s"dw::java::internal::Reflection::invoke")
  }

  def javaFieldMethodName: NameIdentifier = {
    NameIdentifier(s"dw::java::internal::Reflection::field")
  }

  override def name(): Option[String] = Some("java")

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = {
    if (RuntimeConfigProperties.DISABLE_JAVA_MODULE_LOADER) {
      false
    } else {
      loadClass(getClassName(nameIdentifier)).isDefined
    }
  }
}
