/*
 * Decompiled with CFR 0.152.
 */
package ai.rapids.cudf;

import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.CuFileBuffer;
import ai.rapids.cudf.CuFileDriver;
import ai.rapids.cudf.CuFileHandle;
import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.MemoryBuffer;
import ai.rapids.cudf.nvcomp.BatchedLZ4Decompressor;
import ai.rapids.cudf.nvcomp.Decompressor;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MemoryCleaner {
    private static final boolean REF_COUNT_DEBUG = Boolean.getBoolean("ai.rapids.refcount.debug");
    private static final Logger log = LoggerFactory.getLogger(MemoryCleaner.class);
    private static final AtomicLong idGen = new AtomicLong(0L);
    static final AtomicLong leakCount = new AtomicLong();
    private static final Set<CleanerWeakReference> all = Collections.newSetFromMap(new ConcurrentHashMap());
    private static final ReferenceQueue<?> collected = new ReferenceQueue();
    private static volatile int defaultGpu = -1;
    private static final Thread t = new Thread(() -> {
        try {
            int currentGpuId = -1;
            while (true) {
                CleanerWeakReference next;
                if ((next = (CleanerWeakReference)collected.remove(100L)) == null) {
                    continue;
                }
                try {
                    if (currentGpuId != defaultGpu) {
                        Cuda.setDevice(defaultGpu);
                        currentGpuId = defaultGpu;
                    }
                }
                catch (Throwable t) {
                    log.error("ERROR TRYING TO SET GPU ID TO " + defaultGpu, t);
                }
                try {
                    next.clean();
                }
                catch (Throwable t) {
                    log.error("CAUGHT EXCEPTION WHILE TRYING TO CLEAN " + next, t);
                }
                all.remove(next);
            }
        }
        catch (InterruptedException interruptedException) {
            return;
        }
    }, "Cleaner Thread");

    static void setDefaultGpu(int defaultGpuId) {
        defaultGpu = defaultGpuId;
    }

    static void register(ColumnVector vec, Cleaner cleaner) {
        all.add(new CleanerWeakReference<ColumnVector>(vec, cleaner, collected, true));
    }

    static void register(HostColumnVectorCore vec, Cleaner cleaner) {
        all.add(new CleanerWeakReference<HostColumnVectorCore>(vec, cleaner, collected, false));
    }

    static void register(MemoryBuffer buf, Cleaner cleaner) {
        all.add(new CleanerWeakReference<MemoryBuffer>(buf, cleaner, collected, buf instanceof BaseDeviceMemoryBuffer));
    }

    static void register(Cuda.Stream stream, Cleaner cleaner) {
        all.add(new CleanerWeakReference<Cuda.Stream>(stream, cleaner, collected, false));
    }

    static void register(Cuda.Event event, Cleaner cleaner) {
        all.add(new CleanerWeakReference<Cuda.Event>(event, cleaner, collected, false));
    }

    public static void register(Decompressor.Metadata metadata, Cleaner cleaner) {
        all.add(new CleanerWeakReference<Decompressor.Metadata>(metadata, cleaner, collected, false));
    }

    public static void register(BatchedLZ4Decompressor.BatchedMetadata metadata, Cleaner cleaner) {
        all.add(new CleanerWeakReference<BatchedLZ4Decompressor.BatchedMetadata>(metadata, cleaner, collected, false));
    }

    static void register(CuFileDriver driver, Cleaner cleaner) {
        all.add(new CleanerWeakReference<CuFileDriver>(driver, cleaner, collected, false));
    }

    static void register(CuFileBuffer buffer, Cleaner cleaner) {
        all.add(new CleanerWeakReference<CuFileBuffer>(buffer, cleaner, collected, false));
    }

    static void register(CuFileHandle handle, Cleaner cleaner) {
        all.add(new CleanerWeakReference<CuFileHandle>(handle, cleaner, collected, false));
    }

    static boolean bestEffortHasRmmBlockers() {
        return all.stream().anyMatch(cwr -> cwr.isRmmBlocker && !((CleanerWeakReference)cwr).cleaner.isClean());
    }

    private static <T> String stringJoin(String delim, Iterable<T> it) {
        return String.join((CharSequence)delim, StreamSupport.stream(it.spliterator(), false).map(i -> i.toString()).collect(Collectors.toList()));
    }

    static /* synthetic */ AtomicLong access$000() {
        return idGen;
    }

    static {
        t.setDaemon(true);
        t.start();
        if (REF_COUNT_DEBUG) {
            Runtime.getRuntime().addShutdownHook(new Thread(() -> {
                System.gc();
                t.interrupt();
                try {
                    t.join(1000L);
                }
                catch (InterruptedException interruptedException) {
                    // empty catch block
                }
                if (defaultGpu >= 0) {
                    Cuda.setDevice(defaultGpu);
                }
                for (CleanerWeakReference cwr : all) {
                    cwr.clean();
                }
            }));
        }
    }

    private static final class RefCountDebugItem {
        final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
        final long timeMs = System.currentTimeMillis();
        final String op;

        public RefCountDebugItem(String op) {
            this.op = op;
        }

        public String toString() {
            Date date = new Date(this.timeMs);
            SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSSS z");
            return dateFormat.format(date) + ": " + this.op + "\n" + MemoryCleaner.stringJoin("\n", Arrays.asList(this.stackTrace)) + "\n";
        }
    }

    private static class CleanerWeakReference<T>
    extends WeakReference<T> {
        private final Cleaner cleaner;
        final boolean isRmmBlocker;

        public CleanerWeakReference(T orig, Cleaner cleaner, ReferenceQueue collected, boolean isRmmBlocker) {
            super(orig, collected);
            this.cleaner = cleaner;
            this.isRmmBlocker = isRmmBlocker;
        }

        public void clean() {
            if (this.cleaner.clean(true)) {
                leakCount.incrementAndGet();
            }
        }
    }

    public static abstract class Cleaner {
        private final List<RefCountDebugItem> refCountDebug;
        public final long id = MemoryCleaner.access$000().incrementAndGet();
        private boolean leakExpected = false;

        public Cleaner() {
            this.refCountDebug = REF_COUNT_DEBUG ? new LinkedList<RefCountDebugItem>() : null;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public final void addRef() {
            if (REF_COUNT_DEBUG && this.refCountDebug != null) {
                Cleaner cleaner = this;
                synchronized (cleaner) {
                    this.refCountDebug.add(new RefCountDebugItem("INC"));
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public final void delRef() {
            if (REF_COUNT_DEBUG && this.refCountDebug != null) {
                Cleaner cleaner = this;
                synchronized (cleaner) {
                    this.refCountDebug.add(new RefCountDebugItem("DEC"));
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public final void logRefCountDebug(String message) {
            if (REF_COUNT_DEBUG && this.refCountDebug != null) {
                Cleaner cleaner = this;
                synchronized (cleaner) {
                    log.error("{} (ID: {}): {}", new Object[]{message, this.id, MemoryCleaner.stringJoin("\n", this.refCountDebug)});
                }
            }
        }

        public final boolean clean(boolean logErrorIfNotClean) {
            return this.cleanImpl(logErrorIfNotClean && !this.leakExpected);
        }

        public final boolean isLeakExpected() {
            return this.leakExpected;
        }

        protected abstract boolean cleanImpl(boolean var1);

        public void noWarnLeakExpected() {
            this.leakExpected = true;
        }

        public abstract boolean isClean();
    }
}

