package org.mule.weave.v2.module.avro.schema

import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.SchemaParseException
import org.mule.weave.v2.grammar.Identifiers
import org.mule.weave.v2.parser.MessageCollector
import org.mule.weave.v2.parser.SafeStringBasedParserInput
import org.mule.weave.v2.parser.ast.header.directives.TypeDirective
import org.mule.weave.v2.parser.ast.header.directives.VersionDirective
import org.mule.weave.v2.parser.ast.header.directives.VersionMajor
import org.mule.weave.v2.parser.ast.header.directives.VersionMinor
import org.mule.weave.v2.parser.ast.module.ModuleNode
import org.mule.weave.v2.parser.ast.structure.StringNode
import org.mule.weave.v2.parser.ast.types.KeyTypeNode
import org.mule.weave.v2.parser.ast.types.KeyValueTypeNode
import org.mule.weave.v2.parser.ast.types.LiteralTypeNode
import org.mule.weave.v2.parser.ast.types.NameTypeNode
import org.mule.weave.v2.parser.ast.types.ObjectTypeNode
import org.mule.weave.v2.parser.ast.types.TypeReferenceNode
import org.mule.weave.v2.parser.ast.types.UnionTypeNode
import org.mule.weave.v2.parser.ast.types.WeaveTypeNode
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.parser.location.UnknownLocation
import org.mule.weave.v2.parser.phase.ModuleLoader
import org.mule.weave.v2.parser.phase.ParsingContentInput
import org.mule.weave.v2.parser.phase.ParsingContext
import org.mule.weave.v2.parser.phase.ParsingResult
import org.mule.weave.v2.parser.phase.PhaseResult
import org.mule.weave.v2.sdk.NameIdentifierHelper
import org.mule.weave.v2.sdk.WeaveResource
import org.mule.weave.v2.sdk.WeaveResourceResolver
import org.mule.weave.v2.sdk.WeaveResourceResolverAware

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

class AvroSchemaModuleLoader extends ModuleLoader with WeaveResourceResolverAware {

  private val DEFAULT_ROOT_TYPE_NAME = "Root"
  private val NAMED_AVRO_TYPES = Set(Type.ENUM, Type.FIXED, Type.RECORD)
  private var resolver: WeaveResourceResolver = _

  override def loadModule(nameIdentifier: NameIdentifier, moduleContext: ParsingContext): Option[PhaseResult[ParsingResult[ModuleNode]]] = {
    val maybeResource: Option[WeaveResource] = resolveResource(nameIdentifier)
    maybeResource match {
      case Some(resource) => {
        val schemaData = resource.content()
        try {
          val schema = new Schema.Parser().parse(schemaData)
          val typeDirectives: Seq[TypeDirective] = generateTypeDirectives(schema)
          val versionDirectives = Seq(VersionDirective(VersionMajor("2"), VersionMinor("0")))
          val moduleNode: ModuleNode = ModuleNode(nameIdentifier, versionDirectives ++ typeDirectives)
          val input: ParsingContentInput = ParsingContentInput(resource, nameIdentifier, SafeStringBasedParserInput(schemaData))
          Some(new PhaseResult(Some(ParsingResult(input, moduleNode)), MessageCollector()))
        } catch {
          case spe: SchemaParseException =>
            Some(new PhaseResult(None, MessageCollector().error(InvalidAvroSchemaMessage(spe.getMessage), UnknownLocation)))
        }
      }
      case None => None
    }
  }

  private def resolveResource(nameIdentifier: NameIdentifier) = {
    List(".avsc", ".json")
      .map(ext => resolver.resolvePath(NameIdentifierHelper.toWeaveFilePath(nameIdentifier, NameIdentifierHelper.fileSeparator, ext)))
      .find(_.isDefined).flatten
  }

  private def generateTypeDirectives(schema: Schema): Seq[TypeDirective] = {
    val ctx = new AvroSchemaTransformationContext()
    val rootTypeName = if (NAMED_AVRO_TYPES.contains(schema.getType)) {
      typeName(schema)
    } else DEFAULT_ROOT_TYPE_NAME
    val rootType = toWeaveTypeNode(schema, ctx)
    val directives = new ArrayBuffer[TypeDirective]()
    if (rootTypeName == DEFAULT_ROOT_TYPE_NAME) directives += TypeDirective(NameIdentifier(rootTypeName), None, rootType)
    directives ++= ctx.typeDirectives
  }

  private def toWeaveTypeNode(schema: Schema, ctx: AvroSchemaTransformationContext): WeaveTypeNode = schema.getType match {
    case Type.STRING =>
      TypeReferenceNode(NameIdentifier("String"), None)
    case Type.BYTES =>
      val dwType = Option(schema.getLogicalType) match {
        case Some(lt) if lt.getName == "decimal" => "Number"
        case _                                   => "Binary"
      }
      TypeReferenceNode(NameIdentifier(dwType), None)
    case Type.INT =>
      val dwType = Option(schema.getLogicalType) match {
        case Some(lt) if lt.getName == "date"        => "Date"
        case Some(lt) if lt.getName == "time-millis" => "LocalTime"
        case _                                       => "Number"
      }
      TypeReferenceNode(NameIdentifier(dwType), None)
    case Type.LONG =>
      val dwType = Option(schema.getLogicalType) match {
        case Some(lt) if lt.getName == "time-micros" => "LocalTime"
        case Some(lt) if lt.getName == "timestamp-millis" => "DateTime"
        case Some(lt) if lt.getName == "timestamp-micros" => "DateTime"
        case Some(lt) if lt.getName == "local-timestamp-millis" => "LocalDateTime"
        case Some(lt) if lt.getName == "local-timestamp-micros" => "LocalDateTime"
        case _ => "Number"
      }
      TypeReferenceNode(NameIdentifier(dwType), None)
    case Type.FLOAT | Type.DOUBLE =>
      TypeReferenceNode(NameIdentifier("Number"), None)
    case Type.BOOLEAN =>
      TypeReferenceNode(NameIdentifier("Boolean"), None)
    case Type.NULL =>
      TypeReferenceNode(NameIdentifier("Null"), None)
    case Type.FIXED =>
      val dwType = Option(schema.getLogicalType) match {
        case Some(lt) if lt.getName == "decimal" => "Number"
        case _                                   => "Binary"
      }
      val fixedType = TypeReferenceNode(NameIdentifier(dwType), None)
      val tName = typeName(schema)
      ctx.addType(tName, fixedType)
      TypeReferenceNode(NameIdentifier(tName, None))
    case Type.ENUM =>
      val types: mutable.Buffer[WeaveTypeNode] = schema.getEnumSymbols.asScala.map(e => LiteralTypeNode(StringNode(e).withQuotation('\"')))
      val enumType = UnionTypeNode(types)
      val tName = typeName(schema)
      ctx.addType(tName, enumType)
      TypeReferenceNode(NameIdentifier(tName, None))
    case Type.ARRAY =>
      val elementType = toWeaveTypeNode(schema.getElementType, ctx)
      TypeReferenceNode(NameIdentifier("Array"), Some(Seq(elementType)))
    case Type.MAP =>
      val valueType = toWeaveTypeNode(schema.getValueType, ctx)
      TypeReferenceNode(NameIdentifier("Dictionary"), Some(Seq(valueType)))
    case Type.UNION =>
      val types = schema.getTypes.asScala.map(t => toWeaveTypeNode(t, ctx))
      if (types.isEmpty) {
        TypeReferenceNode(NameIdentifier("Nothing"), None)
      } else {
        UnionTypeNode(types)
      }
    case Type.RECORD =>
      val tName = typeName(schema)
      if (ctx.isTypeReference(tName)) {
        TypeReferenceNode(NameIdentifier(tName), None)
      } else {
        ctx.addTypeName(tName)
        val fields = schema.getFields.asScala.map(field => {
          KeyValueTypeNode(KeyTypeNode(NameTypeNode(Some(field.name()))), toWeaveTypeNode(field.schema(), ctx), repeated = false, optional = field.schema().isNullable || field.hasDefaultValue)
        })
        val recordType = ObjectTypeNode(fields, None, None, close = true)
        ctx.addType(tName, recordType)
        TypeReferenceNode(NameIdentifier(tName, None))
      }
  }

  private def typeName(schema: Schema): String = {
    val dwName = schema.getFullName.replace("_", "__").replace(".", "_").replaceFirst("^__", "x___")
    if (Identifiers.keywords.contains(dwName)) "a" + dwName else dwName
  }

  override def name(): Option[String] = Some("avroschema")

  override def resolver(resolver: WeaveResourceResolver): Unit = {
    this.resolver = resolver
  }

  override def canResolveModule(nameIdentifier: NameIdentifier): Boolean = resolveResource(nameIdentifier).isDefined

  private class AvroSchemaTransformationContext {

    private val types = mutable.LinkedHashMap[String, WeaveTypeNode]()

    def addTypeName(name: String): Unit = types += (name -> null)

    def addType(name: String, aType: WeaveTypeNode): Unit = types += (name -> aType)

    def isTypeReference(name: String): Boolean = types.contains(name)

    def typeDirectives(): Seq[TypeDirective] = {
      types.map(entry => TypeDirective(NameIdentifier(entry._1), None, entry._2)).toSeq
    }
  }
}
