package org.mule.weave.v2.module.java

import org.mule.weave.v2.core.functions.SecureQuaternaryFunctionValue
import org.mule.weave.v2.model.DefaultEvaluationContext
import org.mule.weave.v2.model.EvaluationContext
import org.mule.weave.v2.model.capabilities.UnknownLocationCapable
import org.mule.weave.v2.model.service.WeaveRuntimePrivilege
import org.mule.weave.v2.model.structure.KeyValuePair
import org.mule.weave.v2.model.types.ArrayType
import org.mule.weave.v2.model.types.ObjectType
import org.mule.weave.v2.model.types.StringType
import org.mule.weave.v2.model.values.ArrayValue
import org.mule.weave.v2.model.values.ObjectValue
import org.mule.weave.v2.model.values.Value
import org.mule.weave.v2.model.values.wrappers.WrapperValue
import org.mule.weave.v2.module.commons.java.value.JavaSchema
import org.mule.weave.v2.module.commons.java.value.JavaValue
import org.mule.weave.v2.module.commons.java.writer.JavaAdapter
import org.mule.weave.v2.module.pojo.ClassLoaderService
import org.mule.weave.v2.module.pojo.DefaultClassLoaderService
import org.mule.weave.v2.module.pojo.JavaDataFormat
import org.mule.weave.v2.module.pojo.exception.JavaFunctionCallException
import org.mule.weave.v2.module.reader.SourceProvider
import org.mule.weave.v2.module.writer.Writer
import org.mule.weave.v2.parser.exception.LocatableException

import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method

class JavaInvokeFunction extends SecureQuaternaryFunctionValue {

  override val First = StringType

  override val Second = StringType

  override val Third = ArrayType

  override val Forth = ArrayType

  override val requiredPrivilege: WeaveRuntimePrivilege = WeaveRuntimePrivilege.JAVA_REFLECTION

  override def onSecureExecution(classNameValue: First.V, secondValue: Second.V, thirdValue: Third.V, forthValue: Forth.V)(implicit ctx: EvaluationContext): Value[_] = {
    val className: String = classNameValue.evaluate.toString
    val methodName: String = secondValue.evaluate.toString
    val paramClasses = thirdValue.evaluate.toSeq().map((value) => StringType.coerce(value).evaluate.toString)
    try {
      val paramTypes: Seq[Class[_]] = paramClasses.map((value) => JavaInvocationHelper.loadClass(value))
      val classToCall = JavaInvocationHelper.loadClass(className)
      val argsSeq: Seq[Value[_]] = forthValue.evaluate.toSeq()
      //We should cash this to avoid looking the method over and over
      if (methodName.equals("new")) {
        val method = classToCall.getDeclaredConstructor(paramTypes: _*)
        val javaArguments = JavaInvocationHelper.transformArgumentsToJava(ctx, paramTypes, argsSeq)
        val instance: AnyRef = method.newInstance(javaArguments: _*).asInstanceOf[AnyRef]
        val reader = JavaDataFormat.reader(SourceProvider(instance))
        reader.read(s"$className.$methodName")
      } else {
        val method: Method = classToCall.getDeclaredMethod(methodName, paramTypes: _*)
        val javaArguments = JavaInvocationHelper.transformArgumentsToJava(DefaultEvaluationContext(serviceManager = ctx.serviceManager), paramTypes, argsSeq)
        val value = method.invoke(null, javaArguments: _*)
        val reader = JavaDataFormat.reader(SourceProvider(value))
        reader.read(s"$className.$methodName")
      }
    } catch {
      case ie: InvocationTargetException => throw new JavaFunctionCallException(methodName, className, paramClasses, forthValue, ie.getCause)
      case le: LocatableException        => throw le
      case exception: Exception          => throw new JavaFunctionCallException(methodName, className, paramClasses, forthValue, exception)
    }
  }
}

object JavaInvocationHelper {

  def transformArgumentsToJava(ctx: EvaluationContext, paramTypes: Seq[Class[_]], argsSeq: Seq[Value[_]]): Seq[AnyRef] = {
    argsSeq
      .zip(paramTypes)
      .map((value) => {
        transformToJava(value._1, value._2)(ctx)
      })
  }

  def transformToJavaMap(arg: Value[_], mapValueClass: Class[_], mapClass: Class[_])(implicit ctx: EvaluationContext): AnyRef = {
    val writer: Writer = JavaDataFormat.defaultSettingsWriter(None)
    val map = ObjectType
      .coerce(arg)
      .evaluate
      .toIterator()
      .map((kvp) => {
        KeyValuePair(kvp._1, adaptToJavaClass(kvp._2, mapValueClass))
      })
    val valueToWrite = adaptToJavaClass(ObjectValue(map, UnknownLocationCapable), mapClass)
    writer.writeValue(valueToWrite)(ctx)
    writer.result.asInstanceOf[AnyRef]
  }

  def transformToJavaCollection(argValue: Value[_], item: Class[_], collectionClass: Class[_])(implicit ctx: EvaluationContext): AnyRef = {
    val writer: Writer = JavaDataFormat.defaultSettingsWriter(None)
    val values: Iterator[Value[_]] = ArrayType
      .coerce(argValue)
      .evaluate
      .toIterator()
      .map((value) => {
        adaptToJavaClass(value, item)
      })
    val valueToWrite = adaptToJavaClass(ArrayValue(values, UnknownLocationCapable), collectionClass)
    writer.writeValue(valueToWrite)(ctx)
    writer.result.asInstanceOf[AnyRef]
  }

  def transformToJava(argValue: Value[_], expectedJavaType: Class[_])(implicit ctx: EvaluationContext): AnyRef = {
    val maybeAdapter = ValueToJavaAdaptorManager.findAdapterForClass(expectedJavaType)
    maybeAdapter match {
      case Some(valueAdapter) => {
        valueAdapter.adaptToJavaValue(argValue, expectedJavaType, ctx)
      }
      case None => {
        val writer: Writer = JavaDataFormat.defaultSettingsWriter(Some(expectedJavaType.getClassLoader))
        val valueToWrite: Value[_] = adaptToJavaClass(argValue, expectedJavaType)
        writer.writeValue(valueToWrite)(ctx)
        JavaAdapter.fromScalaToJava(writer.result).asInstanceOf[AnyRef]
      }
    }
  }

  def adaptToJavaClass(argValue: Value[_], expectedJavaType: Class[_])(implicit ctx: EvaluationContext): Value[_] = {
    //Should we inferred the class?
    val hasClassDefined: Boolean = doesHasClassDefined(argValue, expectedJavaType)
    //If it says object don't use it
    val valueToWrite = if (!hasClassDefined && !classOf[Object].equals(expectedJavaType)) {
      argValue.valueType(ctx).baseType.withSchema(Some(JavaSchema(expectedJavaType))).coerce(argValue)(ctx)
    } else {
      argValue
    }
    valueToWrite
  }

  private def doesHasClassDefined(argValue: Value[_], expectedJavaType: Class[_])(implicit ctx: EvaluationContext): Boolean = {
    argValue.schema match {
      case Some(schema) => {
        argValue match {
          case jv: JavaValue[_] => {
            val underlying = jv.underlying()
            expectedJavaType.isInstance(underlying)
          }
          case wv: WrapperValue => doesHasClassDefined(wv.value, expectedJavaType)
          case _ => {
            val maybeString = schema.`class`
            maybeString match {
              case None => false
              case Some(value) => {
                val classLoaderService = ctx.serviceManager.lookupCustomService(classOf[ClassLoaderService], DefaultClassLoaderService)
                classLoaderService.loadClass(value) match {
                  case None              => false
                  case Some(loadedClass) => expectedJavaType.isAssignableFrom(loadedClass)
                }
              }
            }
          }
        }
      }
      case _ => false
    }
  }

  def loadClass(className: String)(implicit ctx: EvaluationContext): Class[_] = {
    ReflectionJavaClassLoaderHelper.loadClass(className, Some(Thread.currentThread().getContextClassLoader), UnknownLocationCapable)
  }
}
