package com.github.unidbg.linux.android;

import com.github.unidbg.Emulator;
import com.github.unidbg.Module;
import com.github.unidbg.Svc;
import com.github.unidbg.Symbol;
import com.github.unidbg.arm.Arm64Svc;
import com.github.unidbg.arm.backend.Backend;
import com.github.unidbg.arm.context.RegisterContext;
import com.github.unidbg.linux.LinuxModule;
import com.github.unidbg.linux.struct.dl_phdr_info;
import com.github.unidbg.memory.Memory;
import com.github.unidbg.memory.MemoryBlock;
import com.github.unidbg.memory.SvcMemory;
import com.github.unidbg.pointer.UnidbgPointer;
import com.github.unidbg.pointer.UnidbgStructure;
import com.github.unidbg.spi.Dlfcn;
import com.github.unidbg.spi.InitFunction;
import com.github.unidbg.unix.struct.DlInfo;
import com.sun.jna.Pointer;
import keystone.Keystone;
import keystone.KeystoneArchitecture;
import keystone.KeystoneEncoded;
import keystone.KeystoneMode;
import net.fornwall.jelf.ElfDynamicStructure;
import net.fornwall.jelf.ElfFile;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import unicorn.Arm64Const;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

public class ArmLD64 extends Dlfcn {

    private static final Log log = LogFactory.getLog(ArmLD64.class);

    private final Backend backend;

    ArmLD64(Backend backend, SvcMemory svcMemory) {
        super(svcMemory);
        this.backend = backend;
    }

    @Override
    public long hook(final SvcMemory svcMemory, String libraryName, String symbolName, long old) {
        if ("libdl.so".equals(libraryName)) {
            if (log.isDebugEnabled()) {
                log.debug("link " + symbolName + ", old=0x" + Long.toHexString(old));
            }
            switch (symbolName) {
                case "dl_iterate_phdr":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        private MemoryBlock block;
                        @Override
                        public UnidbgPointer onRegister(SvcMemory svcMemory, int svcNumber) {
                            try (Keystone keystone = new Keystone(KeystoneArchitecture.Arm64, KeystoneMode.LittleEndian)) {
                                KeystoneEncoded encoded = keystone.assemble(Arrays.asList(
                                        "sub sp, sp, #0x10",
                                        "stp x29, x30, [sp]",
                                        "svc #0x" + Integer.toHexString(svcNumber),

                                        "ldr x7, [sp]",
                                        "add sp, sp, #0x8",
                                        "cmp x7, #0",
                                        "b.eq #0x58",
                                        "ldr x0, [sp]",
                                        "add sp, sp, #0x8",
                                        "ldr x1, [sp]",
                                        "add sp, sp, #0x8",
                                        "ldr x2, [sp]",
                                        "add sp, sp, #0x8",
                                        "blr x7",
                                        "cmp w0, #0",
                                        "b.eq #0xc",

                                        "ldr x7, [sp]",
                                        "add sp, sp, #0x8",
                                        "cmp x7, #0",
                                        "b.eq #0x58",
                                        "add sp, sp, #0x18",
                                        "b 0x40",

                                        "mov x8, #0",
                                        "mov x4, #0x" + Integer.toHexString(svcNumber),
                                        "mov x16, #0x" + Integer.toHexString(Svc.CALLBACK_SYSCALL_NUMBER),
                                        "svc #0",

                                        "ldp x29, x30, [sp]",
                                        "add sp, sp, #0x10",
                                        "ret"));
                                byte[] code = encoded.getMachineCode();
                                UnidbgPointer pointer = svcMemory.allocate(code.length, "dl_iterate_phdr");
                                pointer.write(0, code, 0, code.length);
                                if (log.isDebugEnabled()) {
                                    log.debug("dl_iterate_phdr: pointer=" + pointer);
                                }
                                return pointer;
                            }
                        }
                        @Override
                        public long handle(Emulator<?> emulator) {
                            if (block != null) {
                                throw new IllegalStateException();
                            }

                            RegisterContext context = emulator.getContext();
                            UnidbgPointer cb = context.getPointerArg(0);
                            UnidbgPointer data = context.getPointerArg(1);

                            Collection<Module> modules = emulator.getMemory().getLoadedModules();
                            List<LinuxModule> list = new ArrayList<>();
                            for (Module module : modules) {
                                LinuxModule lm = (LinuxModule) module;
                                if (lm.elfFile != null) {
                                    list.add(lm);
                                }
                            }
                            Collections.reverse(list);
                            final int size = UnidbgStructure.calculateSize(dl_phdr_info.class);
                            block = emulator.getMemory().malloc(size * list.size(), true);
                            UnidbgPointer ptr = block.getPointer();
                            Backend backend = emulator.getBackend();
                            UnidbgPointer sp = UnidbgPointer.register(emulator, Arm64Const.UC_ARM64_REG_SP);
                            if (log.isDebugEnabled()) {
                                log.debug("dl_iterate_phdr cb=" + cb + ", data=" + data + ", size=" + list.size() + ", sp=" + sp);
                            }

                            try {
                                sp = sp.share(-8, 0);
                                sp.setLong(0, 0); // NULL-terminated

                                for (LinuxModule module : list) {
                                    dl_phdr_info info = new dl_phdr_info(ptr);
                                    info.dlpi_addr = UnidbgPointer.pointer(emulator, module.base);
                                    assert info.dlpi_addr != null;
                                    ElfDynamicStructure dynamicStructure = module.dynamicStructure;
                                    if (dynamicStructure != null && dynamicStructure.soName > 0 && dynamicStructure.dt_strtab_offset > 0) {
                                        info.dlpi_name = info.dlpi_addr.share(dynamicStructure.dt_strtab_offset + dynamicStructure.soName);
                                    } else {
                                        info.dlpi_name = module.createPathMemory(svcMemory);
                                    }
                                    info.dlpi_phdr = info.dlpi_addr.share(module.elfFile.ph_offset);
                                    info.dlpi_phnum = module.elfFile.num_ph;
                                    info.pack();

                                    sp = sp.share(-8, 0);
                                    sp.setPointer(0, data); // data

                                    sp = sp.share(-8, 0);
                                    sp.setLong(0, size); // size

                                    sp = sp.share(-8, 0);
                                    sp.setPointer(0, ptr); // dl_phdr_info

                                    sp = sp.share(-8, 0);
                                    sp.setPointer(0, cb); // callback

                                    ptr = ptr.share(size, 0);
                                }

                                return context.getLongArg(0);
                            } finally {
                                backend.reg_write(Arm64Const.UC_ARM64_REG_SP, sp.peer);
                            }
                        }
                        @Override
                        public void handleCallback(Emulator<?> emulator) {
                            super.handleCallback(emulator);

                            if (block == null) {
                                throw new IllegalStateException();
                            }
                            block.free();
                            block = null;
                        }
                    }).peer;
                case "dlerror":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public long handle(Emulator<?> emulator) {
                            return error.peer;
                        }
                    }).peer;
                case "dlclose":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public long handle(Emulator<?> emulator) {
                            RegisterContext context = emulator.getContext();
                            long handle = context.getLongArg(0);
                            if (log.isDebugEnabled()) {
                                log.debug("dlclose handle=0x" + Long.toHexString(handle));
                            }
                            return dlclose(emulator.getMemory(), handle);
                        }
                    }).peer;
                case "dlopen":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public UnidbgPointer onRegister(SvcMemory svcMemory, int svcNumber) {
                            ByteBuffer buffer = ByteBuffer.allocate(56);
                            buffer.order(ByteOrder.LITTLE_ENDIAN);
                            buffer.putInt(0xd10043ff); // "sub sp, sp, #0x10"
                            buffer.putInt(0xa9007bfd); // "stp x29, x30, [sp]"
                            buffer.putInt(Arm64Svc.assembleSvc(svcNumber)); // "svc #0x" + Integer.toHexString(svcNumber)
                            buffer.putInt(0xf94003e7); // "ldr x7, [sp]"
                            buffer.putInt(0x910023ff); // "add sp, sp, #0x8", manipulated stack in dlopen
                            buffer.putInt(0xf10000ff); // "cmp x7, #0"
                            buffer.putInt(0x54000060); // "b.eq #0x24"
                            buffer.putInt(0x10ffff9e); // "adr lr, #-0xf", jump to ldr x7, [sp]
                            buffer.putInt(0xd61f00e0); // "br x7", call init array
                            buffer.putInt(0xf94003e0); // "ldr x0, [sp]", with return address
                            buffer.putInt(0x910023ff); // "add sp, sp, #0x8"
                            buffer.putInt(0xa9407bfd); // "ldp x29, x30, [sp]"
                            buffer.putInt(0x910043ff); // "add sp, sp, #0x10"
                            buffer.putInt(0xd65f03c0); // "ret"
                            byte[] code = buffer.array();
                            UnidbgPointer pointer = svcMemory.allocate(code.length, "dlopen");
                            pointer.write(0, code, 0, code.length);
                            return pointer;
                        }
                        @Override
                        public long handle(Emulator<?> emulator) {
                            RegisterContext context = emulator.getContext();
                            Pointer filename = context.getPointerArg(0);
                            int flags = context.getIntArg(1);
                            if (log.isDebugEnabled()) {
                                log.debug("dlopen filename=" + filename.getString(0) + ", flags=" + flags);
                            }
                            return dlopen(emulator.getMemory(), filename.getString(0), emulator);
                        }
                    }).peer;
                case "dladdr":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public long handle(Emulator<?> emulator) {
                            RegisterContext context = emulator.getContext();
                            long addr = context.getLongArg(0);
                            Pointer info = context.getPointerArg(1);
                            if (log.isDebugEnabled()) {
                                log.debug("dladdr addr=0x" + Long.toHexString(addr) + ", info=" + info);
                            }
                            Module module = emulator.getMemory().findModuleByAddress(addr);
                            if (module == null) {
                                return 0;
                            }

                            Symbol symbol = module.findClosestSymbolByAddress(addr, true);

                            DlInfo dlInfo = new DlInfo(info);
                            dlInfo.dli_fname = module.createPathMemory(svcMemory);
                            dlInfo.dli_fbase = UnidbgPointer.pointer(emulator, module.base);
                            if (symbol != null) {
                                dlInfo.dli_sname = symbol.createNameMemory(svcMemory);
                                dlInfo.dli_saddr = UnidbgPointer.pointer(emulator, symbol.getAddress());
                            }
                            dlInfo.pack();
                            return 1;
                        }
                    }).peer;
                case "dlsym":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public long handle(Emulator<?> emulator) {
                            RegisterContext context = emulator.getContext();
                            long handle = context.getLongArg(0);
                            Pointer symbol = context.getPointerArg(1);
                            if (log.isDebugEnabled()) {
                                log.debug("dlsym handle=0x" + Long.toHexString(handle) + ", symbol=" + symbol.getString(0));
                            }
                            return dlsym(emulator, handle, symbol.getString(0));
                        }
                    }).peer;
                case "dl_unwind_find_exidx":
                    return svcMemory.registerSvc(new Arm64Svc() {
                        @Override
                        public long handle(Emulator<?> emulator) {
                            RegisterContext context = emulator.getContext();
                            Pointer pc = context.getPointerArg(0);
                            Pointer pcount = context.getPointerArg(1);
                            log.info("dl_unwind_find_exidx pc" + pc + ", pcount=" + pcount);
                            return 0;
                        }
                    }).peer;
            }
        }
        return 0;
    }

    private long dlopen(Memory memory, String filename, Emulator<?> emulator) {
        Pointer pointer = UnidbgPointer.register(emulator, Arm64Const.UC_ARM64_REG_SP);
        try {
            Module module = memory.dlopen(filename, false);
            pointer = pointer.share(-8); // return value
            if (module == null) {
                pointer.setLong(0, 0);

                pointer = pointer.share(-8); // NULL-terminated
                pointer.setLong(0, 0);

                if (!"libnetd_client.so".equals(filename)) {
                    log.info("dlopen failed: " + filename);
                } else if(log.isDebugEnabled()) {
                    log.debug("dlopen failed: " + filename);
                }
                this.error.setString(0, "Resolve library " + filename + " failed");
                return 0;
            } else {
                pointer.setLong(0, module.base);

                pointer = pointer.share(-8); // NULL-terminated
                pointer.setLong(0, 0);

                for (Module md : memory.getLoadedModules()) {
                    LinuxModule m = (LinuxModule) md;
                    if (!m.getUnresolvedSymbol().isEmpty()) {
                        continue;
                    }
                    for (InitFunction initFunction : m.initFunctionList) {
                        long address = initFunction.getAddress();
                        if (address == 0) {
                            continue;
                        }
                        if (log.isDebugEnabled()) {
                            log.debug("[" + m.name + "]PushInitFunction: 0x" + Long.toHexString(address));
                        }
                        pointer = pointer.share(-8); // init array
                        pointer.setLong(0, address);
                    }
                    m.initFunctionList.clear();
                }

                return module.base;
            }
        } finally {
            backend.reg_write(Arm64Const.UC_ARM64_REG_SP, ((UnidbgPointer) pointer).peer);
        }
    }

    private int dlclose(Memory memory, long handle) {
        if (memory.dlclose(handle)) {
            return 0;
        } else {
            this.error.setString(0, "dlclose 0x" + Long.toHexString(handle) + " failed");
            return -1;
        }
    }

}
