/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.semantickernel.semanticfunctions;

import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableType;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.hooks.FunctionInvokedEvent;
import com.microsoft.semantickernel.hooks.FunctionInvokingEvent;
import com.microsoft.semantickernel.hooks.KernelHooks;
import com.microsoft.semantickernel.orchestration.FunctionResult;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.semanticfunctions.InputVariable;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionMetadata;
import com.microsoft.semantickernel.semanticfunctions.MethodDetails;
import com.microsoft.semantickernel.semanticfunctions.OutputVariable;
import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

public class KernelFunctionFromMethod<T>
extends KernelFunction<T> {
    private static final Logger LOGGER = LoggerFactory.getLogger(KernelFunctionFromMethod.class);
    private final ImplementationFunc<T> function;

    private KernelFunctionFromMethod(ImplementationFunc<T> implementationFunc, @Nullable String pluginName, String functionName, @Nullable String description, @Nullable List<InputVariable> parameters, OutputVariable<?> returnParameter) {
        super(new KernelFunctionMetadata(pluginName, functionName, description, parameters, returnParameter), null);
        this.function = implementationFunc;
    }

    public static <T> KernelFunction<T> create(Method method, Object target, @Nullable String pluginName, @Nullable String functionName, @Nullable String description, @Nullable List<InputVariable> parameters, @Nullable OutputVariable<?> returnParameter) {
        MethodDetails methodDetails = KernelFunctionFromMethod.getMethodDetails(functionName, method, target);
        if (description == null || description.isEmpty()) {
            description = methodDetails.getDescription();
        }
        if (parameters == null || parameters.isEmpty()) {
            parameters = methodDetails.getParameters();
        }
        if (returnParameter == null) {
            returnParameter = methodDetails.getReturnParameter();
        }
        return new KernelFunctionFromMethod(methodDetails.getFunction(), pluginName, methodDetails.getName(), description, parameters, returnParameter);
    }

    private static MethodDetails getMethodDetails(@Nullable String functionName, Method method, Object target) {
        DefineKernelFunction annotation = method.getAnnotation(DefineKernelFunction.class);
        String description = null;
        String returnDescription = null;
        if (annotation != null) {
            if (!annotation.description().isEmpty()) {
                description = annotation.description();
            }
            if (!annotation.returnDescription().isEmpty()) {
                returnDescription = annotation.returnDescription();
            }
        }
        if (functionName == null || functionName.isEmpty()) {
            functionName = method.getName();
        }
        return new MethodDetails(functionName, description, KernelFunctionFromMethod.getFunction(method, target), KernelFunctionFromMethod.getParameters(method), new OutputVariable(returnDescription, method.getReturnType()));
    }

    private static <T> ImplementationFunc<T> getFunction(Method method, Object instance) {
        return (kernel, function, arguments, variableType, invocationContext) -> {
            KernelHooks kernelHooks;
            InvocationContext context = invocationContext == null ? InvocationContext.builder().build() : invocationContext;
            KernelHooks kernelHooks2 = kernelHooks = context.getKernelHooks() != null ? context.getKernelHooks() : kernel.getGlobalKernelHooks();
            assert (kernelHooks != null) : "getGlobalKernelHooks() should never return null!";
            FunctionInvokingEvent updatedState = kernelHooks.executeHooks(new FunctionInvokingEvent(function, arguments));
            KernelFunctionArguments updatedArguments = updatedState != null ? updatedState.getArguments() : arguments;
            try {
                Mono mono;
                List<Object> args = Arrays.stream(method.getParameters()).map(KernelFunctionFromMethod.getParameters(method, updatedArguments, kernel, context)).collect(Collectors.toList());
                try {
                    mono = method.getReturnType().isAssignableFrom(Mono.class) ? (Mono)method.invoke(instance, args.toArray()) : KernelFunctionFromMethod.invokeAsyncFunction(method, instance, args);
                }
                catch (Exception e) {
                    return Mono.error((Throwable)new SKException("Function threw an exception: " + method.getName(), e));
                }
                return mono.flatMap(it -> {
                    try {
                        return Mono.just((Object)it);
                    }
                    catch (ClassCastException e) {
                        return Mono.error((Throwable)new SKException("Return type does not match the expected type", e));
                    }
                }).map(it -> {
                    if (variableType != null) {
                        if (!variableType.getClazz().isAssignableFrom(it.getClass())) {
                            throw new SKException(String.format("Return parameter type from %s.%s does not match the expected type %s", function.getPluginName(), function.getName(), it.getClass().getName()));
                        }
                        return new FunctionResult<Object>(new ContextVariable<Object>(variableType, it), it);
                    }
                    Class<?> returnParameterType = function.getMetadata().getOutputVariableType().getType();
                    ContextVariableType contextVariableType = KernelFunctionFromMethod.getContextVariableType(context, returnParameterType);
                    if (contextVariableType == null) {
                        contextVariableType = KernelFunctionFromMethod.getDefaultContextVariableType(returnParameterType);
                    }
                    if (contextVariableType != null) {
                        return new FunctionResult<Object>(new ContextVariable<Object>(contextVariableType, it), it);
                    }
                    throw new SKException(String.format("Return parameter type from %s.%s does not match the expected type %s", function.getPluginName(), function.getName(), it.getClass().getName()));
                }).map(it -> {
                    FunctionInvokedEvent updatedResult = kernelHooks.executeHooks(new FunctionInvokedEvent(function, updatedArguments, it));
                    return updatedResult.getResult();
                });
            }
            catch (Exception e) {
                return Mono.error((Throwable)e);
            }
        };
    }

    @Nullable
    private static <T> ContextVariableType<T> getContextVariableType(InvocationContext invocationContext, Class<?> clazz) {
        if (clazz != null) {
            try {
                Class<?> tClazz = clazz;
                ContextVariableType<?> type = invocationContext.getContextVariableTypes().getVariableTypeForClass(tClazz);
                return type;
            }
            catch (SKException | ClassCastException runtimeException) {
                // empty catch block
            }
        }
        return null;
    }

    @Nullable
    private static <T> ContextVariableType<T> getDefaultContextVariableType(Class<?> clazz) {
        if (clazz != null) {
            try {
                Class<?> tClazz = clazz;
                ContextVariableTypeConverter.NoopConverter noopConverter = new ContextVariableTypeConverter.NoopConverter(tClazz);
                return new ContextVariableType(noopConverter, tClazz);
            }
            catch (ClassCastException classCastException) {
                // empty catch block
            }
        }
        return null;
    }

    private static Mono<Object> invokeAsyncFunction(Method method, Object instance, List<Object> args) {
        return Mono.defer(() -> Mono.fromCallable(() -> {
            try {
                if (method.getReturnType().equals(Void.TYPE) || method.getReturnType().equals(Void.class)) {
                    method.invoke(instance, args.toArray());
                    return null;
                }
                return method.invoke(instance, args.toArray());
            }
            catch (InvocationTargetException e) {
                throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, "Function threw an exception: " + method.getName(), e.getCause());
            }
            catch (IllegalAccessException e) {
                throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, "Unable to access function " + method.getName(), e);
            }
        }).flatMap(it -> {
            if (it == null) {
                return Mono.empty();
            }
            return Mono.just((Object)it);
        }).subscribeOn(Schedulers.boundedElastic()));
    }

    @Nullable
    private static Function<Parameter, Object> getParameters(Method method, @Nullable KernelFunctionArguments context, Kernel kernel, InvocationContext invocationContext) {
        return parameter -> {
            if (KernelFunctionArguments.class.isAssignableFrom(parameter.getType())) {
                return context;
            }
            if (Kernel.class.isAssignableFrom(parameter.getType())) {
                return kernel;
            }
            return KernelFunctionFromMethod.getArgumentValue(method, context, parameter, kernel, invocationContext);
        };
    }

    @Nullable
    private static Object getArgumentValue(Method method, @Nullable KernelFunctionArguments context, Parameter parameter, Kernel kernel, InvocationContext invocationContext) {
        ContextVariableType<?> converter;
        Class<?> type;
        KernelFunctionParameter annotation;
        ContextVariable<?> arg;
        String variableName = KernelFunctionFromMethod.getGetVariableName(parameter);
        ContextVariable<?> contextVariable = arg = context == null ? null : context.get(variableName);
        if (arg == null && method.getParameters().length == 1 && context != null) {
            if (context.containsKey("input")) {
                arg = context.get("input");
            } else if (context.size() == 1) {
                arg = context.values().iterator().next();
            }
        }
        if (arg == null && (annotation = parameter.getAnnotation(KernelFunctionParameter.class)) != null) {
            type = annotation.type();
            ContextVariableType<?> cvType = invocationContext.getContextVariableTypes().getVariableTypeForClass(type);
            if (cvType != null) {
                String defaultValue = annotation.defaultValue();
                Object value = cvType.getConverter().fromPromptString(defaultValue);
                arg = ContextVariable.convert(value, type, invocationContext.getContextVariableTypes());
            }
            if (arg != null && "SKFunctionParameters__NO_INPUT_PROVIDED".equals(arg.getValue())) {
                if (!annotation.required()) {
                    return null;
                }
                throw new AIException(AIException.ErrorCodes.INVALID_CONFIGURATION, "Attempted to invoke function " + method.getDeclaringClass().getName() + "." + method.getName() + ". The context variable \"" + variableName + "\" has not been set, and no default value is specified.");
            }
        }
        if (arg == null && variableName.matches("arg\\d")) {
            LOGGER.warn(KernelFunctionFromMethod.formErrorMessage(method, parameter));
        }
        if (arg != null && "SKFunctionParameters__NO_INPUT_PROVIDED".equals(arg.getValue())) {
            if (parameter.getName().matches("arg\\d")) {
                throw new AIException(AIException.ErrorCodes.INVALID_CONFIGURATION, KernelFunctionFromMethod.formErrorMessage(method, parameter));
            }
            throw new AIException(AIException.ErrorCodes.INVALID_CONFIGURATION, "Unknown arg " + parameter.getName());
        }
        if (Kernel.class.isAssignableFrom(parameter.getType())) {
            return kernel;
        }
        annotation = parameter.getAnnotation(KernelFunctionParameter.class);
        if (annotation == null || annotation.type() == null) {
            return arg;
        }
        type = annotation.type();
        if (!parameter.getType().isAssignableFrom(type)) {
            throw new AIException(AIException.ErrorCodes.INVALID_CONFIGURATION, "Annotation on method: " + method.getName() + " requested conversion to type: " + type.getName() + ", however this cannot be assigned to parameter of type: " + parameter.getType());
        }
        Object value = arg;
        if (arg != null) {
            if (parameter.getType().isAssignableFrom(arg.getType().getClazz())) {
                return arg.getValue();
            }
            if (KernelFunctionFromMethod.isPrimative(arg.getType().getClazz(), parameter.getType())) {
                return arg.getValue();
            }
            ContextVariableTypeConverter<?> c = arg.getType().getConverter();
            Object converted = c.toObject(invocationContext.getContextVariableTypes(), arg.getValue(), parameter.getType());
            if (converted != null) {
                return converted;
            }
        }
        if ((converter = invocationContext.getContextVariableTypes().getVariableTypeForClass(type)) != null) {
            try {
                value = converter.getConverter().fromObject(arg);
            }
            catch (NumberFormatException nfe) {
                throw new AIException(AIException.ErrorCodes.INVALID_CONFIGURATION, "Invalid value for " + parameter.getName() + " expected " + type.getSimpleName() + " but got " + arg);
            }
        }
        if (value == null && type.equals(String.class) && arg != null) {
            ContextVariableTypeConverter<?> c = arg.getType().getConverter();
            value = c.toPromptString(invocationContext.getContextVariableTypes(), arg.getValue());
        }
        return value;
    }

    private static boolean isPrimative(Class<?> argType, Class<?> param) {
        return !((argType != Byte.class && argType != Byte.TYPE || param != Byte.class && param != Byte.TYPE) && (argType != Integer.class && argType != Integer.TYPE || param != Integer.class && param != Integer.TYPE) && (argType != Long.class && argType != Long.TYPE || param != Long.class && param != Long.TYPE) && (argType != Double.class && argType != Double.TYPE || param != Double.class && param != Double.TYPE) && (argType != Float.class && argType != Float.TYPE || param != Float.class && param != Float.TYPE) && (argType != Short.class && argType != Short.TYPE || param != Short.class && param != Short.TYPE) && (argType != Boolean.class && argType != Boolean.TYPE || param != Boolean.class && param != Boolean.TYPE) && (argType != Character.class && argType != Character.TYPE || param != Character.class && param != Character.TYPE));
    }

    private static String getGetVariableName(Parameter parameter) {
        KernelFunctionParameter annotation = parameter.getAnnotation(KernelFunctionParameter.class);
        if (annotation == null || annotation.name() == null || annotation.name().isEmpty()) {
            return parameter.getName();
        }
        return annotation.name();
    }

    private static String formErrorMessage(Method method, Parameter parameter) {
        Matcher matcher = Pattern.compile("arg(\\d)").matcher(parameter.getName());
        matcher.find();
        return "For the function " + method.getDeclaringClass().getName() + "." + method.getName() + ", the unknown parameter name was detected as \"" + parameter.getName() + "\" this is argument number " + matcher.group(1) + " to the function, this indicates that the argument name for this function was removed during compilation and semantic-kernel is unable to determine the name of the parameter. To support this function the argument must be annotated with @SKFunctionParameters or @SKFunctionInputAttribute. Alternatively the function was invoked with a required context variable missing and no default value.";
    }

    private static List<InputVariable> getParameters(Method method) {
        return Arrays.stream(method.getParameters()).map(KernelFunctionFromMethod::toKernelParameterMetadata).filter(Objects::nonNull).collect(Collectors.toList());
    }

    @Nullable
    private static InputVariable toKernelParameterMetadata(Parameter parameter) {
        KernelFunctionParameter annotation = parameter.getAnnotation(KernelFunctionParameter.class);
        String name = parameter.getName();
        String description = null;
        String defaultValue = null;
        boolean isRequired = true;
        Class<?> type = parameter.getType();
        if (Kernel.class.isAssignableFrom(type) || KernelFunctionArguments.class.isAssignableFrom(type)) {
            return null;
        }
        if (annotation != null) {
            name = annotation.name();
            description = annotation.description();
            defaultValue = annotation.defaultValue();
            isRequired = annotation.required();
            type = annotation.type();
        }
        List<String> enumValues = KernelFunctionFromMethod.getEnumOptions(type);
        return InputVariable.build(name, type, description, defaultValue, enumValues, isRequired);
    }

    @Nullable
    public static List<String> getEnumOptions(Class<?> type) {
        List enumValues = null;
        if (type.isEnum()) {
            enumValues = Arrays.stream(type.getEnumConstants()).map(it -> it.toString()).collect(Collectors.toList());
        }
        return enumValues;
    }

    public static <T> Builder<T> builder() {
        return new Builder();
    }

    @Override
    public Mono<FunctionResult<T>> invokeAsync(Kernel kernel, @Nullable KernelFunctionArguments arguments, @Nullable ContextVariableType<T> variableType, @Nullable InvocationContext invocationContext) {
        return this.function.invokeAsync(kernel, this, arguments, variableType, invocationContext);
    }

    public static class Builder<T> {
        @Nullable
        private Method method;
        @Nullable
        private Object target;
        @Nullable
        private String pluginName;
        @Nullable
        private String functionName;
        @Nullable
        private String description;
        @Nullable
        private List<InputVariable> parameters;
        @Nullable
        private OutputVariable<?> returnParameter;

        @SuppressFBWarnings(value={"EI_EXPOSE_REP2"})
        public Builder<T> withMethod(Method method) {
            this.method = method;
            return this;
        }

        public Builder<T> withTarget(Object target) {
            this.target = target;
            return this;
        }

        public Builder<T> withPluginName(String pluginName) {
            this.pluginName = pluginName;
            return this;
        }

        public Builder<T> withFunctionName(String functionName) {
            this.functionName = functionName;
            return this;
        }

        public Builder<T> withDescription(String description) {
            this.description = description;
            return this;
        }

        public Builder<T> withParameters(List<InputVariable> parameters) {
            this.parameters = new ArrayList<InputVariable>(parameters);
            return this;
        }

        public Builder<T> withReturnParameter(OutputVariable<?> returnParameter) {
            this.returnParameter = returnParameter;
            return this;
        }

        public KernelFunction<T> build() {
            if (this.method == null) {
                throw new SKException("To build a KernelFunctionFromMethod, a method must be provided");
            }
            if (this.target == null) {
                throw new SKException("To build a plugin object must be provided");
            }
            return KernelFunctionFromMethod.create(this.method, this.target, this.pluginName, this.functionName, this.description, this.parameters, this.returnParameter);
        }
    }

    public static interface ImplementationFunc<T> {
        public Mono<FunctionResult<T>> invokeAsync(Kernel var1, KernelFunction<T> var2, @Nullable KernelFunctionArguments var3, @Nullable ContextVariableType<T> var4, @Nullable InvocationContext var5);

        default public FunctionResult<T> invoke(Kernel kernel, KernelFunction<T> function, @Nullable KernelFunctionArguments arguments, @Nullable ContextVariableType<T> variableType, @Nullable InvocationContext invocationContext) {
            return (FunctionResult)this.invokeAsync(kernel, function, arguments, variableType, invocationContext).block();
        }
    }
}

