package org.mule.weave.v2.module.avro

import org.apache.avro.AvroRuntimeException
import org.apache.avro.Conversion
import org.apache.avro.Conversions
import org.apache.avro.LogicalType
import org.apache.avro.Schema
import org.apache.avro.Schema.Field
import org.apache.avro.Schema.Type
import org.apache.avro.file.DataFileReader
import org.apache.avro.file.SeekableInput
import org.apache.avro.generic.GenericData.EnumSymbol
import org.apache.avro.generic.GenericDatumReader
import org.apache.avro.generic.GenericRecord
import org.mule.weave.v2.core.io.SeekableStream
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.capabilities.UnknownLocationCapable
import org.mule.weave.v2.model.structure.AlreadyMaterializedObjectSeq
import org.mule.weave.v2.model.structure.ArraySeq
import org.mule.weave.v2.model.structure.EagerObjectSeq
import org.mule.weave.v2.model.structure.IndexedObjectSeq
import org.mule.weave.v2.model.structure.KeyValuePair
import org.mule.weave.v2.model.structure.ObjectSeq
import org.mule.weave.v2.model.structure.QualifiedName
import org.mule.weave.v2.model.structure.SimpleObjectSeq
import org.mule.weave.v2.model.values.ArrayValue
import org.mule.weave.v2.model.values.KeyValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.module.DataFormat
import org.mule.weave.v2.module.avro.exception.AvroReadingException
import org.mule.weave.v2.module.option.ConfigurableSchemaSetting
import org.mule.weave.v2.module.pojo.reader.ReflectionJavaValueConverter
import org.mule.weave.v2.module.reader.Reader
import org.mule.weave.v2.module.reader.SourceProvider
import org.mule.weave.v2.module.reader.SourceProviderAwareReader
import org.mule.weave.v2.parser.location.SimpleLocation

import java.io.ByteArrayInputStream
import scala.collection.JavaConverters._

class AvroReader(override val sourceProvider: SourceProvider, val settings: AvroReaderSettings)(implicit ctx: EvaluationContext) extends Reader with SourceProviderAwareReader {

  override def dataFormat: Option[DataFormat[_, _]] = Some(new AvroDataFormat)

  lazy val input = new AvroInput(SeekableStream(sourceProvider.asInputStream))

  override protected def doRead(name: String): Value[_] = {

    // Deserialize users from disk
    val datumReader = {
      settings.schema match {
        case Some(ConfigurableSchemaSetting.SchemaResult(schemaResult)) =>
          val schema = new Schema.Parser().parse(new ByteArrayInputStream(schemaResult))
          new GenericDatumReader[GenericRecord](schema)
        case _ =>
          new GenericDatumReader[GenericRecord]()
      }
    }
    val dataFileReader: DataFileReader[GenericRecord] = new DataFileReader[GenericRecord](input, datumReader)
    val scala: Iterator[GenericRecord] = dataFileReader.iterator().asScala
    ArrayValue(
      ArraySeq(scala.zipWithIndex.map((gr) => {
        new AvroObjectValue(new AvroObjectSeq(gr._1), () => s"root(${gr._2})")
      }), materializedValues = true),
      UnknownLocationCapable)
  }
}

class AvroObjectSeq(record: GenericRecord) extends EagerObjectSeq with IndexedObjectSeq with SimpleObjectSeq with AlreadyMaterializedObjectSeq {

  def calculateValueRealSchema(schema: Schema, value: Any): Option[Schema] = {
    val theType = schema.getType
    if (theType == Type.NULL) {
      if (value == null) {
        Some(schema)
      } else {
        None
      }
    } else if (theType == Type.ARRAY) {
      if (value.getClass.isArray) {
        Some(schema)
      } else {
        None
      }
    } else if (schema.isUnion) {
      val option = schema.getTypes.asScala.toStream
        .flatMap((s) => calculateValueRealSchema(s, value))
        .headOption
      option
    } else if (theType == Type.ENUM) {
      value match {
        case es: EnumSymbol if (schema.getEnumSymbols.contains(es.toString)) =>
          Some(schema)
        case str: String if (schema.getEnumSymbols.contains(str)) =>
          Some(schema)
        case _ =>
          None
      }
    } else {
      val expectedClass = ConversionFactory.typeToClassMap.get(theType)
      expectedClass match {
        case Some(theClass) => {
          if (theClass.isInstance(value)) {
            Some(schema)
          } else {
            None
          }
        }
        case None => {
          Some(schema)
        }
      }
    }
  }

  private def toKeyValuePair(field: Field)(implicit ctx: EvaluationContext) = {
    val avroValue: Value[_] = toAvroValue(field)
    KeyValuePair(KeyValue(field.name()), avroValue)
  }

  private def toAvroValue(field: Field)(implicit ctx: EvaluationContext): Value[_] = {
    val value: AnyRef = record.get(field.pos())
    val realSchema = if (field.schema().isUnion) {
      val maybeSchema = calculateValueRealSchema(field.schema(), value)
      maybeSchema.getOrElse(field.schema())
    } else {
      field.schema()
    }
    val logicalType: LogicalType = realSchema.getLogicalType
    val maybeConverter: Option[Conversion[_]] = ConversionFactory.getConversion(logicalType)
    val convertedValue = maybeConverter match {
      case Some(converter) => {
        try {
          Conversions.convertToLogicalType(value, realSchema, logicalType, converter)
        } catch {
          case ae: AvroRuntimeException     => throw new AvroReadingException(ae.getMessage, SimpleLocation(field.name()))
          case ae: IllegalArgumentException => throw new AvroReadingException(ae.getMessage, SimpleLocation(field.name()))
        }
      }
      case None => value
    }
    ReflectionJavaValueConverter.convert(convertedValue, () => field.name())
  }

  override def apply(index: Long)(implicit ctx: EvaluationContext): KeyValuePair = {
    val fields = record.getSchema.getFields
    if (index >= 0 && index < size()) {
      val field: Field = fields.get(index.toInt)
      toKeyValuePair(field)
    } else {
      null
    }
  }

  override def selectKeyValue(key: Value[QualifiedName])(implicit ctx: EvaluationContext): KeyValuePair = {
    val name = key.evaluate.name
    val field: Field = record.getSchema.getField(name)
    if (field != null) {
      toKeyValuePair(field)
    } else {
      null
    }
  }

  override def selectValue(key: Value[QualifiedName])(implicit ctx: EvaluationContext): Value[_] = {
    val name = key.evaluate.name
    val field = record.getSchema.getField(name)
    if (field != null) {
      toKeyValuePair(field)._2
    } else {
      null
    }
  }

  override def allKeyValuesOf(key: Value[QualifiedName])(implicit ctx: EvaluationContext): Option[ObjectSeq] = {
    keyValueOf(key).map((kvp) => ObjectSeq(kvp))
  }

  override def size()(implicit ctx: EvaluationContext): Long = {
    record.getSchema.getFields.size()
  }

  override def isEmpty()(implicit ctx: EvaluationContext): Boolean = {
    record.getSchema.getFields.isEmpty
  }

  override def keyValueOfWithIndex(key: Value[QualifiedName])(implicit ctx: EvaluationContext): IndexedSelection[KeyValuePair] = {
    val name = key.evaluate.name
    val field = record.getSchema.getField(name)
    if (field != null) {
      IndexedSelection(field.pos(), toKeyValuePair(field))
    } else {
      null
    }
  }

  override def selectValueWithIndex(key: Value[QualifiedName])(implicit ctx: EvaluationContext): IndexedSelection[Value[_]] = {
    val name = key.evaluate.name
    val field = record.getSchema.getField(name)
    if (field != null) {
      IndexedSelection(field.pos(), toAvroValue(field))
    } else {
      null
    }
  }
}

class AvroInput(seekableStream: SeekableStream) extends SeekableInput {

  override def seek(p: Long): Unit = seekableStream.seek(p)

  override def tell(): Long = seekableStream.position()

  override def length(): Long = seekableStream.size()

  override def read(b: Array[Byte], off: Int, len: Int): Int = seekableStream.read(b, off, len)

  override def close(): Unit = seekableStream.close()
}
