package amf.xml.internal.transformer.vistors

import amf.core.client.scala.model.domain.extensions.PropertyShape
import amf.core.client.scala.model.domain.{RecursiveShape, ScalarNode, Shape}
import amf.core.internal.annotations.LexicalInformation
import amf.xml.internal.transformer.vistors.TypeToXmlSchemaVisitor.EnableUnionValidation
import amf.shapes.client.scala.model.domain.{AnyShape, ArrayShape, FileShape, NilShape, NodeShape, ScalarShape, UnionShape, XMLSerializer}
import amf.shapes.internal.domain.metamodel.AnyShapeModel.XMLSerialization
import amf.shapes.internal.domain.parser.TypeDefXsdMapping
import amf.shapes.internal.spec.common.TypeDef.{BinaryType, BoolType, ByteType, DateOnlyType, DateTimeOnlyType, DateTimeType, DoubleType, FloatType, IntType, LongType, NumberType, PasswordType, StrType, TimeOnlyType}
import amf.xml.internal.utils.XMLChar
import org.apache.ws.commons.schema.XmlSchemaContentProcessing.SKIP
import org.apache.ws.commons.schema._
import org.apache.ws.commons.schema.constants.Constants._

import java.util
import javax.xml.namespace.QName
import scala.collection.JavaConverters._
import scala.collection.mutable

case class TypeToXmlSchemaVisitor(collection: XmlSchemaCollection, schema: XmlSchema)
  extends TypeVisitor[XmlSchemaType] {
  val Unbounded: Long = Long.MaxValue

  private val currentElement: util.Stack[XmlSchemaElement] = new util.Stack[XmlSchemaElement]
  private val types: util.Map[String, XmlSchemaType] = new util.HashMap[String, XmlSchemaType]()

  def transformShape(name: String, shape: Shape): XmlSchemaElement = {
    val empty = currentElement.isEmpty
    val schemaElement = new XmlSchemaElement(schema, empty)
    schemaElement.setName(name)
    currentElement.push(schemaElement)

    val xmlSchemaType = visit(shape)
    if (xmlSchemaType != null) {
      if (xmlSchemaType.getQName != null) schemaElement.setSchemaTypeName(xmlSchemaType.getQName)
      else schemaElement.setSchemaType(xmlSchemaType)
    }

    currentElement.pop
    schemaElement
  }

  private def visit(shape: Shape): XmlSchemaType = {
    shape match {
      case s: ScalarShape => visitScalar(s)
      case s: ArrayShape => visitArray(s)
      case s: FileShape => visitFile(s)
      case s: NodeShape => visitObject(s)
      case s: UnionShape => visitUnion(s)
      case _: AnyShape | _: RecursiveShape => visitAny()
      case _ => throw new RuntimeException("Unsupported Shape")
    }
  }

  private def visitString(stringShape: ScalarShape): XmlSchemaType = { //TODO we should check if the scalar type is string
    val simpleType = new XmlSchemaSimpleType(schema, false)

    val content = new XmlSchemaSimpleTypeRestriction
    content.setBaseTypeName(XSD_STRING)

    if (stringShape.minLength.nonNull) {
      val minLength = new XmlSchemaMinLengthFacet
      minLength.setValue(stringShape.minLength)
      content.getFacets.add(minLength)
    }

    if (stringShape.maxLength.nonNull) {
      val maxLength = new XmlSchemaMaxLengthFacet
      maxLength.setValue(stringShape.maxLength)
      content.getFacets.add(maxLength)
    }

    val enums = stringShape.values
    if (enums != null) {
      enums.foreach {
        case enum: ScalarNode =>
          val enumValue = new XmlSchemaEnumerationFacet
          enumValue.setValue(enum.value.value())
          content.getFacets.add(enumValue)
      }
    }

    if (stringShape.pattern.nonNull) {
      val patternFacet = new XmlSchemaPatternFacet
      //removing wildcards from pattern
      val pattern = stringShape.pattern.value().replaceAll("^\\^", "").replaceAll("\\$$", "")
      patternFacet.setValue(pattern)
      content.getFacets.add(patternFacet)
    }

    simpleType.setContent(content)
    simpleType
  }

  override def visitObject(nodeShape: NodeShape): XmlSchemaType = {
    def parseProperties(value: XmlSchemaComplexType, properties: Seq[PropertyShape], items: mutable.MutableList[XmlSchemaElement]): Unit = {
      properties
        .sortBy { p =>
          p.annotations.find(classOf[LexicalInformation]) match {
            case Some(li) => li.range.start.line
            case None => 0
          }
        }
        .foreach {
          property =>
            visitPropertyShape(value, items, property)
        }
    }

    val typeName = getTypeName(nodeShape)

    if (typeName != null && types.containsKey(typeName)) { // With this we support recursive structures
      types.get(typeName)
    } else {
      val value = new XmlSchemaComplexType(schema, typeName != null)

      if (typeName != null) {
        value.setName(typeName)
        types.put(typeName, value)
      }

      val properties =
        if (nodeShape.isLink)
          nodeShape.linkTarget.map(_.asInstanceOf[NodeShape].properties).getOrElse(nodeShape.properties)
        else if (nodeShape.inherits.nonEmpty)
          nodeShape.inherits.head match {
            case n: NodeShape => n.properties
            case _ => nodeShape.properties
          }
        else
          nodeShape.properties

      val items: mutable.MutableList[XmlSchemaElement] = mutable.MutableList()
      parseProperties(value, properties, items)

      val xmlSchemaAll = new XmlSchemaAll
      val xmlSchemaAllItems = xmlSchemaAll.getItems
      xmlSchemaAllItems.addAll(items.asJava)

      if (!nodeShape.closed.value()) {
        val schemaAny = new XmlSchemaAny
        schemaAny.setMinOccurs(0)
        schemaAny.setMaxOccurs(Unbounded)
        schemaAny.setProcessContent(SKIP)
        xmlSchemaAllItems.add(schemaAny)
      }

      value.setParticle(xmlSchemaAll)
      value
    }
  }

  private def visitPropertyShape(value: XmlSchemaComplexType, items: mutable.MutableList[XmlSchemaElement], property: PropertyShape) = {
    val propertyValueShape = property.range

    val xmlFacetsOption: Option[XMLSerializer] = propertyValueShape.fields.entry(XMLSerialization)
      .map(_ => propertyValueShape.fields.field(XMLSerialization).asInstanceOf[XMLSerializer])

    val name: String =
      xmlFacetsOption
        .flatMap(x => x.name.option())
        .getOrElse(toValidSchemaName(property.name.value()))

    xmlFacetsOption match {
      case Some(xmlFacets) if xmlFacets.attribute.value() =>
        val xmlSchemaAttribute = new XmlSchemaAttribute(schema, false)
        xmlSchemaAttribute.setName(name)

        val isRequired = property.minCount.option().contains(1)
        if (isRequired) xmlSchemaAttribute.setUse(XmlSchemaUse.REQUIRED)

        val result = visit(propertyValueShape)
        result match {
          case x: XmlSchemaSimpleType =>
            if (result.getQName != null) xmlSchemaAttribute.setSchemaTypeName(result.getQName)
            else xmlSchemaAttribute.setSchemaType(x)
          case _ => //nothing to do
        }

        value.getAttributes.add(xmlSchemaAttribute)
      case _ =>
        val schemaElement = transformShape(name, propertyValueShape)
        val minCount = property.minCount
        if (minCount.nonNull && minCount.value() == 0) { //If property not required
          schemaElement.setMinOccurs(0)
        }
        items += schemaElement
    }
  }

  private def getTypeName(shape: Shape): String = {
    val typeName = shape.name.value()

    shape match {
      case _: ScalarShape | _: FileShape | _: NilShape => null
      case a: ArrayShape if !a.items.isLink => null
      case n: NodeShape if !n.isLink && n.inherits.isEmpty => null
      case n: NodeShape if n.isLink => toValidSchemaName(n.linkTarget.get.asInstanceOf[NodeShape].name.value())
      case n: NodeShape => toValidSchemaName(n.inherits.head.asInstanceOf[NodeShape].name.value())
      case _ => toValidSchemaName(typeName)
    }
  }

  private def toValidSchemaName(typeName: String): String = {
    val s = typeName
    val sb = new StringBuilder

    if (!XMLChar.isNameStart(s.charAt(0))) sb.append("_")

    for (c <- s.toCharArray) {
      if (!XMLChar.isName(c)) {
        if (c == '|') sb.append("or")
        else sb.append("_")
      } else sb.append(c)
    }

    sb.toString
  }

  private def createNumberSchemaType(numberTypeDefinition: ScalarShape, baseType: QName): XmlSchemaSimpleType = {
    val simpleType = new XmlSchemaSimpleType(schema, false)
    val content = new XmlSchemaSimpleTypeRestriction
    content.setBaseTypeName(baseType)
    numberTypeDefinition.minimum.option().foreach { minimum =>
      val minFacet = new XmlSchemaMinInclusiveFacet
      setNumberValue(minimum, baseType, minFacet)
      content.getFacets.add(minFacet)
    }
    numberTypeDefinition.maximum.option().foreach { maximum =>
      val maxFacet = new XmlSchemaMaxInclusiveFacet
      setNumberValue(maximum, baseType, maxFacet)
      content.getFacets.add(maxFacet)
    }
    simpleType.setContent(content)
    simpleType
  }

  private def setNumberValue(value: Double, baseType: QName, facet: XmlSchemaFacet): Unit = {
    if (baseType == XSD_INTEGER) {
      facet.setValue(value.toInt)
    } else if (baseType == XSD_LONG) {
      facet.setValue(value.toLong)
    } else {
      facet.setValue(value)
    }
  }

  override def visitFile(fileShape: FileShape): XmlSchemaType = collection.getTypeByQName(XSD_BASE64)

  override def visitNil(nilShape: NilShape): XmlSchemaType = {
    this.currentElement.peek.setNillable(true)
    collection.getTypeByQName(XSD_ANY)
  }

  override def visitArray(arrayShape: ArrayShape): XmlSchemaType = {
    val itemShape = if (arrayShape.inherits.nonEmpty) arrayShape.inherits.head else arrayShape.items

    val xmlFacets = arrayShape.fields.entry(XMLSerialization).map(_ => arrayShape.xmlSerialization)

    if (xmlFacets.exists(xml => xml.wrapped.value())) { // This is for the inside element not the wrapped. So this one is the tag for the item type
      // First uses the xml facet then the item name finally the field name or parent type name
      val xmlSerialization = itemShape match {
        case i: AnyShape => i.fields.entry(XMLSerialization).map(_ => i.xmlSerialization)
        case _ => None
      }
      val name = xmlSerialization.map(_.name.value()).getOrElse(itemShape.name.value())

      val transform = transformShape(name, itemShape)
      addArrayCardinality(arrayShape, transform)
      val value = new XmlSchemaComplexType(schema, false)
      val xmlSchemaSequence = new XmlSchemaSequence
      value.setParticle(xmlSchemaSequence)
      xmlSchemaSequence.getItems.add(transform)
      value
    } else {
      val value = visit(itemShape)
      val peek = currentElement.peek
      addArrayCardinality(arrayShape, peek)
      value
    }
  }

  def defaultTo[T](stringValue: T, defaultValue: T): T = if (stringValue == null) defaultValue else stringValue

  private def addArrayCardinality(arrayShape: ArrayShape, transform: XmlSchemaElement): Unit = {
    transform.setMinOccurs(arrayShape.minItems.option().map(_.longValue()).getOrElse(0))
    if (arrayShape.maxItems.nonNull) transform.setMaxOccurs(arrayShape.maxItems.value())
    else transform.setMaxOccurs(Unbounded)
  }

  override def visitUnion(unionShape: UnionShape): XmlSchemaType = {
    if (!EnableUnionValidation) return visitAny(unionShape)

    val choice = new XmlSchemaChoice()
    unionShape.anyOf.foreach {
      shape =>
        val xmlSchemaType: XmlSchemaChoiceMember = visit(shape) match {
          case complexType: XmlSchemaComplexType =>
            val group = new XmlSchemaGroup(schema)
            group.setName(shape.name.value())
            complexType.getParticle match {
              case particle: XmlSchemaAll =>
                val items = particle.getItems.asScala.map(_.asInstanceOf[XmlSchemaSequenceMember])
                val sequence = new XmlSchemaSequence
                sequence.getItems.addAll(items.asJava)
                group.setParticle(sequence)
              case particle: XmlSchemaGroupParticle =>
                group.setParticle(particle)
              case _ => throw new RuntimeException("Unexpected particle")
            }
            val groupRef: XmlSchemaGroupRef = new XmlSchemaGroupRef()
            groupRef.setRefName(group.getQName)
            groupRef

          case otherType =>
            val element = new XmlSchemaElement(schema, false)
            element.setSchemaType(otherType)
            element
        }
        choice.getItems.add(xmlSchemaType)
    }
    val complexType = new XmlSchemaComplexType(schema, false)
    complexType.setParticle(choice)
    complexType.setAnyAttribute(new XmlSchemaAnyAttribute)
    complexType
  }

  override def visitScalar(scalarShape: ScalarShape): XmlSchemaType = {
    val dataType = scalarShape.dataType.value()
    val format = scalarShape.format
    val typeDef = if (format != null) TypeDefXsdMapping.typeDef(dataType, format.value()) else TypeDefXsdMapping.typeDef(dataType)

    typeDef match {
      case StrType | PasswordType => visitString(scalarShape)
      case IntType => createNumberSchemaType(scalarShape, XSD_INTEGER)
      case LongType => createNumberSchemaType(scalarShape, XSD_LONG)
      case FloatType => createNumberSchemaType(scalarShape, XSD_FLOAT)
      case DoubleType => createNumberSchemaType(scalarShape, XSD_DOUBLE)
      case NumberType => createNumberSchemaType(scalarShape, XSD_DOUBLE)
      case BoolType => collection.getTypeByQName(XSD_BOOLEAN)
      case DateTimeType | DateTimeOnlyType => collection.getTypeByQName(XSD_DATETIME)
      case TimeOnlyType => collection.getTypeByQName(XSD_TIME)
      case DateOnlyType => collection.getTypeByQName(XSD_DATE)
      case ByteType => collection.getTypeByQName(XSD_BYTE)
      case BinaryType => collection.getTypeByQName(XSD_BASE64)
      case _ => throw new RuntimeException("Unsupported type")
    }
  }

  def visitAny(): XmlSchemaComplexType = createAny

  override def visitAny(anyShape: AnyShape): XmlSchemaComplexType = createAny

  private def createAny = {
    val value = new XmlSchemaComplexType(schema, false)
    val xmlSchemaSequence = new XmlSchemaChoice
    value.setParticle(xmlSchemaSequence)
    val items = xmlSchemaSequence.getItems
    val schemaAny = new XmlSchemaAny
    schemaAny.setMinOccurs(0)
    schemaAny.setMaxOccurs(Unbounded)
    schemaAny.setProcessContent(SKIP)
    items.add(schemaAny)
    value
  }
}


object TypeToXmlSchemaVisitor {
  private[transformer] val UnionValidationProperty = "amf.plugins.xml.validateUnion"

  def EnableUnionValidation: Boolean = System.getProperty(UnionValidationProperty, "false").toBoolean
}
