package org.mule.weave.v2.runtime.core.operator.relational

import org.mule.weave.v2.core.functions.BinaryFunctionValue
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.capabilities.UnknownLocationCapable
import org.mule.weave.v2.model.types.AnyType
import org.mule.weave.v2.model.types.DateTimeType
import org.mule.weave.v2.model.types.LocalDateTimeType
import org.mule.weave.v2.model.types.LocalDateType
import org.mule.weave.v2.model.types.LocalTimeType
import org.mule.weave.v2.model.types.StringType
import org.mule.weave.v2.model.types.TimeType
import org.mule.weave.v2.model.values.BooleanValue
import org.mule.weave.v2.model.values.DateTimeValue
import org.mule.weave.v2.model.values.LocalDateTimeValue
import org.mule.weave.v2.model.values.TimeValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.runtime.core.exception.InvalidComparisonException

import java.time.ZoneId
import java.time.ZoneOffset

trait BaseRelationalOperator extends BinaryFunctionValue {

  override val L = AnyType

  override val R = AnyType

  override def doExecute(leftValue: L.V, rightValue: R.V)(implicit ctx: EvaluationContext): Value[_] = {
    if (leftValue.valueType.baseType == rightValue.valueType.baseType) {
      doCompare(leftValue, rightValue)
    } else if (leftValue.valueType.baseType eq StringType) {
      val rightValueTypeWithValueSchema = rightValue.valueType.withSchema(rightValue.schema)
      val newLeftValue = rightValueTypeWithValueSchema.coerce(leftValue)
      doCompare(newLeftValue, rightValue)
    } else if (rightValue.valueType.baseType eq StringType) {
      val leftValueTypeWithValueSchema = leftValue.valueType.withSchema(leftValue.schema)
      val rightNewValue = leftValueTypeWithValueSchema.coerce(rightValue)
      doCompare(leftValue, rightNewValue)
    } else {
      val leftBaseType = leftValue.valueType.baseType
      val rightBaseType = rightValue.valueType.baseType
      if ((leftBaseType eq LocalDateTimeType) && (rightBaseType eq DateTimeType)) {
        val newLeftValue = localDateTimeToDateTime(leftValue)
        doCompare(newLeftValue, rightValue)
      } else if ((rightBaseType eq LocalDateTimeType) && (leftBaseType eq DateTimeType)) {
        val newRightValue = localDateTimeToDateTime(rightValue)
        doCompare(leftValue, newRightValue)
      } else if ((leftBaseType eq LocalDateType) && ((rightBaseType eq DateTimeType) || (rightBaseType eq LocalDateTimeType))) {
        val newLeftValue = LocalDateTimeValue(LocalDateType.coerce(leftValue).evaluate.atTime(0, 0), UnknownLocationCapable)
        doExecute(newLeftValue, rightValue)
      } else if ((rightBaseType eq LocalDateType) && ((leftBaseType eq DateTimeType) || (leftBaseType eq LocalDateTimeType))) {
        val newRightValue = LocalDateTimeValue(LocalDateType.coerce(rightValue).evaluate.atTime(0, 0), UnknownLocationCapable)
        doExecute(leftValue, newRightValue)
      } else if ((leftBaseType eq TimeType) && (rightBaseType eq LocalTimeType)) {
        val newRightValue = localTimeToTime(rightValue)
        doCompare(leftValue, newRightValue)
      } else if ((rightBaseType eq TimeType) && (leftBaseType eq LocalTimeType)) {
        val newLeftValue = localTimeToTime(leftValue)
        doCompare(newLeftValue, rightValue)
      } else {
        throw new InvalidComparisonException(leftValue.valueType, rightValue.valueType, location())
      }
    }
  }

  private def localDateTimeToDateTime(leftValue: Value[_])(implicit ctx: EvaluationContext) = {
    DateTimeValue(LocalDateTimeType.coerce(leftValue).evaluate.atZone(ZoneId.of("GMT-0")), UnknownLocationCapable)
  }

  private def localTimeToTime(leftValue: Value[_])(implicit ctx: EvaluationContext) = {
    TimeValue(LocalTimeType.coerce(leftValue).evaluate.atOffset(ZoneOffset.UTC), UnknownLocationCapable)
  }

  protected def doCompare(leftValue: Value[_], rightValue: Value[_])(implicit ctx: EvaluationContext): BooleanValue

}
