/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.compression;

import java.util.ArrayList;
import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.compression.impl.AbstractCompressor;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaThreshold
extends AbstractCompressor {
    private static final Logger log = LoggerFactory.getLogger(CudaThreshold.class);
    protected float threshold = 0.001f;

    public String getDescriptor() {
        return "THRESHOLD";
    }

    public void configure(Object ... vars) {
        if (!(vars[0] instanceof Number)) {
            throw new ND4JIllegalStateException("Threshold value should be Number");
        }
        Number t = (Number)vars[0];
        this.threshold = FastMath.abs((float)t.floatValue());
        log.info("Setting threshold to [{}]", (Object)Float.valueOf(this.threshold));
    }

    public INDArray compress(INDArray array) {
        Nd4j.getExecutioner().commit();
        DataBuffer buffer = this.compress(array.data());
        if (buffer == null) {
            return null;
        }
        INDArray dup = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)array.shapeInfoDataBuffer());
        dup.markAsCompressed(true);
        return dup;
    }

    public CompressionType getCompressionType() {
        return CompressionType.LOSSLESS;
    }

    public DataBuffer decompress(DataBuffer buffer, DataType type) {
        if (buffer.dataType() != DataType.INT) {
            throw new UnsupportedOperationException();
        }
        long compressedLength = buffer.getInt(0L);
        long originalLength = buffer.getInt(1L);
        DataBuffer result = Nd4j.createBuffer((DataType)type, (long)originalLength, (boolean)false);
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        PointerPointer extras = new PointerPointer(32L).put(1L, (Pointer)context.getOldStream());
        AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
        return result;
    }

    public DataBuffer compress(DataBuffer buffer) {
        int numPrefixBlocks;
        DataBuffer tempX;
        int numPrefixBlocks2;
        int numThreads = 1024;
        int numBlocks = (int)(buffer.length() / (long)numThreads + (long)(buffer.length() % (long)numThreads == 0L ? 0 : 1));
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        DataBuffer blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true) : Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true, Nd4j.getMemoryManager().getCurrentWorkspace());
        PointerPointer extras = new PointerPointer(32L).put(1L, (Pointer)context.getOldStream());
        AtomicAllocator.getInstance().getAllocationPoint(blocksBuffer).tickDeviceWrite();
        int numMatches = blocksBuffer.getInt(0L);
        DataBuffer encodedBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(3 + numMatches), false) : Nd4j.getDataBufferFactory().createInt((long)(3 + numMatches), false, Nd4j.getMemoryManager().getCurrentWorkspace());
        encodedBuffer.put(0L, numMatches);
        encodedBuffer.put(1L, (int)buffer.length());
        encodedBuffer.put(2L, Float.floatToIntBits(this.threshold));
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite();
        int prefixThreads = 512;
        int numElts = numBlocks;
        int level = 0;
        ArrayList<DataBuffer> buffers = new ArrayList<DataBuffer>();
        do {
            numPrefixBlocks2 = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)));
            if (numBlocks <= 1) continue;
            ++level;
        } while ((numElts = numPrefixBlocks2) > 1);
        long[] pointers = new long[level];
        level = 0;
        numElts = numBlocks;
        DataBuffer dataBuffer = tempX = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false) : Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            if ((numPrefixBlocks = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)))) <= 1) continue;
            DataBuffer bf = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false) : Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false, Nd4j.getMemoryManager().getCurrentWorkspace());
            buffers.add(bf);
            pointers[level++] = AtomicAllocator.getInstance().getPointer(bf).address();
        } while ((numElts = numPrefixBlocks) > 1);
        AtomicAllocator.getInstance().memcpyBlocking(tempX, (Pointer)new LongPointer(pointers), pointers.length * 8, 0L);
        extras.put(2L, AtomicAllocator.getInstance().getPointer(tempX));
        DataBuffer offsetsBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numBlocks, true) : Nd4j.getDataBufferFactory().createInt((long)numBlocks, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(extras, (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (long)numBlocks, (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        extras.address();
        tempX.address();
        return encodedBuffer;
    }

    protected CompressedDataBuffer compressPointer(DataTypeEx srcType, Pointer srcPointer, int length, int elementSize) {
        throw new UnsupportedOperationException();
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float threshold) {
        this.threshold = threshold;
    }
}

