/*
 * Decompiled with CFR 0.152.
 */
package org.mvnsearch.chatgpt.model.function;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.RecordComponent;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.mvnsearch.chatgpt.model.function.ChatGPTJavaFunction;
import org.mvnsearch.chatgpt.model.function.GPTFunction;
import org.mvnsearch.chatgpt.model.function.Parameter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

public class GPTFunctionUtils {
    private static final Logger log = LoggerFactory.getLogger(GPTFunctionUtils.class);
    public static final ObjectMapper objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false).setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);

    public static Map<String, ChatGPTJavaFunction> extractFunctions(Class<?> clazz) throws Exception {
        HashMap<String, ChatGPTJavaFunction> functionDeclares = new HashMap<String, ChatGPTJavaFunction>();
        Set<Class<?>> classesToSearchForFunctions = GPTFunctionUtils.getAllClassesInType(clazz);
        log.debug("starting to look for functions in " + clazz.getName() + ". found " + classesToSearchForFunctions);
        Set methods = classesToSearchForFunctions.stream().flatMap(cc -> Stream.of(cc.getDeclaredMethods())).filter(m -> m.getAnnotation(GPTFunction.class) != null).collect(Collectors.toSet());
        log.debug("found " + methods.size() + " methods.");
        if (log.isDebugEnabled() && !methods.isEmpty()) {
            log.debug("found " + methods.size() + " methods");
            log.debug("======================================================================");
            log.debug("class " + clazz.getName());
            for (Method m2 : methods) {
                log.debug("\t " + m2.getReturnType().getName() + " " + m2.getName() + "(" + Arrays.toString(m2.getParameterTypes()) + ")");
            }
        }
        for (Method method : methods) {
            GPTFunction gptFunctionAnnotation = method.getAnnotation(GPTFunction.class);
            if (gptFunctionAnnotation == null) continue;
            String functionName = gptFunctionAnnotation.name();
            if (functionName.isEmpty()) {
                functionName = method.getName();
            }
            ChatGPTJavaFunction gptJavaFunction = new ChatGPTJavaFunction();
            gptJavaFunction.setJavaMethod(method);
            gptJavaFunction.setName(functionName);
            gptJavaFunction.setDescription(gptFunctionAnnotation.value());
            Class<?> requestClazz = method.getParameterTypes()[0];
            gptJavaFunction.setParameterType(requestClazz);
            block2: for (Field field : requestClazz.getDeclaredFields()) {
                Parameter functionParamAnnotation = field.getAnnotation(Parameter.class);
                if (functionParamAnnotation == null) continue;
                String fieldName = functionParamAnnotation.name();
                String fieldType = GPTFunctionUtils.getJsonSchemaType(field.getType());
                if (fieldName.isEmpty()) {
                    fieldName = field.getName();
                }
                if (fieldType.equals("array")) {
                    Class<?> actualClazz = GPTFunctionUtils.parseArrayItemClass(field.getGenericType());
                    gptJavaFunction.addArrayProperty(fieldName, GPTFunctionUtils.getJsonSchemaType(actualClazz), functionParamAnnotation.value());
                } else {
                    if (fieldType.equals("object")) {
                        throw new Exception("Object type not supported: " + clazz.getName() + "." + field.getName());
                    }
                    gptJavaFunction.addProperty(fieldName, fieldType, functionParamAnnotation.value());
                }
                if (functionParamAnnotation.required()) {
                    gptJavaFunction.addRequired(fieldName);
                    continue;
                }
                for (Annotation annotation : field.getAnnotations()) {
                    String annotationName = annotation.annotationType().getName().toLowerCase();
                    if (!annotationName.endsWith("nonnull")) continue;
                    gptJavaFunction.addRequired(fieldName);
                    continue block2;
                }
            }
            functionDeclares.put(functionName, gptJavaFunction);
        }
        return functionDeclares;
    }

    public static Set<Class<?>> getAllClassesInType(Class<?> clazz) {
        Class<?> parent;
        log.debug("finding types for " + clazz.getName());
        HashSet all = new HashSet();
        all.add(clazz);
        all.addAll(Arrays.asList(ClassUtils.getAllInterfacesForClass(clazz)));
        Class<?> c = clazz;
        while ((parent = c.getSuperclass()) != null && !parent.equals(Object.class)) {
            all.add(parent);
            c = c.getSuperclass();
        }
        return all;
    }

    private static String getJsonSchemaType(Class<?> clazz) {
        if (clazz.equals(Integer.class) || clazz.equals(Integer.TYPE) || clazz.equals(Long.class) || clazz.equals(Long.TYPE)) {
            return "integer";
        }
        if (clazz.equals(Boolean.class) || clazz.equals(Boolean.TYPE)) {
            return "boolean";
        }
        if (clazz.equals(Double.class) || clazz.equals(Double.TYPE) || clazz.equals(Float.class) || clazz.equals(Float.TYPE)) {
            return "number";
        }
        if (clazz.equals(String.class)) {
            return "string";
        }
        if (clazz.equals(List.class)) {
            return "array";
        }
        return "object";
    }

    public static Object callGPTFunction(Object target, ChatGPTJavaFunction function, String argumentsJson) throws Exception {
        log.info("attempting to call GPTFunction on target [" + target.getClass().getName() + "] with arguments [" + argumentsJson + "]");
        Method javaMethod = function.getJavaMethod();
        ReflectionUtils.makeAccessible((Method)javaMethod);
        Class<?> parameterType = function.getParameterType();
        Object param = objectMapper.readValue(argumentsJson, parameterType);
        return javaMethod.invoke(target, param);
    }

    public static String toTextPlain(@Nullable Object object) {
        if (object != null) {
            if (object instanceof String) {
                return (String)object;
            }
            try {
                String jsonText = objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(object);
                if (jsonText.startsWith("\"") && jsonText.endsWith("\"")) {
                    return jsonText.substring(1, jsonText.length() - 1);
                }
                return jsonText;
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        return "";
    }

    private static Class<?> parseArrayItemClass(Type genericType) {
        String itemClassName = genericType.getTypeName();
        if (itemClassName.contains("<")) {
            itemClassName = itemClassName.substring(itemClassName.indexOf("<") + 1, itemClassName.indexOf(">"));
        }
        return switch (itemClassName) {
            case "java.lang.Integer" -> Integer.class;
            case "java.lang.Long" -> Long.class;
            case "java.lang.Boolean" -> Boolean.class;
            case "java.lang.Double" -> Double.class;
            case "java.lang.Float" -> Float.class;
            case "java.lang.String" -> String.class;
            default -> Object.class;
        };
    }

    public static Object[] convertRecordToArray(@NotNull Object obj) {
        RecordComponent[] fields = obj.getClass().getRecordComponents();
        Object[] args = new Object[fields.length];
        for (int i = 0; i < fields.length; ++i) {
            try {
                Object value;
                args[i] = value = fields[i].getAccessor().invoke(obj, new Object[0]);
                continue;
            }
            catch (Exception ignore) {
                args[i] = null;
            }
        }
        return args;
    }
}

