package amf.plugins.domain.shapes.resolution.stages.shape_normalization

import amf.core.annotations._
import amf.core.metamodel.domain.ShapeModel
import amf.core.metamodel.domain.extensions.PropertyShapeModel
import amf.core.model.domain._
import amf.core.model.domain.extensions.PropertyShape
import amf.core.parser.{Annotations, FieldEntry}
import amf.plugins.domain.shapes.metamodel._
import amf.plugins.domain.shapes.models._
import amf.plugins.features.validation.CoreValidations.ResolutionValidation

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

private[stages] object ShapeCanonizer {
  def apply(s: Shape, context: NormalizationContext): Shape = ShapeCanonizer()(context).normalize(s)
}

sealed case class ShapeCanonizer()(implicit val context: NormalizationContext) extends ShapeNormalizer {

  protected def cleanUnnecessarySyntax(shape: Shape): Shape = {
    shape.annotations.reject(a => !a.isInstanceOf[PerpetualAnnotation])
    shape
  }

  private var withoutCaching = false

  private def runWithoutCaching[T](fn: () => T): T = {
    withoutCaching = true
    val t: T = fn()
    withoutCaching = false
    t
  }

  private def normalizeWithoutCaching(s: Shape): Shape = runWithoutCaching(() => normalize(s))

  private def actionWithoutCaching(s: Shape): Shape = runWithoutCaching(() => normalizeAction(s))

  override protected def normalizeAction(shape: Shape): Shape = {
    cleanUnnecessarySyntax(shape)
    val canonical = shape match {
      case union: UnionShape         => canonicalUnion(union)
      case scalar: ScalarShape       => canonicalScalar(scalar)
      case array: ArrayShape         => canonicalArray(array)
      case matrix: MatrixShape       => canonicalMatrix(matrix)
      case tuple: TupleShape         => canonicalTuple(tuple)
      case property: PropertyShape   => canonicalProperty(property)
      case fileShape: FileShape      => canonicalShape(fileShape)
      case nil: NilShape             => canonicalShape(nil)
      case node: NodeShape           => canonicalNode(node)
      case recursive: RecursiveShape => recursive
      case any: AnyShape             => canonicalAny(any)
    }
    if (!withoutCaching) context.cache + canonical // i should never add a shape if is not resolved yet
    context.cache.updateFixPointsAndClosures(canonical, withoutCaching)

    canonical
  }

  protected def canonicalLogicalConstraints(shape: Shape): Unit = {
    var oldLogicalConstraints = shape.fields.getValue(ShapeModel.And)
    if (Option(oldLogicalConstraints).isDefined) {
      val newLogicalConstraints = shape.and.map(normalize)
      shape.setArrayWithoutId(ShapeModel.And, newLogicalConstraints, oldLogicalConstraints.annotations)
    }

    oldLogicalConstraints = shape.fields.getValue(ShapeModel.Or)
    if (Option(oldLogicalConstraints).isDefined) {
      val newLogicalConstraints = shape.or.map(normalize)
      shape.setArrayWithoutId(ShapeModel.Or, newLogicalConstraints, oldLogicalConstraints.annotations)
    }

    oldLogicalConstraints = shape.fields.getValue(ShapeModel.Xone)
    if (Option(oldLogicalConstraints).isDefined) {
      val newLogicalConstraints = shape.xone.map(normalize)
      shape.setArrayWithoutId(ShapeModel.Xone, newLogicalConstraints, oldLogicalConstraints.annotations)
    }

    val notConstraint = shape.fields.getValue(ShapeModel.Not)
    if (Option(notConstraint).isDefined) {
      val newLogicalConstraint = normalize(shape.not)
      shape.set(ShapeModel.Not, newLogicalConstraint, notConstraint.annotations)
    }
  }

  private def canonicalShape(any: Shape) = {
    canonicalLogicalConstraints(any)
    if (any.inherits.nonEmpty) {
      canonicalInheritance(any)
    } else {
      any
    }
  }

  private def canonicalAny(any: AnyShape) = {
    canonicalLogicalConstraints(any)
    if (any.inherits.nonEmpty) {
      canonicalInheritance(any)
    } else {
      AnyShapeAdjuster(any)
    }
  }

  protected def canonicalScalar(scalar: ScalarShape): Shape = {
    canonicalLogicalConstraints(scalar)
    if (Option(scalar.inherits).isDefined && scalar.inherits.nonEmpty) {
      canonicalInheritance(scalar)
    } else {
      scalar
    }
  }

  protected def canonicalInheritance(shape: Shape): Shape = {
    if (endpointSimpleInheritance(shape)) {
      val referencedShape = shape.inherits.head
      aggregateExamples(shape, referencedShape)
      if (shape.annotations.contains(classOf[AutoGeneratedName])) referencedShape.add(AutoGeneratedName())
      if (!referencedShape
            .isInstanceOf[RecursiveShape]) // i need to mark the reference shape as resolved to extract to declaration in graph emitter if is a declared element
        referencedShape.annotations += ResolvedInheritance()
      normalize(referencedShape)
    } else {
      val superTypes = shape.inherits
      val oldInherits: Seq[Shape] = if (context.keepEditingInfo) shape.inherits.collect {
        case rec: RecursiveShape => rec
        case shape: Shape        => shape.link(shape.name.value()).asInstanceOf[Shape]
      } else Nil
      shape.fields.removeField(ShapeModel.Inherits) // i need to remove the resolved type without inhertis, because later it will be added to cache once it will be fully resolved
      var accShape: Shape                             = normalizeWithoutCaching(shape)
      var superShapeswithDiscriminator: Seq[AnyShape] = Nil
      var inheritedIds: Seq[String]                   = Nil

      superTypes.foreach { superNode =>
        val canonicalSuperNode = normalize(superNode)

        // we save this information to connect the references once we have computed the minShape
        if (hasDiscriminator(canonicalSuperNode))
          superShapeswithDiscriminator = superShapeswithDiscriminator ++ Seq(canonicalSuperNode.asInstanceOf[NodeShape])

        canonicalSuperNode match {
          case chain: InheritanceChain => inheritedIds ++= (Seq(canonicalSuperNode.id) ++ chain.inheritedIds)
          case _                       => inheritedIds :+= canonicalSuperNode.id
        }
        val newMinShape = context.minShape(accShape, canonicalSuperNode)
        accShape = actionWithoutCaching(newMinShape)
      }
      if (context.keepEditingInfo) accShape.annotations += InheritedShapes(oldInherits.map(_.id))
      if (!shape.id.equals(accShape.id)) {
        context.cache.registerMapping(shape.id, accShape.id)
        accShape.withId(shape.id) // i need to override id, if not i will override the father catched shape
      }

      // adjust inheritance chain if discriminator is defined
      accShape match {
        case any: AnyShape => superShapeswithDiscriminator.foreach(_.linkSubType(any))
        case _             => // ignore
      }

      // we set the full set of inherited IDs
      accShape match {
        case chain: InheritanceChain => chain.inheritedIds ++= inheritedIds
        case _                       => // ignore
      }

      shape match {
        // If the shape is a declaration that inherits of only one type and not declare any property, the it inherits the examples
        case s: NodeShape if isSimpleInheritanceDeclaration(s, superTypes) =>
          aggregateExamples(accShape, superTypes.head)
        case _ => // Nothing to do
      }

      accShape
    }
  }

  private def isSimpleInheritanceDeclaration(n: NodeShape, inherits: Seq[Shape]): Boolean =
    n.annotations.contains(classOf[DeclaredElement]) && inherits.size == 1 && n.properties.isEmpty

  private def copyExamples(from: AnyShape, to: AnyShape): Unit = {
    from.examples.foreach(e1 => {
      to.examples.find { e2 =>
        e1.id == e2.id || (e1.raw.option().getOrElse("").trim == e2.raw.option().getOrElse("").trim && e1.name
          .value() == e2.name.value())
      } match {
        case Some(toExample) =>
          // duplicated
          copyTracking(e1, toExample)
        case None =>
          e1.annotations += LocalElement()
          to.setArrayWithoutId(AnyShapeModel.Examples, to.examples ++ Seq(e1))
      }
    })
  }

  private def copyTracking(duplicate: Example, receiver: Example): Unit = {
    duplicate.annotations.find(classOf[TrackedElement]).foreach { dupAnnotation =>
      receiver.annotations += receiver.annotations
        .find(classOf[TrackedElement])
        .fold(TrackedElement(dupAnnotation.parents)) { receiverAnnotation =>
          receiver.annotations.reject(_.isInstanceOf[TrackedElement])
          TrackedElement(receiverAnnotation.parents ++ dupAnnotation.parents)
        }
    }
  }

  protected def aggregateExamples(shape: Shape, referencedShape: Shape): Unit = {
    (shape, referencedShape) match {
      case (accShape: AnyShape, refShape: AnyShape) =>
        val (from, to) =
          if (accShape.annotations.contains(classOf[DeclaredElement])) (refShape, accShape)
          else (accShape, refShape)

        copyExamples(from, to)

        val namesCache: mutable.Set[String] = mutable.Set() // duplicated names
        // we give proper names if there are more than one example, so it cannot be null
        if (to.examples.size > 1) {
          to.examples.foreach { example =>
            // we generate a unique new name if the no name or the name is already in the list of named examples
            if (example.name.option().isEmpty || namesCache.contains(example.name.value())) {
              var i    = 0
              var name = s"example_$i"
              while (namesCache.contains(name)) {
                i += 1
                name = s"example_$i"
              }
              namesCache.add(name)
              example.withName(name)
            } else namesCache.add(example.name.value())
          }
        }

      case _ => // Nothing to do
    }
  }

  def endpointSimpleInheritance(shape: Shape): Boolean = shape match {
    case any: AnyShape if any.annotations.contains(classOf[DeclaredElement]) => false
    case any: AnyShape if any.inherits.size == 1 =>
      val superType       = any.inherits.head
      val ignoredFields   = Seq(ShapeModel.Inherits, AnyShapeModel.Examples, AnyShapeModel.Name)
      val effectiveFields = shape.fields.fields().filterNot(f => ignoredFields.contains(f.field))
      validFields(effectiveFields, superType)
    case _ => false
  }

  // To be a simple inheritance, all the effective fields of the shape must be the same in the superType
  private def validFields(entries: Iterable[FieldEntry], superType: Shape): Boolean = {
    entries.foreach(e => {
      superType.fields.entry(e.field) match {
        case Some(s) if s.value.value.equals(e.value.value)                              => // Valid
        case _ if e.field == NodeShapeModel.Closed && !superType.isInstanceOf[NodeShape] => // Valid
        case _                                                                           => return false
      }
    })
    true
  }

  protected def hasDiscriminator(shape: Shape): Boolean = {
    shape match {
      case anyShape: NodeShape => anyShape.discriminator.option().isDefined
      case _                   => false
    }
  }

  protected def canonicalArray(array: ArrayShape): Shape = {
    canonicalLogicalConstraints(array)
    if (array.inherits.nonEmpty) {
      canonicalInheritance(array)
    } else {
      Option(array.items).fold(array.asInstanceOf[Shape])(i => {
        val newItems = normalize(i)
        array.annotations += ExplicitField()
        array.fields.removeField(ArrayShapeModel.Items)
        newItems match {
          case _: ArrayShape =>
            // Array items -> array must become a Matrix
            array.fields.setWithoutId(ArrayShapeModel.Items, newItems)
            array.toMatrixShape
          case _ =>
            // No union, we just set the new canonical items
            array.fields.setWithoutId(ArrayShapeModel.Items, newItems)
            array
        }
      })
    }
  }

  protected def canonicalMatrix(matrix: MatrixShape): Shape = {
    canonicalLogicalConstraints(matrix)
    if (matrix.inherits.nonEmpty) {
      canonicalInheritance(matrix)
    } else {
      Option(matrix.items) match {
        case Some(items) =>
          val newItems = normalize(items)
          matrix.fields.removeField(ArrayShapeModel.Items)
          newItems match {
            case unionItems: UnionShape =>
              val newUnionItems = unionItems.anyOf.map {
                case a: ArrayShape => matrix.cloneShape(Some(context.errorHandler)).withItems(a)
                case o             => matrix.cloneShape(Some(context.errorHandler)).toArrayShape.withItems(o)
              }
              unionItems.setArrayWithoutId(UnionShapeModel.AnyOf, newUnionItems)
              Option(matrix.fields.getValue(ShapeModel.Name)) match {
                case Some(name) => unionItems.withName(name.toString)
                case _          => unionItems
              }
            case a: ArrayShape => matrix.withItems(a)
            case _             => matrix.toArrayShape.withItems(newItems)
          }
        case _ => matrix
      }
    }
  }

  protected def canonicalTuple(tuple: TupleShape): Shape = {
    canonicalLogicalConstraints(tuple)
    if (tuple.inherits.nonEmpty) {
      canonicalInheritance(tuple)
    } else {
      var acc: Seq[Seq[Shape]] = Seq(Seq())

      val sources: Seq[Seq[Shape]] = tuple.items.map { shape =>
        normalize(shape) match {
          case other: Shape => Seq(other)
        }
      }

      sources.foreach { source =>
        source.foreach { shape =>
          acc = acc.map(_ ++ Seq(shape))
        }
      }

      if (acc.length == 1) {
        tuple.fields.setWithoutId(
          TupleShapeModel.TupleItems,
          AmfArray(acc.head),
          Option(tuple.fields.getValue(TupleShapeModel.TupleItems)).map(_.annotations).getOrElse(Annotations()))
        tuple
      } else {
        acc.map { items =>
          val newTuple = tuple.cloneShape(Some(context.errorHandler))
          newTuple.fields.setWithoutId(
            TupleShapeModel.Items,
            AmfArray(items),
            Option(tuple.fields.getValue(TupleShapeModel.Items)).map(_.annotations).getOrElse(Annotations()))
        }
        val union = UnionShape()
        union.id = tuple.id + "resolved"
        union.withName(tuple.name.value())
        union
      }
    }
  }

  protected def canonicalNode(node: NodeShape): Shape = {
    canonicalLogicalConstraints(node)
    node.add(ExplicitField())
    if (node.inherits.nonEmpty) {
      canonicalInheritance(node)
    } else {
      // We start processing the properties by cloning the base node shape
      def ensureInheritanceAnnotations(property: PropertyShape, canonicalProperty: PropertyShape) = {
        val annotationOption              = property.annotations.find(classOf[InheritanceProvenance])
        val annotationOptionFromCanonical = canonicalProperty.annotations.find(classOf[InheritanceProvenance])

        (annotationOption, annotationOptionFromCanonical) match {
          case (Some(annotation), None) => canonicalProperty.annotations += annotation
          case _                        => // Nothing
        }
      }
      val canonicalProperties: Seq[PropertyShape] = node.properties.map { propertyShape =>
        normalize(propertyShape) match {
          case canonicalProperty: PropertyShape =>
            ensureInheritanceAnnotations(propertyShape, canonicalProperty)
            canonicalProperty
          case other =>
            context.errorHandler.violation(ResolutionValidation,
                                           other.id,
                                           None,
                                           s"Resolution error: Expecting property shape, found $other",
                                           other.position(),
                                           other.location())
            propertyShape
        }
      }
      node.setArrayWithoutId(NodeShapeModel.Properties, canonicalProperties)

    }
  }

  protected def canonicalProperty(property: PropertyShape): Shape = {
    property.fields.setWithoutId(PropertyShapeModel.Range,
                                 normalize(property.range),
                                 property.fields.getValue(PropertyShapeModel.Range).annotations)
    property
  }

  protected def canonicalUnion(union: UnionShape): Shape = {
    if (union.inherits.nonEmpty) {
      canonicalInheritance(union)
    } else {
      val anyOfAcc: ListBuffer[Shape] = ListBuffer()
      union.anyOf.foreach { unionMember: Shape =>
        val normalizedUnionMember = normalizeWithoutCaching(unionMember)
        normalizedUnionMember match {
          case nestedUnion: UnionShape =>
            nestedUnion.anyOf.foreach(e => anyOfAcc += e)
          case other: Shape =>
            anyOfAcc += other
        }
      }
      val anyOfAnnotations = Option(union.fields.getValue(UnionShapeModel.AnyOf)) match {
        case Some(anyOf) => anyOf.annotations
        case _           => Annotations()
      }

      union.fields.setWithoutId(UnionShapeModel.AnyOf, AmfArray(anyOfAcc), anyOfAnnotations)

      union
    }
  }
}
