/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.memory.impl;

import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaFullCachingProvider
extends CudaCachingZeroProvider {
    protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CudaCachingZeroProvider.CacheHolder>> deviceCache = new ConcurrentHashMap();
    private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);

    public CudaFullCachingProvider() {
        this.init();
    }

    public void init() {
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        this.deviceCachedAmount = new ArrayList();
        for (int i = 0; i < numDevices; ++i) {
            this.deviceCachedAmount.add(new AtomicLong(0L));
        }
    }

    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
        long reqMemory = AllocationUtils.getRequiredMemory(shape);
        if (location == AllocationStatus.DEVICE && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceAllocation()) {
            Pointer pointer;
            Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
            this.ensureDeviceCacheHolder(deviceId, shape);
            CudaCachingZeroProvider.CacheHolder cache = this.deviceCache.get(deviceId).get(shape);
            if (cache != null && (pointer = cache.poll()) != null) {
                this.cacheDeviceHit.incrementAndGet();
                ((AtomicLong)this.deviceCachedAmount.get(deviceId)).addAndGet(-reqMemory);
                PointersPair pair = new PointersPair();
                pair.setDevicePointer(pointer);
                point.setAllocationStatus(AllocationStatus.DEVICE);
                point.setDeviceId(deviceId);
                MemoryTracker.getInstance().incrementAllocatedAmount(deviceId, reqMemory);
                MemoryTracker.getInstance().decrementCachedAmount(deviceId, reqMemory);
                return pair;
            }
            this.cacheDeviceMiss.incrementAndGet();
            return super.malloc(shape, point, location);
        }
        return super.malloc(shape, point, location);
    }

    @Override
    public void free(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            if (point.isConstant()) {
                return;
            }
            AllocationShape shape = point.getShape();
            int deviceId = point.getDeviceId();
            long address = point.getDevicePointer().address();
            long reqMemory = AllocationUtils.getRequiredMemory(shape);
            if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCacheableLength() || ((AtomicLong)this.deviceCachedAmount.get(deviceId)).get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumDeviceCache()) {
                super.free(point);
                return;
            }
            this.ensureDeviceCacheHolder(deviceId, shape);
            CudaCachingZeroProvider.CacheHolder cache = this.deviceCache.get(deviceId).get(shape);
            if (point.getDeviceId() != deviceId) {
                throw new RuntimeException("deviceId changed!");
            }
            if (reqMemory <= 96L) {
                cache.put(new CudaPointer(point.getDevicePointer().address()));
                MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
                MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
                return;
            }
            cache.put(new CudaPointer(point.getDevicePointer().address()));
            MemoryTracker.getInstance().incrementCachedAmount(deviceId, reqMemory);
            MemoryTracker.getInstance().decrementAllocatedAmount(deviceId, reqMemory);
            return;
        }
        super.free(point);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
        if (!this.deviceCache.containsKey(deviceId)) {
            try {
                CudaFullCachingProvider cudaFullCachingProvider = this;
                synchronized (cudaFullCachingProvider) {
                    if (!this.deviceCache.containsKey(deviceId)) {
                        this.deviceCache.put(deviceId, new ConcurrentHashMap());
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (!this.deviceCache.get(deviceId).containsKey(shape)) {
            try {
                this.singleLock.acquire();
                if (!this.deviceCache.get(deviceId).containsKey(shape)) {
                    this.deviceCache.get(deviceId).put(shape, new CudaCachingZeroProvider.CacheHolder(shape, (AtomicLong)this.deviceCachedAmount.get(deviceId)));
                }
            }
            catch (Exception exception) {
            }
            finally {
                this.singleLock.release();
            }
        }
    }

    @Override
    protected synchronized void purgeCache(int deviceId) {
        for (AllocationShape shape : this.deviceCache.get(deviceId).keySet()) {
            Pointer ptr = null;
            while ((ptr = this.deviceCache.get(deviceId).get(shape).poll()) != null) {
                this.freeDevice(ptr, deviceId);
                MemoryTracker.getInstance().decrementCachedAmount(deviceId, shape.getNumberOfBytes());
            }
        }
        ((AtomicLong)this.deviceCachedAmount.get(deviceId)).set(0L);
    }

    @Override
    public synchronized void purgeCache() {
        for (Integer device : this.deviceCache.keySet()) {
            this.purgeCache(device);
        }
        super.purgeCache();
    }
}

