/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.constant;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.constant.ConstantProtector;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.cache.ArrayDescriptor;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProtectedCudaConstantHandler
implements ConstantHandler {
    private static final Logger log = LoggerFactory.getLogger(ProtectedCudaConstantHandler.class);
    private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler();
    protected Map<Integer, AtomicLong> constantOffsets = new HashMap<Integer, AtomicLong>();
    protected Map<Integer, Semaphore> deviceLocks = new ConcurrentHashMap<Integer, Semaphore>();
    protected Map<Integer, Map<ArrayDescriptor, DataBuffer>> buffersCache = new HashMap<Integer, Map<ArrayDescriptor, DataBuffer>>();
    protected Map<Integer, Pointer> deviceAddresses = new HashMap<Integer, Pointer>();
    protected AtomicLong bytes = new AtomicLong(0L);
    protected FlowController flowController;
    protected static final ConstantProtector protector = ConstantProtector.getInstance();
    private static Logger logger = LoggerFactory.getLogger(ProtectedCudaConstantHandler.class);
    private static final int MAX_CONSTANT_LENGTH = 49152;
    private static final int MAX_BUFFER_LENGTH = 272;
    protected Semaphore lock = new Semaphore(1);
    private boolean resetHappened = false;

    public static ProtectedCudaConstantHandler getInstance() {
        return ourInstance;
    }

    private ProtectedCudaConstantHandler() {
    }

    @Override
    public void purgeConstants() {
        this.buffersCache = new HashMap<Integer, Map<ArrayDescriptor, DataBuffer>>();
        protector.purgeProtector();
        this.resetHappened = true;
        logger.info("Resetting Constants...");
        for (Integer device : this.constantOffsets.keySet()) {
            this.constantOffsets.get(device).set(0L);
            this.buffersCache.put(device, new ConcurrentHashMap());
        }
    }

    protected int amountOfEntries(int deviceId) {
        this.ensureMaps(deviceId);
        return this.buffersCache.get(0).size();
    }

    @Override
    public synchronized long moveToConstantSpace(DataBuffer dataBuffer) {
        throw new RuntimeException("This code shouldn't be called, ever");
    }

    @Override
    public DataBuffer relocateConstantSpace(DataBuffer dataBuffer) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        if (dataBuffer instanceof CudaIntDataBuffer) {
            int[] data = dataBuffer.asInt();
            return this.getConstantBuffer(data, DataType.INT);
        }
        if (dataBuffer instanceof CudaFloatDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return this.getConstantBuffer(data, DataType.FLOAT);
        }
        if (dataBuffer instanceof CudaDoubleDataBuffer) {
            double[] data = dataBuffer.asDouble();
            return this.getConstantBuffer(data, DataType.DOUBLE);
        }
        if (dataBuffer instanceof CudaHalfDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return this.getConstantBuffer(data, DataType.HALF);
        }
        if (dataBuffer instanceof CudaLongDataBuffer) {
            long[] data = dataBuffer.asLong();
            return this.getConstantBuffer(data, DataType.LONG);
        }
        throw new IllegalStateException("Unknown CudaDataBuffer opType");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void ensureMaps(Integer deviceId) {
        if (!this.buffersCache.containsKey(deviceId)) {
            if (this.flowController == null) {
                this.flowController = AtomicAllocator.getInstance().getFlowController();
            }
            try {
                ProtectedCudaConstantHandler protectedCudaConstantHandler = this;
                synchronized (protectedCudaConstantHandler) {
                    if (!this.buffersCache.containsKey(deviceId)) {
                        this.buffersCache.put(deviceId, new ConcurrentHashMap());
                        this.constantOffsets.put(deviceId, new AtomicLong(0L));
                        this.deviceLocks.put(deviceId, new Semaphore(1));
                        Pointer cAddr = NativeOpsHolder.getInstance().getDeviceNativeOps().getConstantSpace();
                        this.deviceAddresses.put(deviceId, cAddr);
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Override
    public DataBuffer getConstantBuffer(int[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    @Override
    public DataBuffer getConstantBuffer(float[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    @Override
    public DataBuffer getConstantBuffer(double[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    @Override
    public DataBuffer getConstantBuffer(long[] array, DataType type) {
        return Nd4j.getExecutioner().createConstantBuffer(array, type);
    }

    @Override
    public DataBuffer getConstantBuffer(boolean[] array, DataType dataType) {
        return this.getConstantBuffer(ArrayUtil.toLongs(array), dataType);
    }

    @Override
    public long getCachedBytes() {
        return this.bytes.get();
    }
}

