/*
 * Decompiled with CFR 0.152.
 */
package top.dreamlike.panama.generator.proxy;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.classfile.AccessFlags;
import java.lang.classfile.ClassBuilder;
import java.lang.classfile.ClassFile;
import java.lang.classfile.CodeBuilder;
import java.lang.constant.ClassDesc;
import java.lang.constant.ConstantDesc;
import java.lang.constant.ConstantDescs;
import java.lang.constant.DynamicCallSiteDesc;
import java.lang.foreign.AddressLayout;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.CallSite;
import java.lang.invoke.ConstantCallSite;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.AccessFlag;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Supplier;
import top.dreamlike.panama.generator.annotation.CLib;
import top.dreamlike.panama.generator.annotation.NativeFunction;
import top.dreamlike.panama.generator.annotation.Pointer;
import top.dreamlike.panama.generator.exception.StructException;
import top.dreamlike.panama.generator.helper.ClassFileHelper;
import top.dreamlike.panama.generator.helper.DowncallContext;
import top.dreamlike.panama.generator.helper.FunctionPointer;
import top.dreamlike.panama.generator.helper.NativeAddressable;
import top.dreamlike.panama.generator.helper.NativeGeneratorHelper;
import top.dreamlike.panama.generator.proxy.InvokeDynamicFactory;
import top.dreamlike.panama.generator.proxy.NativeImageHelper;
import top.dreamlike.panama.generator.proxy.NativeLookup;
import top.dreamlike.panama.generator.proxy.StructProxyGenerator;

public class NativeCallGenerator {
    private static final MethodHandle DLSYM_MH;
    private final NativeLookup nativeLibLookup;
    private final ClassFile classFile = ClassFile.of();
    static final String GENERATOR_FIELD_NAME = "_generator";
    private static final Method GENERATE_IN_GENERATOR_CONTEXT;
    public volatile boolean use_lmf = !NativeImageHelper.inExecutable();
    private Map<String, MemorySegment> foreignFunctionAddressCache = new ConcurrentHashMap<String, MemorySegment>();
    private volatile boolean use_indy = !NativeImageHelper.inExecutable();
    final Map<Class<?>, Supplier<Object>> ctorCaches = new ConcurrentHashMap();
    private final StructProxyGenerator structProxyGenerator;

    public NativeCallGenerator() {
        this.structProxyGenerator = new StructProxyGenerator();
        this.nativeLibLookup = new NativeLookup();
    }

    public NativeCallGenerator(StructProxyGenerator structProxyGenerator) {
        this.structProxyGenerator = structProxyGenerator;
        this.nativeLibLookup = new NativeLookup();
    }

    public static CallSite indyFactory(MethodHandles.Lookup lookup, String methodName, MethodType methodType, Object ... args) throws Throwable {
        Class<?> lookupClass = lookup.lookupClass();
        NativeCallGenerator generator = (NativeCallGenerator)lookupClass.getField(GENERATOR_FIELD_NAME).get(null);
        Class<?> targetInterface = lookupClass.getInterfaces()[0];
        Method method = targetInterface.getMethod(methodName, methodType.parameterArray());
        MethodHandle nativeCallMH = generator.nativeMethodHandle(method, false);
        return new ConstantCallSite(nativeCallMH);
    }

    public void indyMode() {
        this.use_indy = true;
    }

    public <T> T generate(Class<T> nativeInterface) {
        Objects.requireNonNull(nativeInterface);
        try {
            return this.ctorCaches.computeIfAbsent(nativeInterface, key -> this.bind(nativeInterface)).get();
        }
        catch (Throwable throwable) {
            throw new StructException("should not reach here!", throwable);
        }
    }

    private static boolean needTransToPointer(Parameter parameter) {
        Class<?> typeClass = parameter.getType();
        return MemorySegment.class.isAssignableFrom(typeClass) || NativeAddressable.class.isAssignableFrom(typeClass) || !typeClass.isPrimitive() && parameter.getAnnotation(Pointer.class) != null;
    }

    public MethodHandle generateInGeneratorContext(Class interfaceClass, String methodName, MethodType methodType) throws NoSuchMethodException {
        NativeCallGenerator current = this;
        methodType = methodType.dropParameterTypes(0, 1);
        Method method = interfaceClass.getDeclaredMethod(methodName, methodType.parameterArray());
        return current.nativeMethodHandle(method);
    }

    public MemorySegment generateUpcall(Arena scope, Method method, Object ... receiver) {
        if (!Modifier.isStatic(method.getModifiers())) {
            if (receiver.length != 1) {
                throw new IllegalArgumentException("receiver length must be 1");
            }
            if (!method.getDeclaringClass().isAssignableFrom(receiver[0].getClass())) {
                throw new IllegalArgumentException("receiver type must be assignable from method declaring class");
            }
        }
        try {
            MethodHandle handle = MethodHandles.lookup().unreflect(method);
            if (!Modifier.isStatic(method.getModifiers())) {
                handle = handle.bindTo(receiver[0]);
            }
            Parameter[] parameters = method.getParameters();
            MemoryLayout[] memoryLayouts = new MemoryLayout[parameters.length];
            ArrayList<Integer> pointerIndex = new ArrayList<Integer>(parameters.length / 2);
            for (int i = 0; i < parameters.length; ++i) {
                Class<?> type = parameters[i].getType();
                if (type.isPrimitive()) {
                    memoryLayouts[i] = NativeLookup.primitiveMapToMemoryLayout(type);
                    continue;
                }
                memoryLayouts[i] = ValueLayout.ADDRESS;
                pointerIndex.add(i);
            }
            AddressLayout returnMemoryLayout = method.getReturnType().isPrimitive() ? NativeLookup.primitiveMapToMemoryLayout(method.getReturnType()) : ValueLayout.ADDRESS;
            FunctionDescriptor fd = method.getReturnType() == Void.TYPE ? FunctionDescriptor.ofVoid(memoryLayouts) : FunctionDescriptor.of(returnMemoryLayout, memoryLayouts);
            for (Integer index : pointerIndex) {
                int i = index;
                Class<?> argType = parameters[i].getType();
                MethodHandle argEnhanceMH = StructProxyGenerator.ENHANCE_MH.asType(StructProxyGenerator.ENHANCE_MH.type().changeReturnType(argType)).bindTo(this.structProxyGenerator).bindTo(argType);
                handle = MethodHandles.filterArguments(handle, i, argEnhanceMH);
                MemoryLayout layout = this.structProxyGenerator.extract(argType);
                handle = MethodHandles.filterArguments(handle, i, MethodHandles.insertArguments(NativeGeneratorHelper.DEREFERENCE, 1, layout.byteSize()));
            }
            if (!method.getReturnType().isPrimitive()) {
                handle = MethodHandles.filterReturnValue(handle, NativeGeneratorHelper.TRANSFORM_OBJECT_TO_STRUCT_MH.asType(NativeGeneratorHelper.TRANSFORM_OBJECT_TO_STRUCT_MH.type().changeParameterType(0, method.getReturnType())));
            }
            return Linker.nativeLinker().upcallStub(handle, fd, scope, new Linker.Option[0]);
        }
        catch (Throwable t) {
            throw new StructException("should not reach here!", t);
        }
    }

    public FunctionPointer generateUpcallFP(Arena scope, Method method, Object ... receiver) {
        return new FunctionPointer(this.generateUpcall(scope, method, receiver));
    }

    public void plainMode() {
        this.use_indy = false;
    }

    private MethodHandle nativeMethodHandle(Method method) {
        return this.nativeMethodHandle(method, false);
    }

    MethodHandle nativeMethodHandle(Method method, boolean lazy) {
        DowncallContext downcallContext = this.parseDowncallContext(method);
        String functionName = downcallContext.functionName();
        FunctionDescriptor fd = downcallContext.fd();
        Linker.Option[] options = downcallContext.ops();
        boolean returnPointer = downcallContext.returnPointer();
        boolean needCaptureStatue = downcallContext.needCaptureStatue();
        ArrayList<Integer> rawMemoryIndex = downcallContext.rawMemoryIndex();
        if (needCaptureStatue && downcallContext.fast()) {
            throw new IllegalArgumentException("fast mode cant capture errno");
        }
        MethodHandle methodHandle = Linker.nativeLinker().downcallHandle(fd, options);
        if (lazy) {
            MethodHandle dlsymMH = DLSYM_MH.bindTo(this).bindTo(functionName);
            methodHandle = MethodHandles.collectArguments(methodHandle, 0, dlsymMH);
        } else {
            methodHandle = methodHandle.bindTo(this.dlsym(functionName));
        }
        if (needCaptureStatue) {
            methodHandle = MethodHandles.collectArguments(methodHandle, 0, NativeLookup.AllocateErrorBuffer_MH);
        }
        Parameter[] parameters = method.getParameters();
        for (Integer i : rawMemoryIndex) {
            methodHandle = MethodHandles.filterArguments(methodHandle, i, NativeGeneratorHelper.TRANSFORM_OBJECT_TO_STRUCT_MH.asType(NativeGeneratorHelper.TRANSFORM_OBJECT_TO_STRUCT_MH.type().changeParameterType(0, parameters[i].getType())));
        }
        for (int i = 0; i < parameters.length; ++i) {
            Parameter parameter = parameters[i];
            Class<?> parameterType = parameter.getType();
            if (!parameterType.isArray()) continue;
            MethodHandle wrapperMH = NativeLookup.heapAccessMH(parameterType.getComponentType());
            methodHandle = MethodHandles.filterArguments(methodHandle, i, wrapperMH);
        }
        if (needCaptureStatue) {
            methodHandle = NativeLookup.fillErrorNoAfterReturn(methodHandle);
        }
        if (MemorySegment.class.isAssignableFrom(method.getReturnType())) {
            return methodHandle;
        }
        if (returnPointer) {
            MemoryLayout returnLayout = fd.returnLayout().get();
            methodHandle = MethodHandles.filterReturnValue(methodHandle, MethodHandles.insertArguments(NativeGeneratorHelper.DEREFERENCE, 1, returnLayout.byteSize()));
            MethodHandle returnEnhance = StructProxyGenerator.ENHANCE_MH.asType(StructProxyGenerator.ENHANCE_MH.type().changeReturnType(method.getReturnType())).bindTo(this.structProxyGenerator).bindTo(method.getReturnType());
            methodHandle = MethodHandles.filterReturnValue(methodHandle, returnEnhance);
        }
        return methodHandle;
    }

    DowncallContext parseDowncallContext(Method method) {
        boolean needCaptureStatue;
        NativeFunction function = method.getAnnotation(NativeFunction.class);
        String functionName = function == null || Objects.requireNonNullElse(function.value(), "").isBlank() ? method.getName() : function.value();
        Class<?> returnType = method.getReturnType();
        boolean returnPointer = !returnType.isPrimitive() && function != null && function.returnIsPointer();
        ArrayList<Linker.Option> options = new ArrayList<Linker.Option>(2);
        boolean needFast = function != null && function.fast();
        boolean needHeap = function != null && function.allowPassHeap();
        boolean bl = needCaptureStatue = function != null && function.needErrorNo();
        if (needCaptureStatue) {
            options.add(Linker.Option.captureCallState("errno"));
        }
        if (NativeLookup.primitiveMapToMemoryLayout(returnType) == null && !returnPointer && !MemorySegment.class.isAssignableFrom(returnType) && returnType != Void.TYPE) {
            throw new IllegalArgumentException(String.valueOf(method) + " must return primitive type or is marked returnIsPointer");
        }
        ArrayList<Integer> rawMemoryIndex = new ArrayList<Integer>();
        MemoryLayout[] layouts = new MemoryLayout[method.getParameterCount()];
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; ++i) {
            Class<?> typeClass = parameters[i].getType();
            if (NativeCallGenerator.needTransToPointer(parameters[i])) {
                layouts[i] = ValueLayout.ADDRESS;
                rawMemoryIndex.add(i);
                continue;
            }
            if (typeClass.isArray()) {
                if (!typeClass.getComponentType().isPrimitive()) {
                    throw new IllegalArgumentException("array must be primitive type");
                }
                needHeap = true;
                needFast = true;
                layouts[i] = ValueLayout.ADDRESS;
                continue;
            }
            layouts[i] = this.structProxyGenerator.extract(typeClass);
            if (typeClass.isPrimitive()) continue;
            rawMemoryIndex.add(i);
        }
        MemoryLayout returnLayout = MemorySegment.class.isAssignableFrom(returnType) || returnPointer ? ValueLayout.ADDRESS : this.structProxyGenerator.extract(returnType);
        if (needFast) {
            options.add(Linker.Option.critical((boolean)needHeap));
        }
        FunctionDescriptor fd = returnType == Void.TYPE ? FunctionDescriptor.ofVoid(layouts) : FunctionDescriptor.of(returnLayout, layouts);
        return new DowncallContext(fd, (Linker.Option[])options.toArray(Linker.Option[]::new), functionName, returnPointer, needCaptureStatue, rawMemoryIndex);
    }

    private Class generateRuntimeProxyClass(MethodHandles.Lookup lookup, Class nativeInterface) throws IllegalAccessException {
        String className = this.generateProxyClassName(nativeInterface);
        ClassDesc thisClassDesc = ClassDesc.ofDescriptor("L" + className.replace(".", "/") + ";");
        byte[] thisClass = this.classFile.build(thisClassDesc, classBuilder -> {
            classBuilder.withInterfaceSymbols(new ClassDesc[]{ClassFileHelper.toDesc(nativeInterface)});
            classBuilder.withField(GENERATOR_FIELD_NAME, ClassFileHelper.toDesc(NativeCallGenerator.class), AccessFlags.ofField((AccessFlag[])new AccessFlag[]{AccessFlag.PUBLIC, AccessFlag.STATIC, AccessFlag.FINAL}).flagsMask());
            ArrayList<Consumer<CodeBuilder>> clinits = new ArrayList<Consumer<CodeBuilder>>();
            clinits.add(it -> {
                ClassFileHelper.invoke(it, NativeGeneratorHelper.FETCH_CURRENT_NATIVE_CALL_GENERATOR);
                it.putstatic(thisClassDesc, GENERATOR_FIELD_NAME, ClassFileHelper.toDesc(NativeCallGenerator.class));
            });
            clinits.add(it -> {
                it.ldc((ConstantDesc)ClassFileHelper.toDesc(nativeInterface));
                ClassFileHelper.invoke(it, NativeGeneratorHelper.LOAD_SO);
            });
            classBuilder.withMethodBody("<init>", ConstantDescs.MTD_void, 1, it -> {
                it.aload(0);
                it.invokespecial(ConstantDescs.CD_Object, "<init>", ConstantDescs.MTD_void);
                it.return_();
            });
            for (Method method : nativeInterface.getMethods()) {
                if (method.isBridge() || method.isDefault() || method.isSynthetic()) continue;
                if (!this.use_indy) {
                    Consumer<CodeBuilder> needInitInClint = this.invokeByMh(method, (ClassBuilder)classBuilder, className);
                    clinits.add(needInitInClint);
                    continue;
                }
                this.invokeByIndy(method, (ClassBuilder)classBuilder, className);
            }
            classBuilder.withMethodBody("<clinit>", ConstantDescs.MTD_void, AccessFlag.STATIC.mask(), it -> {
                clinits.forEach(init -> init.accept(it));
                it.return_();
            });
        });
        if (this.structProxyGenerator.classDataPeek != null) {
            this.structProxyGenerator.classDataPeek.accept(className, thisClass);
        }
        return lookup.defineClass(thisClass);
    }

    private Consumer<CodeBuilder> invokeByMh(Method method, ClassBuilder thisClass, String className) {
        String mhFieldName = method.getName() + "_native_method_handle";
        ClassDesc thisClassDesc = ClassDesc.of(className);
        thisClass.withMethodBody(method.getName(), ClassFileHelper.toMethodDescriptor(method), AccessFlags.ofMethod((AccessFlag[])new AccessFlag[]{AccessFlag.PUBLIC}).flagsMask(), it -> {
            it.getstatic(thisClassDesc, mhFieldName, ClassFileHelper.toDesc(MethodHandle.class));
            ClassFileHelper.invokeMethodHandleExactWithAllArgs(method, it);
        });
        thisClass.withField(mhFieldName, ClassFileHelper.toDesc(MethodHandle.class), AccessFlags.ofField((AccessFlag[])new AccessFlag[]{AccessFlag.PUBLIC, AccessFlag.STATIC, AccessFlag.FINAL}).flagsMask());
        return it -> {
            it.getstatic(thisClassDesc, GENERATOR_FIELD_NAME, ClassFileHelper.toDesc(NativeCallGenerator.class));
            ClassDesc nativeInterfaceClassDesc = ClassFileHelper.toDesc(method.getDeclaringClass());
            it.ldc((ConstantDesc)nativeInterfaceClassDesc);
            it.ldc((ConstantDesc)((Object)method.getName()));
            it.ldc((ConstantDesc)ClassFileHelper.toMethodDescriptor(method).insertParameterTypes(0, nativeInterfaceClassDesc));
            ClassFileHelper.invoke(it, GENERATE_IN_GENERATOR_CONTEXT);
            it.putstatic(thisClassDesc, mhFieldName, ClassFileHelper.toDesc(MethodHandle.class));
        };
    }

    private void invokeByIndy(Method method, ClassBuilder thisClass, String className) {
        thisClass.withMethodBody(method.getName(), ClassFileHelper.toMethodDescriptor(method), AccessFlags.ofMethod((AccessFlag[])new AccessFlag[]{AccessFlag.PUBLIC}).flagsMask(), it -> {
            ClassFileHelper.loadAllArgs(method, it);
            it.invokedynamic(DynamicCallSiteDesc.of(ConstantDescs.ofCallsiteBootstrap(ClassFileHelper.toDesc(InvokeDynamicFactory.class), "nativeCallIndyFactory", ConstantDescs.CD_CallSite, new ClassDesc[0]), method.getName(), ClassFileHelper.toMethodDescriptor(method)));
            ClassFileHelper.returnValue(it, method.getReturnType());
        });
    }

    Supplier<Object> bind(Class<?> nativeInterface) {
        try {
            if (!nativeInterface.isInterface()) {
                throw new IllegalArgumentException(String.valueOf(nativeInterface) + " is not interface");
            }
            String className = this.generateProxyClassName(nativeInterface);
            NativeGeneratorHelper.CURRENT_GENERATOR.set(this);
            MethodHandles.Lookup lookup = MethodHandles.privateLookupIn(nativeInterface, MethodHandles.lookup());
            Class aClass = null;
            try {
                aClass = lookup.findClass(className);
            }
            catch (ClassNotFoundException classNotFoundException) {
                // empty catch block
            }
            if (aClass == null) {
                aClass = this.generateRuntimeProxyClass(lookup, nativeInterface);
            }
            if (!this.structProxyGenerator.skipInit) {
                lookup.ensureInitialized(aClass);
            }
            MethodHandle methodHandle = MethodHandles.lookup().findConstructor(aClass, MethodType.methodType(Void.TYPE));
            if (this.use_lmf) {
                Supplier<Object> supplier = NativeGeneratorHelper.ctorBinder(methodHandle);
                return supplier;
            }
            Supplier<Object> supplier = () -> {
                MethodHandle mh = methodHandle.asType(methodHandle.type().changeReturnType(Object.class));
                try {
                    return mh.invokeExact();
                }
                catch (Throwable t) {
                    t.printStackTrace();
                    return null;
                }
            };
            return supplier;
        }
        catch (Throwable e) {
            throw new StructException("should not reach here!", e);
        }
        finally {
            NativeGeneratorHelper.CURRENT_GENERATOR.remove();
        }
    }

    public void useLmf(boolean use_lmf) {
        this.use_lmf = use_lmf;
    }

    String generateProxyClassName(Class nativeInterface) {
        String proxyName = nativeInterface.getName() + "_native_call_enhance";
        if (this.use_indy) {
            proxyName = proxyName + "_indy";
        }
        return proxyName;
    }

    private MemorySegment dlsym(String name) {
        Objects.requireNonNull(name);
        return this.foreignFunctionAddressCache.computeIfAbsent(name, this.nativeLibLookup::findOrException);
    }

    public static void loadSo(Class<?> interfaceClass) throws IOException {
        InputStream inputStream;
        ClassLoader classLoader = interfaceClass.getClassLoader();
        CLib annotation = interfaceClass.getAnnotation(CLib.class);
        if (annotation == null || annotation.value().isBlank()) {
            return;
        }
        if (annotation.isLib()) {
            Runtime.getRuntime().loadLibrary(classLoader.getName());
            return;
        }
        InputStream inputStream2 = inputStream = annotation.inClassPath() ? classLoader.getResourceAsStream(annotation.value()) : new FileInputStream(annotation.value());
        if (inputStream == null) {
            throw new IllegalArgumentException("cant find clib! classloader: " + String.valueOf(classLoader) + ", path: " + annotation.value());
        }
        String tmpFileName = interfaceClass.getSimpleName() + "_" + String.valueOf(UUID.randomUUID()) + "_dynamic.so";
        File file = File.createTempFile(tmpFileName, ".tmp");
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        file.deleteOnExit();
        try (FileOutputStream fileOutputStream2 = fileOutputStream;
             InputStream inputStream3 = inputStream;){
            inputStream.transferTo(fileOutputStream);
            System.load(file.getAbsolutePath());
        }
    }

    static {
        try {
            DLSYM_MH = MethodHandles.lookup().findVirtual(NativeCallGenerator.class, "dlsym", MethodType.methodType(MemorySegment.class, String.class));
            GENERATE_IN_GENERATOR_CONTEXT = NativeCallGenerator.class.getMethod("generateInGeneratorContext", Class.class, String.class, MethodType.class);
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }
}

