package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.scope.Reference
import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.{ BooleanType, EdgeLabels, TypeNode, WeaveType, WeaveTypeResolutionContext, WeaveTypeResolver }
import org.mule.weave.v2.utils.DataGraphDotEmitter

import scala.collection.{ Seq, breakOut }

case class RetyperResolver(branch: Boolean, ref: Reference) extends WeaveTypeResolver {
  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val originalType = node.incomingEdges().filter(_.label.contains(EdgeLabels.ORIGINAL_TYPE))
    val branchEdge = node.incomingEdges().filter(_.label.contains(EdgeLabels.BRANCH_CONSTRAINT))
    val result = originalType match {
      case Seq(x) => x.mayBeIncomingType()
      case _ =>
        if (ctx.currentParsingContext.strictMode) {
          if (System.getProperty("type_check_debug") != null) {
            println(DataGraphDotEmitter.print(ctx.rootGraph, name = "Failure_Graph"))
          }
          throw new RuntimeException(s"PassThrough resolver only works with nodes with one edge but found '${node.incomingEdges().size}' with node ${node.astNode} at ${node.astNode.location().locationString}.")
        } else {
          None
        }
    }
    branchEdge match {
      case Seq(x) =>
        x.mayBeIncomingType() match {
          case Some(BooleanType(_, constraints)) if result.isDefined =>
            Some(
              if (branch)
                constraints.enhancePositive(ref, result.get, ctx)
              else
                constraints.enhanceNegative(ref, result.get, ctx))
          case _ => result
        }
      case _ => result
    }
  }

  override def resolveExpectedType(node: TypeNode, incomingExpectedType: Option[WeaveType], ctx: WeaveTypeResolutionContext): Seq[(Edge, WeaveType)] = {
    if (incomingExpectedType.isDefined) {
      node.incomingEdges().map((_, incomingExpectedType.get))
    } else {
      Seq()
    }
  }
}
