package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.ts.BooleanType
import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.EdgeLabels
import org.mule.weave.v2.ts.NothingType
import org.mule.weave.v2.ts.TypeHelper
import org.mule.weave.v2.ts.TypeNode
import org.mule.weave.v2.ts.WeaveType
import org.mule.weave.v2.ts.WeaveTypeResolutionContext
import org.mule.weave.v2.ts.WeaveTypeResolver

import scala.collection.Seq

object IfElseResolver extends WeaveTypeResolver {
  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val conditionEdge: Edge = node.incomingEdges(EdgeLabels.CONDITION).head
    val ifEdge: Edge = node.incomingEdges(EdgeLabels.IF_LABEL).head
    val elseEdge: Edge = node.incomingEdges(EdgeLabels.ELSE_LABEL).head
    if (ifEdge.incomingTypeDefined() || elseEdge.incomingTypeDefined()) {
      conditionEdge.mayBeIncomingType() match {
        case Some(BooleanType(Some(condition), _)) => {
          if (condition) {
            if (ifEdge.incomingTypeDefined())
              Some(ifEdge.incomingType())
            else
              None
          } else {
            if (elseEdge.incomingTypeDefined()) {
              Some(elseEdge.incomingType())
            } else
              None
          }
        }
        case _ => {
          val ifExpr: WeaveType = if (ifEdge.incomingTypeDefined()) ifEdge.incomingType() else NothingType()
          val elseExpr: WeaveType = if (elseEdge.incomingTypeDefined()) elseEdge.incomingType() else NothingType()
          Some(TypeHelper.unify(Seq(ifExpr, elseExpr)))
        }
      }
    } else {
      None
    }
  }

  override def resolveExpectedType(node: TypeNode, incomingExpectedType: Option[WeaveType], ctx: WeaveTypeResolutionContext): Seq[(Edge, WeaveType)] = {
    if (incomingExpectedType.isDefined) {
      val conditionEdge: Edge = node.incomingEdges(EdgeLabels.CONDITION).head
      val ifEdge: Edge = node.incomingEdges(EdgeLabels.IF_LABEL).head
      val elseEdge: Edge = node.incomingEdges(EdgeLabels.ELSE_LABEL).head
      Seq(
        (conditionEdge, BooleanType()),
        (ifEdge, incomingExpectedType.get),
        (elseEdge, incomingExpectedType.get))
    } else {
      Seq()
    }
  }

  override def supportsPartialResolution(): Boolean = true
}
