/*
 * Decompiled with CFR 0.152.
 */
package ai.rapids.cudf;

import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.HostMemoryBuffer;
import ai.rapids.cudf.MemoryBuffer;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class PinnedMemoryPool
implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(PinnedMemoryPool.class);
    private static final long ALIGNMENT = 8L;
    private static volatile PinnedMemoryPool singleton_ = null;
    private static Future<PinnedMemoryPool> initFuture = null;
    private final long pinnedPoolBase;
    private final PriorityQueue<MemorySection> freeHeap = new PriorityQueue<MemorySection>(new SortedBySize());
    private int numAllocatedSections = 0;
    private long availableBytes;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static PinnedMemoryPool getSingleton() {
        if (singleton_ != null) return singleton_;
        if (initFuture == null) {
            return null;
        }
        Class<PinnedMemoryPool> clazz = PinnedMemoryPool.class;
        synchronized (PinnedMemoryPool.class) {
            if (singleton_ != null) return singleton_;
            try {
                singleton_ = initFuture.get();
            }
            catch (Exception e) {
                throw new RuntimeException("Error initializing pinned memory pool", e);
            }
            initFuture = null;
            // ** MonitorExit[var0] (shouldn't be in output)
            return singleton_;
        }
    }

    private static void freeInternal(MemorySection section) {
        Objects.requireNonNull(PinnedMemoryPool.getSingleton()).free(section);
    }

    public static synchronized void initialize(long poolSize) {
        PinnedMemoryPool.initialize(poolSize, -1);
    }

    public static synchronized void initialize(long poolSize, int gpuId) {
        if (PinnedMemoryPool.isInitialized()) {
            throw new IllegalStateException("Can only initialize the pool once.");
        }
        ExecutorService initService = Executors.newSingleThreadExecutor(runnable -> {
            Thread t = new Thread(runnable, "pinned pool init");
            t.setDaemon(true);
            return t;
        });
        initFuture = initService.submit(() -> new PinnedMemoryPool(poolSize, gpuId));
        initService.shutdown();
    }

    public static boolean isInitialized() {
        return PinnedMemoryPool.getSingleton() != null;
    }

    public static synchronized void shutdown() {
        PinnedMemoryPool pool = PinnedMemoryPool.getSingleton();
        if (pool != null) {
            pool.close();
        }
        initFuture = null;
        singleton_ = null;
    }

    public static HostMemoryBuffer tryAllocate(long bytes) {
        HostMemoryBuffer result = null;
        PinnedMemoryPool pool = PinnedMemoryPool.getSingleton();
        if (pool != null) {
            result = pool.tryAllocateInternal(bytes);
        }
        return result;
    }

    public static HostMemoryBuffer allocate(long bytes) {
        HostMemoryBuffer result = PinnedMemoryPool.tryAllocate(bytes);
        if (result == null) {
            result = HostMemoryBuffer.allocate(bytes, false);
        }
        return result;
    }

    public static long getAvailableBytes() {
        PinnedMemoryPool pool = PinnedMemoryPool.getSingleton();
        if (pool != null) {
            return pool.getAvailableBytesInternal();
        }
        return 0L;
    }

    private PinnedMemoryPool(long poolSize, int gpuId) {
        if (gpuId > -1) {
            Cuda.setDevice(gpuId);
            Cuda.freeZero();
        }
        this.pinnedPoolBase = Cuda.hostAllocPinned(poolSize);
        this.freeHeap.add(new MemorySection(this.pinnedPoolBase, poolSize));
        this.availableBytes = poolSize;
    }

    @Override
    public void close() {
        assert (this.numAllocatedSections == 0);
        Cuda.freePinned(this.pinnedPoolBase);
    }

    private synchronized HostMemoryBuffer tryAllocateInternal(long bytes) {
        MemorySection allocated;
        if (this.freeHeap.isEmpty()) {
            log.debug("No free pinned memory left");
            return null;
        }
        long alignedBytes = (bytes + 8L - 1L) / 8L * 8L;
        MemorySection largest = this.freeHeap.peek();
        if (largest.size < alignedBytes) {
            log.debug("Insufficient pinned memory. {} needed, {} found", (Object)alignedBytes, (Object)largest.size);
            return null;
        }
        log.debug("Allocating {}/{} bytes pinned from {} FREE COUNT {} OUTSTANDING COUNT {}", new Object[]{bytes, alignedBytes, largest, this.freeHeap.size(), this.numAllocatedSections});
        this.freeHeap.remove(largest);
        if (largest.size == alignedBytes) {
            allocated = largest;
        } else {
            allocated = largest.splitOff(alignedBytes);
            this.freeHeap.add(largest);
        }
        ++this.numAllocatedSections;
        this.availableBytes -= allocated.size;
        log.debug("Allocated {} free {} outstanding {}", new Object[]{allocated, this.freeHeap, this.numAllocatedSections});
        return new HostMemoryBuffer(allocated.baseAddress, bytes, new PinnedHostBufferCleaner(allocated, bytes));
    }

    private synchronized void free(MemorySection section) {
        boolean anyReplaced;
        log.debug("Freeing {} with {} outstanding {}", new Object[]{section, this.freeHeap, this.numAllocatedSections});
        this.availableBytes += section.size;
        do {
            anyReplaced = false;
            Iterator<MemorySection> it = this.freeHeap.iterator();
            while (it.hasNext()) {
                MemorySection current = it.next();
                if (!section.canCombine(current)) continue;
                it.remove();
                anyReplaced = true;
                section.combineWith(current);
            }
        } while (anyReplaced);
        this.freeHeap.add(section);
        --this.numAllocatedSections;
        log.debug("After freeing {} outstanding {}", this.freeHeap, (Object)this.numAllocatedSections);
    }

    private synchronized long getAvailableBytesInternal() {
        return this.availableBytes;
    }

    private static final class PinnedHostBufferCleaner
    extends MemoryBuffer.MemoryBufferCleaner {
        private MemorySection section;
        private final long origLength;

        PinnedHostBufferCleaner(MemorySection section, long length) {
            this.section = section;
            this.origLength = length;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
            boolean neededCleanup = false;
            long origAddress = 0L;
            if (this.section != null) {
                origAddress = this.section.baseAddress;
                try {
                    PinnedMemoryPool.freeInternal(this.section);
                }
                finally {
                    this.section = null;
                }
                neededCleanup = true;
            }
            if (neededCleanup && logErrorIfNotClean) {
                log.error("A PINNED HOST BUFFER WAS LEAKED (ID: " + this.id + " " + Long.toHexString(origAddress) + ")");
                this.logRefCountDebug("Leaked pinned host buffer");
            }
            return neededCleanup;
        }

        @Override
        public boolean isClean() {
            return this.section == null;
        }
    }

    private static class MemorySection {
        private long baseAddress;
        private long size;

        MemorySection(long baseAddress, long size) {
            this.baseAddress = baseAddress;
            this.size = size;
        }

        boolean canCombine(MemorySection other) {
            boolean ret = other.baseAddress + other.size == this.baseAddress || this.baseAddress + this.size == other.baseAddress;
            log.trace("CAN {} COMBINE WITH {} ? {}", new Object[]{this, other, ret});
            return ret;
        }

        void combineWith(MemorySection other) {
            assert (this.canCombine(other));
            log.trace("COMBINING {} AND {}", (Object)this, (Object)other);
            this.baseAddress = Math.min(this.baseAddress, other.baseAddress);
            this.size = other.size + this.size;
            log.trace("COMBINED TO {}\n", (Object)this);
        }

        MemorySection splitOff(long newSize) {
            assert (this.size > newSize);
            MemorySection ret = new MemorySection(this.baseAddress, newSize);
            this.baseAddress += newSize;
            this.size -= newSize;
            return ret;
        }

        public String toString() {
            return "PINNED: " + this.size + " bytes (0x" + Long.toHexString(this.baseAddress) + " to 0x" + Long.toHexString(this.baseAddress + this.size) + ")";
        }
    }

    private static class SortedBySize
    implements Comparator<MemorySection> {
        private SortedBySize() {
        }

        @Override
        public int compare(MemorySection s0, MemorySection s1) {
            int ret = Long.compare(s1.size, s0.size);
            if (ret == 0) {
                ret = Long.compare(s0.baseAddress, s1.baseAddress);
            }
            return ret;
        }
    }
}

