package org.mule.weave.v2.ts

import org.mule.weave.v2.annotations.WeaveApi
import org.mule.weave.v2.parser.ast.variables.NameIdentifier
import org.mule.weave.v2.ts.TypeHelper.resolveIntersection
import org.mule.weave.v2.utils.IdentityHashMap

import scala.collection.mutable

object WeaveTypeTraverse {

  /**
    * Maps the specified type recursively
    *
    * @param weaveType The type to map
    * @param callback  The callback
    * @return The transformed type
    */
  @WeaveApi(Seq("data-weave-agent"))
  def treeMap(weaveType: WeaveType, callback: PartialFunction[WeaveType, WeaveType], recursionDetector: RecursionDetector[WeaveType] = TypeHelper.createRecursionDetector()): WeaveType = {
    internalTreeMap(weaveType, callback, recursionDetector)
  }

  private def internalTreeMap(weaveType: WeaveType, callback: PartialFunction[WeaveType, WeaveType], recursionDetector: RecursionDetector[WeaveType]): WeaveType = {

    def mapReferenceType(rt: ReferenceType): WeaveType = {
      rt match {
        case srt: SimpleReferenceType => {
          if (srt.typeParams.isEmpty || srt.typeRef.refResolver().isEmpty) {
            recursionDetector.resolve(srt, {
              case tp: TypeParameter => {
                internalTreeMap(tp, callback, recursionDetector)
              }
              case _ => {
                srt
              }
            })
          } else {
            val params: Seq[WeaveType] = srt.typeParams.get
            val mappedParams: Seq[WeaveType] = params.map((tp) => internalTreeMap(tp, callback, recursionDetector))
            if (params.size != mappedParams.size || mappedParams.zip(params).exists((pair) => !(pair._1 eq pair._2))) {
              val resolver: DefaultTypeReference = new DefaultTypeReference(srt.refName, Some(mappedParams), srt.typeRef.refResolver().get)
              srt.copy(typeParams = Some(mappedParams), typeRef = resolver)
            } else {
              srt
            }
          }
        }
        case TypeSelectorType(prefix, refName, referencedType, referenceResolver) => {
          val mappedType = internalTreeMap(referencedType, callback, recursionDetector)
          if (mappedType eq referencedType) {
            rt
          } else {
            TypeSelectorType(prefix, refName, mappedType.asInstanceOf[ReferenceType], referenceResolver)
          }
        }
      }
    }

    val mappedType: WeaveType = callback
      .orElse[WeaveType, WeaveType]({
        case ot @ ObjectType(properties, _, _) => {
          val mappedProperties = properties
            .map((kvt) => internalTreeMap(kvt, callback, recursionDetector))
            .collect({ case kvt: KeyValuePairType => kvt })
          if (properties.size != mappedProperties.size || mappedProperties.zip(properties).exists((pair) => !(pair._1 eq pair._2))) {
            val objectType = ot.copy(properties = mappedProperties)
            objectType.label(weaveType.label())
            objectType
          } else {
            ot
          }
        }
        case kvt @ KeyValuePairType(key, value, _, _) => {
          val mappedKey = internalTreeMap(key, callback, recursionDetector)
          val mappedValue = internalTreeMap(value, callback, recursionDetector)
          if (!(key eq mappedKey) || !(value eq mappedValue)) {
            kvt.copy(key = mappedKey, value = mappedValue)
          } else {
            kvt
          }
        }
        case rt: ReferenceType => {
          mapReferenceType(rt)
        }
        case kt @ KeyType(name, attrs) => {
          val mappedName = internalTreeMap(name, callback, recursionDetector)
          val mappedAttributes = attrs
            .map(internalTreeMap(_, callback, recursionDetector))
            .collect({ case nvt: NameValuePairType => nvt })
          if (!(mappedName eq name) || attrs.zip(mappedAttributes).exists((pair) => !(pair._1 eq pair._2))) {
            //Something has changed
            KeyType(mappedName, mappedAttributes)
          } else {
            kt
          }
        }
        case nvp @ NameValuePairType(name, value, optional) => {
          val mappedName = internalTreeMap(name, callback, recursionDetector)
          val mappedValue = internalTreeMap(value, callback, recursionDetector)
          if (!(mappedName eq name) || !(mappedValue eq value)) {
            NameValuePairType(mappedName, mappedValue, optional)
          } else {
            nvp
          }
        }
        case tt @ TypeType(t) => {
          val weaveTypeType = internalTreeMap(t, callback, recursionDetector)
          if (weaveTypeType eq t) {
            tt
          } else {
            val typeType = TypeType(weaveTypeType)
            typeType.label(weaveType.label())
            typeType
          }
        }
        case at @ ArrayType(of) => {
          val mappedItem = internalTreeMap(of, callback, recursionDetector)
          if (mappedItem eq (of)) {
            at
          } else {
            val arrayType = ArrayType(mappedItem)
            arrayType.label(weaveType.label())
            arrayType
          }
        }
        case ut @ UnionType(of) => {
          val mappedTypes = of.map(internalTreeMap(_, callback, recursionDetector))
          val newTypes = mappedTypes.zip(of).exists((pair) => !(pair._1 eq pair._2))
          if (newTypes) {
            val unionType = UnionType(mappedTypes)
            unionType.label(weaveType.label())
            unionType
          } else {
            ut
          }
        }
        case it @ IntersectionType(of) => {
          val weaveTypes = of.map(internalTreeMap(_, callback, recursionDetector))
          if (of.zip(weaveTypes).exists((pair) => !(pair._1 eq pair._2))) {
            TypeHelper.intersec(weaveTypes)
          } else {
            it
          }
        }
        case FunctionType(typeParams, params, returnType, overloads, name, customReturnTypeResolver) => {
          val functionType = if (overloads.isEmpty) {
            FunctionType(
              typeParams.map(tp => {
                val top = if (tp.top.isDefined) Some(internalTreeMap(tp.top.get, callback, recursionDetector)) else tp.top
                val bottom = if (tp.bottom.isDefined) Some(internalTreeMap(tp.bottom.get, callback, recursionDetector)) else tp.bottom
                TypeParameter(tp.name, top, bottom, tp.instanceId, tp.noImplicitBounds)
              }),
              params.map((ftp) => {
                FunctionTypeParameter(ftp.name, internalTreeMap(ftp.wtype, callback, recursionDetector), ftp.optional, ftp.defaultValueType)
              }),
              internalTreeMap(returnType, callback, recursionDetector),
              overloads,
              name,
              customReturnTypeResolver)
          } else {
            FunctionType(typeParams, params, returnType, overloads.map(internalTreeMap(_, callback, recursionDetector)).collect({ case ft: FunctionType => ft }), name, customReturnTypeResolver)
          }
          functionType.label(weaveType.label())
          functionType
        }
        case DynamicReturnType(arguments, node, typeGraph, scope, name, expectedReturnType, resolver) =>
          val mappedParameters = arguments.map((ftp) => {
            FunctionTypeParameter(ftp.name, internalTreeMap(ftp.wtype, callback, recursionDetector), ftp.optional, ftp.defaultValueType)
          })
          val mappedExpectedReturnType = expectedReturnType.map(internalTreeMap(_, callback, recursionDetector))
          DynamicReturnType(mappedParameters, node, typeGraph, scope, name, mappedExpectedReturnType, resolver)
        case wt => wt
      })
      .apply(weaveType)

    if (mappedType eq weaveType) {
      //If we are returning same type keep old type
      weaveType
    } else {
      WeaveTypeCloneHelper.copyAdditionalTypeInformation(weaveType, mappedType)
      mappedType
    }
  }

  /**
    * Maps the specified type recursively
    *
    * @param weaveType The type to map
    * @param callback  The callback
    * @return The transformed type
    */
  def treeExists(weaveType: WeaveType, callback: PartialFunction[WeaveType, Boolean], recursionDetector: RecursionDetector[Boolean]): Boolean = {
    callback
      .orElse[WeaveType, Boolean]({
        case ot @ ObjectType(properties, _, _) => {
          properties.exists(treeExists(_, callback, recursionDetector))
        }
        case kvt @ KeyValuePairType(key, value, _, _) => {
          treeExists(key, callback, recursionDetector) || treeExists(value, callback, recursionDetector)
        }
        case rt: ReferenceType => {
          recursionDetector.resolve(rt, (referencedType) => treeExists(referencedType, callback, recursionDetector))
        }
        case KeyType(name, attrs) => {
          treeExists(name, callback, recursionDetector) ||
            attrs.exists(treeExists(_, callback, recursionDetector))
        }
        case NameValuePairType(name, value, optional) => {
          treeExists(name, callback, recursionDetector) ||
            treeExists(value, callback, recursionDetector)
        }
        case TypeType(t) => {
          treeExists(t, callback, recursionDetector)
        }
        case ArrayType(of) => {
          treeExists(of, callback, recursionDetector)
        }
        case UnionType(of) => {
          of.exists(treeExists(_, callback, recursionDetector))
        }
        case IntersectionType(of) => {
          of.exists(treeExists(_, callback, recursionDetector))
        }
        case FunctionType(_, params, returnType, _, _, _) => {
          params.exists((pt) => treeExists(pt.wtype, callback, recursionDetector)) ||
            treeExists(returnType, callback, recursionDetector)
        }
        case TypeParameter(_, top, bottom, _, _) => {
          top.exists(treeExists(_, callback, recursionDetector)) ||
            bottom.exists(treeExists(_, callback, recursionDetector))
        }
        case wt => false
      })
      .apply(weaveType)
  }

  def containsArrayOfNothing(weaveType: WeaveType): Boolean = {
    exists(weaveType, {
      case ArrayType(_: NothingType) => true
      case _                         => false
    })
  }

  def containsTypeParameter(weaveType: WeaveType): Boolean = {
    exists(weaveType, {
      case _: TypeParameter => true
      case _                => false
    })
  }

  def containsDynamicReturnType(weaveType: WeaveType): Boolean = {
    exists(weaveType, {
      case _: DynamicReturnType => true
      case _                    => false
    })
  }

  /**
    * Check if two types are structurally equals
    *
    * @param right The type of the assignment expression
    * @param left  The type expected type of the assignment
    * @return True they are equals
    */
  def equalsWith(left: WeaveType, right: WeaveType, recursionDetector: RecursionDetector[Boolean] = RecursionDetector((_, _) => true)): Boolean = {
    (left eq right) || {
      left match {
        case lt: ReferenceType => {
          equalsWith(lt.resolveType(), right, recursionDetector)
        }
        case IntersectionType(of) =>
          val resolvedLeft = resolveIntersection(of)
          resolvedLeft match {
            case lit: IntersectionType => {
              right match {
                case rit: IntersectionType => {
                  resolveIntersection(rit.of) match {
                    case resolvedRightIT: IntersectionType => {
                      //{T & U} != {T & U}
                      compareSeq(resolvedRightIT.of, lit.of, recursionDetector)
                    }
                    case _ => {
                      false
                    }
                  }
                }
                case _ => false
              }
            }
            case _ => {
              equalsWith(resolvedLeft, right, recursionDetector)
            }
          }

        case _ => {
          right match {
            case IntersectionType(of) => {
              val resolvedIntersection = resolveIntersection(of)
              resolvedIntersection match {
                case _: IntersectionType => false
                case _                   => equalsWith(left, resolvedIntersection, recursionDetector)
              }
            }
            case rt: ReferenceType => {
              recursionDetector.resolve(rt, (referencedType) => equalsWith(left, referencedType, recursionDetector))
            }
            case ObjectType(rProperties, rClose, rOrdered) => {
              left match {
                case ObjectType(lProperties, lClose, lOrdered) => {
                  if (rClose == lClose && rOrdered == lOrdered) {
                    compareSeq(lProperties, rProperties, recursionDetector)
                  } else {
                    false
                  }
                }
                case _ => false
              }
            }
            case KeyValuePairType(rKey, rValue, rOptional, rRepeated) => {
              left match {
                case KeyValuePairType(lKey, lValue, lOptional, lRepeated) => {
                  rOptional == lOptional &&
                    rRepeated == lRepeated &&
                    equalsWith(lKey, rKey, recursionDetector) &&
                    equalsWith(lValue, rValue, recursionDetector)
                }
                case _ => false
              }
            }
            case KeyType(rName, rAttributes) => {
              left match {
                case KeyType(lName, lAttributes) => {
                  equalsWith(lName, rName, recursionDetector) &&
                    compareSeq(lAttributes, rAttributes, recursionDetector)
                }
                case _ => false
              }
            }
            case NameValuePairType(expectedName, expectedValue, lOptional) => {
              left match {
                case NameValuePairType(assignedName, assignedValue, rOptional) => {
                  (rOptional == lOptional) && equalsWith(assignedName, expectedName, recursionDetector) && equalsWith(assignedValue, expectedValue, recursionDetector)
                }
                case _ => false
              }
            }
            case NameType(rName) => {
              left match {
                case NameType(lName) => {
                  rName.isDefined == lName.isDefined &&
                    (rName.isEmpty || rName.get.equals(lName.get))
                }
                case _ => false
              }
            }
            case ArrayType(rType) => {
              left match {
                case ArrayType(lType) => equalsWith(lType, rType, recursionDetector)
                case _                => false
              }
            }
            case UnionType(rTypes) => {
              left match {
                case UnionType(lTypes) => {
                  compareSeq(lTypes, rTypes, recursionDetector)
                }
                case _ => false
              }
            }
            case TypeType(rTypeType) => {
              left match {
                case TypeType(lTypeType) => equalsWith(lTypeType, rTypeType, recursionDetector)
                case _                   => false
              }
            }
            case rtp: TypeParameter => {
              left match {
                case ltp: TypeParameter => {
                  val isDefined = rtp.bottom.isDefined == ltp.bottom.isDefined
                  if (isDefined && rtp.bottom.isDefined) {
                    equalsWith(ltp.bottom.get, rtp.bottom.get, recursionDetector)
                  } else {
                    isDefined
                  }
                }
                case _ => false
              }
            }
            case FunctionType(typeParams, rArguments, rReturnType, rOverloads, _, _) => {
              left match {
                case FunctionType(typeParams, lArguments, lReturnType, lOverloads, _, _) => {
                  if (rOverloads.isEmpty && lOverloads.isEmpty) {
                    if (equalsWith(rReturnType, lReturnType, recursionDetector) && rArguments.size == lArguments.size) {
                      rArguments.isEmpty || rArguments.zip(lArguments).forall((t) => equalsWith(t._1.wtype, t._2.wtype, recursionDetector))
                    } else {
                      false
                    }
                  } else {
                    compareSeq(lOverloads, rOverloads, recursionDetector)
                  }
                }
                case _ => false
              }
            }
            case StringType(rvalue) => {
              left match {
                case StringType(lvalue) => rvalue == lvalue
                case _                  => false
              }
            }
            case NumberType(rvalue) => {
              left match {
                case NumberType(lvalue) => rvalue == lvalue
                case _                  => false
              }
            }
            case BooleanType(rvalue, _) => {
              left match {
                case BooleanType(lvalue, _) => rvalue == lvalue
                case _                      => false
              }
            }
            case _ => left.getClass.isInstance(right)
          }
        }
      }
    }
  }

  private def compareSeq(lTypes: Seq[WeaveType], rTypes: Seq[WeaveType], recursionDetector: RecursionDetector[Boolean]) = {
    (rTypes.size == lTypes.size) &&
      (rTypes.isEmpty || lTypes.zip(rTypes).forall((lrType) => equalsWith(lrType._1, lrType._2, recursionDetector)))
  }

  /**
    * Travers the WeaveType collecting all the types that the callback returns.
    * In the case of reference types it will only call the callback on the referenced type IF it doesn't have type arguments.
    * In the case it has it will call over the type params
    *
    * @param weaveType the type to traverse
    * @param callback  The callback that returns the types
    * @tparam T The type to be returns
    * @return The Seq with all the collected items
    */
  def shallowCollectAll[T](weaveType: WeaveType, callback: (WeaveType) => Seq[T]): Seq[T] = {
    flatMap(weaveType, callback, RecursionDetector[Seq[T]]((_, _) => Seq[T]()))
  }

  private def flatMap[T](weaveType: WeaveType, callback: (WeaveType) => Seq[T], stack: RecursionDetector[Seq[T]]): Seq[T] = {
    weaveType match {
      case ot @ ObjectType(properties, close, _) => {
        callback(ot) ++ properties.flatMap(flatMap(_, callback, stack))
      }
      case kvt @ KeyValuePairType(key, value, _, _) => {
        callback(kvt) ++ flatMap(key, callback, stack) ++ flatMap(value, callback, stack)
      }
      case simpleReferenceType: SimpleReferenceType => {
        val baseResult = callback(simpleReferenceType)
        val recResult = if (simpleReferenceType.typeParams.isEmpty) {
          // We treat it as a simple node as we don't want to follow reference
          stack.resolve(simpleReferenceType, (referencedType) => callback(referencedType))
        } else {
          simpleReferenceType.typeParams.getOrElse(Seq.empty)
            .flatMap((tp) => flatMap(tp, callback, stack))
        }
        baseResult ++ recResult
      }
      case tst: TypeSelectorType => {
        callback(tst) ++ flatMap(tst.referencedType, callback, stack)
      }
      case kt @ KeyType(name, attrs) => {
        callback(kt) ++ flatMap(name, callback, stack) ++ attrs.flatMap((nvp) => flatMap(nvp, callback, stack))
      }
      case nvpt @ NameValuePairType(name, value, _) => {
        callback(nvpt) ++ flatMap(name, callback, stack) ++ flatMap(value, callback, stack)
      }
      case tt @ TypeType(t)          => callback(tt) ++ flatMap(t, callback, stack)
      case at @ ArrayType(of)        => callback(at) ++ flatMap(of, callback, stack)
      case ut @ UnionType(of)        => callback(ut) ++ of.flatMap(flatMap(_, callback, stack))
      case it @ IntersectionType(of) => callback(it) ++ of.flatMap(flatMap(_, callback, stack))
      case ft @ FunctionType(_, params, returnType, overloads, _, _) => {
        val baseResult = callback(ft)
        val recResult = if (overloads.isEmpty) {
          params.flatMap((param) => flatMap(param.wtype, callback, stack)) ++
            flatMap(returnType, callback, stack)
        } else {
          overloads.flatMap(flatMap(_, callback, stack))
        }
        baseResult ++ recResult
      }
      case _ => callback(weaveType)
    }
  }

  def exists(weaveType: WeaveType, callback: (WeaveType) => Boolean, stack: RecursionDetector[Boolean] = RecursionDetector((_, _) => false)): Boolean = {
    callback(weaveType) || {
      weaveType match {
        case ObjectType(properties, _, _) => {
          properties.exists(exists(_, callback, stack))
        }
        case KeyValuePairType(key, value, _, _) => {
          exists(key, callback, stack) || exists(value, callback, stack)
        }
        case rt: ReferenceType => {
          stack.resolve(rt, (referencedType) => exists(referencedType, callback, stack))

        }
        case KeyType(name, attrs) => {
          exists(name, callback, stack) || attrs.exists((nvp) => exists(nvp, callback, stack))
        }
        case NameValuePairType(name, value, _) => {
          exists(name, callback, stack) || exists(value, callback, stack)
        }
        case TypeType(t)            => exists(t, callback, stack)
        case drt: DynamicReturnType => drt.typeParameters.exists((param) => exists(param.wtype, callback, stack))
        case ArrayType(of)          => exists(of, callback, stack)
        case UnionType(of)          => of.exists(exists(_, callback, stack))
        case IntersectionType(of)   => of.exists(exists(_, callback, stack))
        case FunctionType(_, params, returnType, overloads, _, _) => {
          if (overloads.isEmpty) {
            params.exists((param) => exists(param.wtype, callback, stack)) ||
              exists(returnType, callback, stack)
          } else {
            overloads.exists(exists(_, callback, stack))
          }
        }
        case _ => false
      }
    }
  }

}

class RecursionDetector[T](recursionShortCircuit: (NameIdentifier, () => T) => T, val stack: mutable.Stack[WeaveType] = mutable.Stack()) {

  val references: IdentityHashMap[WeaveType, T] = IdentityHashMap()

  private def push(id: WeaveType): Unit = {
    stack.push(id)
  }

  def alreadyInStack(id: WeaveType): Boolean = {
    stack.exists(_ eq id)
  }

  private def pop(newType: T): Unit = {
    val oldType: WeaveType = stack.pop()
    references.+=((oldType, newType))
  }

  private def resolveRecursionShortCircuit(id: NameIdentifier, weaveType: () => T): T = {
    recursionShortCircuit(id, weaveType)
  }

  def resolve(rt: ReferenceType, processor: (WeaveType) => T): T = {
    val referencedType = rt.resolveType()
    if (references.contains(referencedType)) {
      references(referencedType)
    } else if (alreadyInStack(referencedType)) {
      resolveRecursionShortCircuit(rt.nameIdentifier(), () => references(referencedType))
    } else {
      push(referencedType)
      val result = processor(referencedType)
      pop(result)
      result
    }
  }
}

object RecursionDetector {
  def apply[T](recursiveResultCallBack: (NameIdentifier, () => T) => T): RecursionDetector[T] = {
    new RecursionDetector[T](recursiveResultCallBack)
  }

  def withParent[T](recursiveResultCallBack: (NameIdentifier, () => T) => T, parent: RecursionDetector[T]): RecursionDetector[T] = {
    val stack = new mutable.Stack[WeaveType]()
    stack.pushAll(parent.stack)
    new RecursionDetector[T](recursiveResultCallBack, stack)
  }
}

class DoubleRecursionDetector[T](val recursionShortCircuit: (WeaveTypePair, () => T) => T, val storeResults: Boolean = true) {

  private val resultForReference: mutable.HashMap[WeaveTypePair, T] = mutable.HashMap()
  private val stack: mutable.Stack[WeaveTypePair] = mutable.Stack()

  private def push(id: WeaveTypePair): Unit = {
    stack.push(id)
  }

  def alreadyInStack(id: WeaveTypePair): Boolean = {
    stack.contains(id)
  }

  private def pop(value: T): Unit = {
    val wtp: WeaveTypePair = stack.pop()
    if (storeResults) {
      resultForReference.+=((wtp, value))
    }
  }

  private def resolveRecursionShortCircuit(pair: WeaveTypePair, weaveType: () => T): T = {
    recursionShortCircuit(pair, weaveType)
  }

  def resolve(leftType: WeaveType, rightType: WeaveType, processor: (WeaveTypePair) => T): T = {

    val wtp = (leftType, rightType) match {
      case (l: ReferenceType, r: ReferenceType) => {
        new WeaveTypePair(l.resolveType(), r.resolveType())
      }
      case (l: WeaveType, r: ReferenceType) => {
        new WeaveTypePair(l, r.resolveType())
      }
      case (l: ReferenceType, r: WeaveType) => {
        new WeaveTypePair(l.resolveType(), r)
      }
      case (l: WeaveType, r: WeaveType) => {
        new WeaveTypePair(l, r)
      }
    }
    if (alreadyResolved(wtp)) {
      getResult(wtp)
    } else if (alreadyInStack(wtp)) {
      resolveRecursionShortCircuit(new WeaveTypePair(leftType, rightType), () => getResult(wtp))
    } else {
      push(wtp)
      val result = processor(wtp)
      pop(result)
      result
    }

  }

  private def getResult(referencedType: WeaveTypePair): T = {
    val maybeResult = resultForReference.get(referencedType)
    maybeResult.getOrElse(throw new RuntimeException("Unable to find result for : " + referencedType))
  }

  private def alreadyResolved(referencedType: WeaveTypePair): Boolean = {
    resultForReference.contains(referencedType)
  }

}

object DoubleRecursionDetector {
  def apply[T](recursiveResultCallBack: (WeaveTypePair, () => T) => T): DoubleRecursionDetector[T] = {
    new DoubleRecursionDetector[T](recursiveResultCallBack)
  }
}

class WeaveTypePair(val left: WeaveType, val right: WeaveType) {

  override def toString: String = {
    "l = " + left.toString + " , " + " r = " + right
  }

  def ref(): ReferenceType = {
    left match {
      case rt: ReferenceType => rt
      case _                 => right.asInstanceOf[ReferenceType]
    }
  }

  override def equals(obj: Any): Boolean = {
    obj match {
      case wtp: WeaveTypePair => {
        (wtp.left eq this.left) && (wtp.right eq this.right)
      }
      case _ => false
    }
  }

  override def hashCode(): Int = {
    val result = System.identityHashCode(left)
    31 * result + System.identityHashCode(right)
  }
}