/*
 * Decompiled with CFR 0.152.
 */
package ai.rapids.cudf.nvcomp;

import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.MemoryCleaner;
import ai.rapids.cudf.nvcomp.Decompressor;
import ai.rapids.cudf.nvcomp.NvcompJni;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BatchedLZ4Decompressor {
    private static final Logger log = LoggerFactory.getLogger(Decompressor.class);

    public static BatchedMetadata getMetadata(BaseDeviceMemoryBuffer[] inputs, Cuda.Stream stream) {
        long[] inputAddrs = new long[inputs.length];
        long[] inputSizes = new long[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            BaseDeviceMemoryBuffer buffer = inputs[i];
            inputAddrs[i] = buffer.getAddress();
            inputSizes[i] = buffer.getLength();
        }
        return new BatchedMetadata(NvcompJni.batchedLZ4DecompressGetMetadata(inputAddrs, inputSizes, stream.getStream()));
    }

    public static long getTempSize(BatchedMetadata metadata) {
        return NvcompJni.batchedLZ4DecompressGetTempSize(metadata.getMetadata());
    }

    public static long[] getOutputSizes(BatchedMetadata metadata, int numOutputs) {
        return NvcompJni.batchedLZ4DecompressGetOutputSize(metadata.getMetadata(), numOutputs);
    }

    public static void decompressAsync(BaseDeviceMemoryBuffer[] inputs, BaseDeviceMemoryBuffer tempBuffer, BatchedMetadata metadata, BaseDeviceMemoryBuffer[] outputs, Cuda.Stream stream) {
        int numBuffers = inputs.length;
        if (outputs.length != numBuffers) {
            throw new IllegalArgumentException("buffer count mismatch, " + numBuffers + " inputs and " + outputs.length + " outputs");
        }
        long[] inputAddrs = new long[numBuffers];
        long[] inputSizes = new long[numBuffers];
        for (int i = 0; i < numBuffers; ++i) {
            BaseDeviceMemoryBuffer buffer = inputs[i];
            inputAddrs[i] = buffer.getAddress();
            inputSizes[i] = buffer.getLength();
        }
        long[] outputAddrs = new long[numBuffers];
        long[] outputSizes = new long[numBuffers];
        for (int i = 0; i < numBuffers; ++i) {
            BaseDeviceMemoryBuffer buffer = outputs[i];
            outputAddrs[i] = buffer.getAddress();
            outputSizes[i] = buffer.getLength();
        }
        NvcompJni.batchedLZ4DecompressAsync(inputAddrs, inputSizes, tempBuffer.getAddress(), tempBuffer.getLength(), metadata.getMetadata(), outputAddrs, outputSizes, stream.getStream());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static DeviceMemoryBuffer[] decompressAsync(BaseDeviceMemoryBuffer[] inputs, Cuda.Stream stream) {
        int numBuffers = inputs.length;
        long[] inputAddrs = new long[numBuffers];
        long[] inputSizes = new long[numBuffers];
        for (int i = 0; i < numBuffers; ++i) {
            BaseDeviceMemoryBuffer buffer = inputs[i];
            inputAddrs[i] = buffer.getAddress();
            inputSizes[i] = buffer.getLength();
        }
        long metadata = NvcompJni.batchedLZ4DecompressGetMetadata(inputAddrs, inputSizes, stream.getStream());
        try {
            long[] outputSizes = NvcompJni.batchedLZ4DecompressGetOutputSize(metadata, numBuffers);
            long[] outputAddrs = new long[numBuffers];
            DeviceMemoryBuffer[] outputs = new DeviceMemoryBuffer[numBuffers];
            try {
                for (int i = 0; i < numBuffers; ++i) {
                    DeviceMemoryBuffer buffer;
                    outputs[i] = buffer = DeviceMemoryBuffer.allocate(outputSizes[i]);
                    outputAddrs[i] = buffer.getAddress();
                }
                long tempSize = NvcompJni.batchedLZ4DecompressGetTempSize(metadata);
                try (DeviceMemoryBuffer tempBuffer = DeviceMemoryBuffer.allocate(tempSize);){
                    NvcompJni.batchedLZ4DecompressAsync(inputAddrs, inputSizes, tempBuffer.getAddress(), tempBuffer.getLength(), metadata, outputAddrs, outputSizes, stream.getStream());
                }
            }
            catch (Throwable t) {
                for (DeviceMemoryBuffer buffer : outputs) {
                    if (buffer == null) continue;
                    buffer.close();
                }
                throw t;
            }
            DeviceMemoryBuffer[] deviceMemoryBufferArray = outputs;
            return deviceMemoryBufferArray;
        }
        finally {
            NvcompJni.batchedLZ4DecompressDestroyMetadata(metadata);
        }
    }

    public static class BatchedMetadata
    implements AutoCloseable {
        private final BatchedMetadataCleaner cleaner;
        private final long id;
        private boolean closed = false;

        BatchedMetadata(long metadata) {
            this.cleaner = new BatchedMetadataCleaner(metadata);
            this.id = this.cleaner.id;
            MemoryCleaner.register(this, (MemoryCleaner.Cleaner)this.cleaner);
            this.cleaner.addRef();
        }

        long getMetadata() {
            return this.cleaner.metadata;
        }

        public boolean isLZ4Metadata() {
            return NvcompJni.isLZ4Metadata(this.getMetadata());
        }

        @Override
        public synchronized void close() {
            if (this.closed) {
                this.cleaner.logRefCountDebug("double free " + this);
                throw new IllegalStateException("Close called too many times " + this);
            }
            this.cleaner.delRef();
            this.cleaner.clean(false);
            this.closed = true;
        }

        public String toString() {
            return "LZ4 BATCHED METADATA (ID: " + this.id + " " + Long.toHexString(this.cleaner.metadata) + ")";
        }

        private static class BatchedMetadataCleaner
        extends MemoryCleaner.Cleaner {
            private long metadata;

            BatchedMetadataCleaner(long metadata) {
                this.metadata = metadata;
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
                boolean neededCleanup = false;
                long address = this.metadata;
                if (this.metadata != 0L) {
                    try {
                        NvcompJni.batchedLZ4DecompressDestroyMetadata(this.metadata);
                    }
                    finally {
                        this.metadata = 0L;
                    }
                    neededCleanup = true;
                }
                if (neededCleanup && logErrorIfNotClean) {
                    log.error("LZ4 BATCHED METADATA WAS LEAKED (Address: " + Long.toHexString(address) + ")");
                    this.logRefCountDebug("Leaked event");
                }
                return neededCleanup;
            }

            @Override
            public boolean isClean() {
                return this.metadata != 0L;
            }
        }
    }
}

