/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.module.extension.internal.loader.utils;

import static java.lang.Enum.valueOf;
import static java.lang.String.format;
import static java.lang.reflect.Proxy.newProxyInstance;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toMap;

import static net.bytebuddy.jar.asm.Opcodes.ASM5;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.sdk.api.annotation.JavaVersionSupport;
import org.mule.sdk.api.meta.JavaVersion;

import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import net.bytebuddy.jar.asm.AnnotationVisitor;
import net.bytebuddy.jar.asm.ClassReader;
import net.bytebuddy.jar.asm.ClassVisitor;
import net.bytebuddy.jar.asm.Type;
import org.slf4j.Logger;

/**
 * Utility to extract annotations from a {@link Class} by introspecting it instead of directly accessing it.
 */
public class AnnotationsIntrospectorUtils {

  private static final Logger LOGGER = getLogger(AnnotationsIntrospectorUtils.class);

  public static Map<Class<? extends Annotation>, Annotation> extractAnnotations(Class<?> clazz) {
    List<Object> annotations = new ArrayList<>();

    try (InputStream classStream = getClassAsStream(clazz)) {
      ClassReader classReader = new ClassReader(classStream);
      int asmLevel = ASM5;
      classReader.accept(new ClassVisitor(asmLevel) {

        @Override
        public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
          return getAnnotationVisitor(descriptor, asmLevel, annotations, clazz.getClassLoader());
        }
      }, 0);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }

    return annotations.stream()
        .collect(toMap(annotation -> ((Annotation) annotation).annotationType(), annotation -> (Annotation) annotation));
  }

  private static AnnotationVisitor getAnnotationVisitor(String descriptor, int asmLevel, List<Object> annotations,
                                                        ClassLoader classLoader) {
    String annotationClassName = descriptorToClassName(descriptor);
    try {
      Class<? extends Annotation> annotationClass = (Class<? extends Annotation>) classLoader.loadClass(annotationClassName);

      return getAnnotationVisitor(annotationClass, asmLevel, annotations, false, classLoader);
    } catch (Exception e) {
      throw new RuntimeException(format("Unexpected error while parsing annotation '%s'", annotationClassName), e);
    }
  }

  private static AnnotationVisitor getAnnotationVisitor(Class<? extends Annotation> annotationClass, int asmLevel,
                                                        List<Object> values, boolean isArray, ClassLoader classLoader) {
    Map<String, Object> annotationValues = new HashMap<>();

    return new AnnotationVisitor(asmLevel) {

      @Override
      public AnnotationVisitor visitAnnotation(String name, String descriptor) {
        return getAnnotationVisitor(descriptor, asmLevel, values, classLoader);
      }

      @Override
      public void visit(String name, Object value) {
        Object processedValue = value;
        if (value instanceof Type) {
          Type type = (Type) value;
          try {
            processedValue = classLoader.loadClass(type.getClassName());
          } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
          }
        }

        if (isArray) {
          values.add(processedValue);
        } else {
          annotationValues.put(name, processedValue);
        }
      }

      @Override
      public void visitEnum(String name, String descriptor, String value) {
        String enumClassName = descriptorToClassName(descriptor);
        try {
          Class<?> enumClass = classLoader.loadClass(enumClassName);
          Object enumValue = valueOf((Class<Enum>) enumClass, value);
          if (isArray) {
            values.add(enumValue);
          } else {
            annotationValues.put(name, enumValue);
          }
        } catch (IllegalArgumentException e) {
          if (!enumClassName.equals("org.mule.sdk.api.annotation.JavaVersionSupport")) {
            LOGGER.warn("Error while loading value '{}' for Enum '{}'", value, enumClassName);
          }

          // Invalid value found for `JavaVersionSupport`, ignore it
          LOGGER.debug("Invalid JavaVersionSupport value: '{}'", value);
        } catch (Exception e) {
          throw new RuntimeException(format("Unexpected error while parsing Enum '%s'", enumClassName), e);
        }
      }

      @Override
      public AnnotationVisitor visitArray(String name) {
        List<Object> arrayValues = new ArrayList<>();
        annotationValues.put(name, arrayValues);

        return getAnnotationVisitor(annotationClass, asmLevel, arrayValues, true, classLoader);
      }

      @Override
      public void visitEnd() {
        if (!isArray) {
          stream(annotationClass.getDeclaredMethods())
              .filter(method -> annotationValues.keySet().stream().noneMatch(valueName -> valueName.equals(method.getName())))
              .forEach(method -> annotationValues.put(method.getName(), method.getDefaultValue()));
          annotationValues.forEach((key, value) -> {
            if (value instanceof List) {
              List<?> list = (List<?>) value;
              Class<?> componentType;
              try {
                componentType = annotationClass.getDeclaredMethod(key).getReturnType().getComponentType();
              } catch (NoSuchMethodException e) {
                throw new RuntimeException(e);
              }
              Object typedArray = Array.newInstance(componentType, list.size());
              for (int i = 0; i < list.size(); i++) {
                Array.set(typedArray, i, list.get(i));
              }

              annotationValues.put(key, typedArray);
            }
          });
          // Create annotation proxy
          values.add(createAnnotationProxy(annotationClass, annotationValues));
        }
      }
    };
  }

  public static JavaVersion[] extractJavaVersionValues(Class<?> clazz) throws Exception {
    try (InputStream classStream = getClassAsStream(clazz)) {
      List<JavaVersion> javaVersions = new ArrayList<>();
      ClassReader reader = new ClassReader(classStream);
      int asmVersion = ASM5;
      reader.accept(new ClassVisitor(asmVersion) {

        @Override
        public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
          return new AnnotationVisitor(asmVersion) {

            @Override
            public AnnotationVisitor visitArray(String name) {
              return new AnnotationVisitor(asmVersion) {

                @Override
                public void visitEnum(String name, String desc, String value) {
                  try {
                    if (descriptor.endsWith(JavaVersionSupport.class.getSimpleName() + ";")) {
                      javaVersions.add(JavaVersion.valueOf(value));
                    }
                  } catch (IllegalArgumentException e) {
                    LOGGER.debug("Found unknown value '{}' in JavaVersionSupport annotation", value, e);
                  }
                }
              };
            }
          };
        }
      }, 0);

      return javaVersions.toArray(new JavaVersion[javaVersions.size()]);
    }
  }

  private static InputStream getClassAsStream(Class<?> clazz) {
    String className = clazz.getName().replace('.', '/') + ".class";
    InputStream classStream = clazz.getClassLoader().getResourceAsStream(className);
    if (classStream == null) {
      throw new RuntimeException(format("Unable to load extension class '%s' from class loader '%s'", className,
                                        clazz.getClassLoader()));
    }

    return classStream;
  }

  /**
   * Converts ASM descriptor to class name (e.g., {@code Lcom/example/MyAnnotation;} -> {@code com.example.MyAnnotation}) if
   * corresponds.
   *
   * @param descriptor the ASM descriptor to convert.
   * @return the class name in case the ASM descriptor was a class, or the descriptor otherwise.
   */
  private static String descriptorToClassName(String descriptor) {
    if (descriptor.startsWith("L") && descriptor.endsWith(";")) {
      return descriptor.substring(1, descriptor.length() - 1).replace('/', '.');
    }

    // Handle primitive types if needed
    switch (descriptor) {
      case "I":
        return "int";
      case "Z":
        return "boolean";
      case "B":
        return "byte";
      case "C":
        return "char";
      case "S":
        return "short";
      case "J":
        return "long";
      case "F":
        return "float";
      case "D":
        return "double";
      default:
        return descriptor;
    }
  }

  private static <A extends Annotation> A createAnnotationProxy(Class<A> annotationClass, Map<String, Object> values) {
    return (A) newProxyInstance(
                                annotationClass.getClassLoader(),
                                new Class[] {annotationClass},
                                new CustomAnnotationInvocationHandler(annotationClass, values));
  }

  private static class CustomAnnotationInvocationHandler implements InvocationHandler {

    private final Class<? extends Annotation> annotationType;
    private final Map<String, Object> memberValues;

    public CustomAnnotationInvocationHandler(Class<? extends Annotation> annotationType,
                                             Map<String, Object> memberValues) {
      this.annotationType = annotationType;
      this.memberValues = memberValues;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) {
      String methodName = method.getName();

      switch (methodName) {
        case "equals":
          return annotationEquals(args[0]);
        case "hashCode":
          return annotationHashCode();
        case "toString":
          return annotationToString();
        case "annotationType":
          return annotationType;
        default:
          Object value = memberValues.get(methodName);
          return value != null ? value : method.getDefaultValue();
      }
    }

    private boolean annotationEquals(Object obj) {
      if (obj == this) {
        return true;
      }
      if (!annotationType.isInstance(obj)) {
        return false;
      }

      for (Method method : annotationType.getDeclaredMethods()) {
        if (method.getParameterCount() == 0) {
          try {
            Object thisValue = memberValues.get(method.getName());
            Object otherValue = method.invoke(obj);

            if (!deepEquals(thisValue, otherValue)) {
              return false;
            }
          } catch (Exception e) {
            return false;
          }
        }
      }
      return true;
    }

    private boolean deepEquals(Object o1, Object o2) {
      if (o1 == o2) {
        return true;
      }
      if (o1 == null || o2 == null) {
        return false;
      }

      if (o1.getClass().isArray() && o2.getClass().isArray()) {
        return Arrays.deepEquals((Object[]) o1, (Object[]) o2);
      }

      return o1.equals(o2);
    }

    public int annotationHashCode() {
      int result = 0;
      for (Map.Entry<String, Object> entry : memberValues.entrySet()) {
        result += (127 * entry.getKey().hashCode()) ^
            (entry.getValue() != null ? entry.getValue().hashCode() : 0);
      }
      return result;
    }

    public String annotationToString() {
      StringBuilder sb = new StringBuilder("@");
      sb.append(annotationType.getName()).append("(");

      boolean first = true;
      boolean loneValue = memberValues.size() == 1;
      for (Map.Entry<String, Object> entry : memberValues.entrySet()) {
        if (!first) {
          sb.append(", ");
        }
        String key = entry.getKey();
        if (!loneValue || !"value".equals(key)) {
          sb.append(key).append('=');
        }
        sb.append(memberValueToString(entry.getValue()));
        first = false;
      }

      sb.append(")");
      return sb.toString();
    }

    private String memberValueToString(Object value) {
      return value.getClass().isArray() ? stream((Object[]) value).map(Object::toString).collect(joining(", ", "{", "}"))
          : value.toString();
    }
  }

}
