package org.mule.weave.v2.ts.resolvers

import org.mule.weave.v2.grammar.AdditionOpId
import org.mule.weave.v2.grammar.AllAttributesSelectorOpId
import org.mule.weave.v2.grammar.AllSchemaSelectorOpId
import org.mule.weave.v2.grammar.AttributeValueSelectorOpId
import org.mule.weave.v2.grammar.BinaryOpIdentifier
import org.mule.weave.v2.grammar.DescendantsSelectorOpId
import org.mule.weave.v2.grammar.DivisionOpId
import org.mule.weave.v2.grammar.DynamicSelectorOpId
import org.mule.weave.v2.grammar.FilterSelectorOpId
import org.mule.weave.v2.grammar.GreaterOrEqualThanOpId
import org.mule.weave.v2.grammar.GreaterThanOpId
import org.mule.weave.v2.grammar.LeftShiftOpId
import org.mule.weave.v2.grammar.LessOrEqualThanOpId
import org.mule.weave.v2.grammar.LessThanOpId
import org.mule.weave.v2.grammar.MetadataAdditionOpId
import org.mule.weave.v2.grammar.MetadataInjectorOpId
import org.mule.weave.v2.grammar.MinusOpId
import org.mule.weave.v2.grammar.MultiAttributeValueSelectorOpId
import org.mule.weave.v2.grammar.MultiValueSelectorOpId
import org.mule.weave.v2.grammar.MultiplicationOpId
import org.mule.weave.v2.grammar.NamespaceSelectorOpId
import org.mule.weave.v2.grammar.NotOpId
import org.mule.weave.v2.grammar.ObjectKeyValueSelectorOpId
import org.mule.weave.v2.grammar.OpIdentifier
import org.mule.weave.v2.grammar.RangeSelectorOpId
import org.mule.weave.v2.grammar.RightShiftOpId
import org.mule.weave.v2.grammar.SchemaValueSelectorOpId
import org.mule.weave.v2.grammar.SimilarOpId
import org.mule.weave.v2.grammar.SubtractionOpId
import org.mule.weave.v2.grammar.UnaryOpIdentifier
import org.mule.weave.v2.grammar.ValueSelectorOpId
import org.mule.weave.v2.sdk.SystemFunctionDefinitions
import org.mule.weave.v2.ts.Edge
import org.mule.weave.v2.ts.EdgeLabels
import org.mule.weave.v2.ts.FunctionType
import org.mule.weave.v2.ts.TypeNode
import org.mule.weave.v2.ts.WeaveType
import org.mule.weave.v2.ts.WeaveTypeResolutionContext
import org.mule.weave.v2.ts.WeaveTypeResolver
import org.mule.weave.v2.ts.resolvers.FunctionCallNodeResolver.calculateExpectedTypeByParameter

class OpNodeTypeResolver(opIdentifier: OpIdentifier) extends WeaveTypeResolver {

  override def resolveReturnType(node: TypeNode, ctx: WeaveTypeResolutionContext): Option[WeaveType] = {
    val argTypes: Seq[WeaveType] = node.incomingTypes(EdgeLabels.ARGUMENT)
    val functionType: FunctionType = getFunctionType
    val maybeType = FunctionCallNodeResolver.resolveReturnType(functionType, argTypes, Seq(), node, ctx)
    maybeType
  }

  private def getFunctionType(): FunctionType = {
    opIdentifier match {
      case unary: UnaryOpIdentifier =>
        unary match {
          case MinusOpId                 => SystemFunctionDefinitions.minusUnary
          case AllSchemaSelectorOpId     => SystemFunctionDefinitions.allSchemaSelector
          case DescendantsSelectorOpId   => SystemFunctionDefinitions.descendantsSelectorOps
          case NotOpId                   => SystemFunctionDefinitions.notUnary
          case AllAttributesSelectorOpId => SystemFunctionDefinitions.allAttributes
          case NamespaceSelectorOpId     => SystemFunctionDefinitions.namespaceSelector
        }
      case binary: BinaryOpIdentifier =>
        binary match {
          case AttributeValueSelectorOpId      => SystemFunctionDefinitions.attributeSelectorOps
          case MultiValueSelectorOpId          => SystemFunctionDefinitions.multiValueSelectorOps
          case MultiAttributeValueSelectorOpId => SystemFunctionDefinitions.multiAttributeSelectorOps
          case GreaterOrEqualThanOpId          => SystemFunctionDefinitions.comparator
          case AdditionOpId                    => SystemFunctionDefinitions.addition
          case SubtractionOpId                 => SystemFunctionDefinitions.minusOps
          case DivisionOpId                    => SystemFunctionDefinitions.divide
          case MultiplicationOpId              => SystemFunctionDefinitions.multiply
          case RightShiftOpId                  => SystemFunctionDefinitions.rightShift
          case LeftShiftOpId                   => SystemFunctionDefinitions.leftShift
          case DynamicSelectorOpId             => SystemFunctionDefinitions.dynamicSelectorOps
          case GreaterThanOpId                 => SystemFunctionDefinitions.comparator
          case SchemaValueSelectorOpId         => SystemFunctionDefinitions.schemaSelector
          case FilterSelectorOpId              => SystemFunctionDefinitions.filterSelector
          case ValueSelectorOpId               => SystemFunctionDefinitions.valueSelectorOps
          case ObjectKeyValueSelectorOpId      => SystemFunctionDefinitions.objectKeyValueSelectorOps
          case SimilarOpId                     => SystemFunctionDefinitions.comparator
          case RangeSelectorOpId               => SystemFunctionDefinitions.rangeSelectorOps
          case LessThanOpId                    => SystemFunctionDefinitions.comparator
          case LessOrEqualThanOpId             => SystemFunctionDefinitions.comparator
          case MetadataInjectorOpId            => SystemFunctionDefinitions.metadataInjector
          case MetadataAdditionOpId            => SystemFunctionDefinitions.metadataAddition
        }
    }
  }

  override def resolveExpectedType(node: TypeNode, incomingExpectedType: Option[WeaveType], ctx: WeaveTypeResolutionContext): Seq[(Edge, WeaveType)] = {
    val paramsTypes: Seq[WeaveType] = calculateExpectedTypeByParameter(getFunctionType, incomingExpectedType, ctx)
    node.incomingEdges().zip(paramsTypes)
  }
}
