/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.memory.impl;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaDirectProvider
implements MemoryProvider {
    protected static final long DEVICE_RESERVED_SPACE = 0x3200000L;
    private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected volatile ConcurrentHashMap<Long, Integer> validator = new ConcurrentHashMap();
    private AtomicLong emergencyCounter = new AtomicLong(0L);

    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
        switch (location) {
            case HOST: {
                Pointer pointer;
                long reqMem = AllocationUtils.getRequiredMemory(shape);
                if (reqMem < 1L) {
                    reqMem = 1L;
                }
                if ((pointer = this.nativeOps.mallocHost(reqMem, 0)) == null) {
                    throw new RuntimeException("Can't allocate [HOST] memory: " + reqMem + "; threadId: " + Thread.currentThread().getId());
                }
                CudaPointer hostPointer = new CudaPointer(pointer);
                PointersPair devicePointerInfo = new PointersPair();
                if (point.getPointers().getDevicePointer() == null) {
                    point.setAllocationStatus(AllocationStatus.HOST);
                    devicePointerInfo.setDevicePointer(new CudaPointer(hostPointer, reqMem));
                } else {
                    devicePointerInfo.setDevicePointer(point.getDevicePointer());
                }
                devicePointerInfo.setHostPointer(new CudaPointer(hostPointer, reqMem));
                point.setPointers(devicePointerInfo);
                MemoryTracker.getInstance().incrementAllocatedHostAmount(reqMem);
                return devicePointerInfo;
            }
            case DEVICE: {
                Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
                long reqMem = AllocationUtils.getRequiredMemory(shape);
                if (reqMem < 1L) {
                    reqMem = 1L;
                }
                AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, deviceId, reqMem);
                Pointer pointer = this.nativeOps.mallocDevice(reqMem, deviceId.intValue(), 0);
                if (pointer == null) {
                    this.purgeCache(deviceId);
                    Nd4j.getMemoryManager().invokeGc();
                    pointer = this.nativeOps.mallocDevice(reqMem, deviceId.intValue(), 0);
                    if (pointer == null) {
                        return null;
                    }
                }
                CudaPointer devicePointer = new CudaPointer(pointer);
                PointersPair devicePointerInfo = point.getPointers();
                if (devicePointerInfo == null) {
                    devicePointerInfo = new PointersPair();
                }
                devicePointerInfo.setDevicePointer(new CudaPointer(devicePointer, reqMem));
                point.setAllocationStatus(AllocationStatus.DEVICE);
                point.setDeviceId(deviceId);
                MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMem);
                return devicePointerInfo;
            }
        }
        throw new IllegalStateException("Unsupported location for malloc: [" + (Object)((Object)location) + "]");
    }

    @Override
    public void free(AllocationPoint point) {
        switch (point.getAllocationStatus()) {
            case HOST: {
                long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
                NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
                long result = nativeOps.freeHost(point.getPointers().getHostPointer());
                if (result == 0L) {
                    throw new RuntimeException("Can't deallocate [HOST] memory...");
                }
                MemoryTracker.getInstance().decrementAllocatedHostAmount(reqMem);
                break;
            }
            case DEVICE: {
                if (point.isConstant()) {
                    return;
                }
                long reqMem = AllocationUtils.getRequiredMemory(point.getShape());
                NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
                AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Integer.valueOf(point.getDeviceId()), reqMem);
                PointersPair pointers = point.getPointers();
                long result = nativeOps.freeDevice(pointers.getDevicePointer(), 0);
                if (result == 0L) {
                    throw new RuntimeException("Can't deallocate [DEVICE] memory...");
                }
                MemoryTracker.getInstance().decrementAllocatedAmount(point.getDeviceId(), reqMem);
                break;
            }
            default: {
                throw new IllegalStateException("Can't free memory on target [" + (Object)((Object)point.getAllocationStatus()) + "]");
            }
        }
    }

    @Override
    public boolean pingDeviceForFreeMemory(Integer deviceId, long requiredMemory) {
        long freeMem = this.nativeOps.getDeviceFreeMemory(-1);
        return freeMem - requiredMemory >= 0x3200000L;
    }

    protected void freeHost(Pointer pointer) {
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        nativeOps.freeHost(pointer);
    }

    protected void freeDevice(Pointer pointer, int deviceId) {
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        nativeOps.freeDevice(pointer, 0);
    }

    protected void purgeCache(int deviceId) {
    }

    @Override
    public void purgeCache() {
    }
}

