/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.concurrency;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DeviceAllocationsTracker {
    private Configuration configuration;
    private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();
    private final Map<Integer, ReentrantReadWriteLock> deviceLocks = new ConcurrentHashMap<Integer, ReentrantReadWriteLock>();
    private final Map<Integer, AtomicLong> memoryTackled = new ConcurrentHashMap<Integer, AtomicLong>();
    private final Map<Integer, AtomicLong> reservedSpace = new ConcurrentHashMap<Integer, AtomicLong>();
    private static Logger log = LoggerFactory.getLogger(DeviceAllocationsTracker.class);

    public DeviceAllocationsTracker(@NonNull Configuration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.configuration = configuration;
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int device = 0; device < numDevices; ++device) {
            this.deviceLocks.put(device, new ReentrantReadWriteLock());
        }
    }

    protected void ensureThreadRegistered(Long threadId, Integer deviceId) {
        this.globalLock.readLock().lock();
        this.globalLock.readLock().unlock();
        if (!this.memoryTackled.containsKey(deviceId)) {
            this.globalLock.writeLock().lock();
            if (!this.memoryTackled.containsKey(deviceId)) {
                this.memoryTackled.put(deviceId, new AtomicLong(0L));
            }
            if (!this.reservedSpace.containsKey(deviceId)) {
                this.reservedSpace.put(deviceId, new AtomicLong(0L));
            }
            this.globalLock.writeLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long addToAllocation(@NonNull Long threadId, Integer deviceId, long memorySize) {
        if (threadId == null) {
            throw new NullPointerException("threadId is marked non-null but is null");
        }
        this.ensureThreadRegistered(threadId, deviceId);
        try {
            this.deviceLocks.get(deviceId).readLock().lock();
            long res = this.memoryTackled.get(deviceId).addAndGet(memorySize);
            this.subFromReservedSpace(deviceId, memorySize);
            long l = res;
            return l;
        }
        finally {
            this.deviceLocks.get(deviceId).readLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long subFromAllocation(Long threadId, Integer deviceId, long memorySize) {
        this.ensureThreadRegistered(threadId, deviceId);
        try {
            this.deviceLocks.get(deviceId).writeLock().lock();
            AtomicLong val2 = this.memoryTackled.get(deviceId);
            val2.addAndGet(memorySize * -1L);
            long l = val2.get();
            return l;
        }
        finally {
            this.deviceLocks.get(deviceId).writeLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public boolean reserveAllocationIfPossible(Long threadId, Integer deviceId, long memorySize) {
        this.ensureThreadRegistered(threadId, deviceId);
        try {
            this.deviceLocks.get(deviceId).writeLock().lock();
            this.addToReservedSpace(deviceId, memorySize);
            boolean bl = true;
            return bl;
        }
        finally {
            this.deviceLocks.get(deviceId).writeLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long getAllocatedSize(Long threadId, Integer deviceId) {
        this.ensureThreadRegistered(threadId, deviceId);
        try {
            this.deviceLocks.get(deviceId).readLock().lock();
            long l = this.getAllocatedSize(deviceId);
            return l;
        }
        finally {
            this.deviceLocks.get(deviceId).readLock().unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public long getAllocatedSize(Integer deviceId) {
        if (!this.memoryTackled.containsKey(deviceId)) {
            return 0L;
        }
        try {
            this.deviceLocks.get(deviceId).readLock().lock();
            long l = this.memoryTackled.get(deviceId).get();
            return l;
        }
        finally {
            this.deviceLocks.get(deviceId).readLock().unlock();
        }
    }

    public long getReservedSpace(Integer deviceId) {
        return this.reservedSpace.get(deviceId).get();
    }

    protected void addToReservedSpace(Integer deviceId, long memorySize) {
        this.ensureThreadRegistered(Thread.currentThread().getId(), deviceId);
        this.reservedSpace.get(deviceId).addAndGet(memorySize);
    }

    protected void subFromReservedSpace(Integer deviceId, long memorySize) {
        this.ensureThreadRegistered(Thread.currentThread().getId(), deviceId);
        this.reservedSpace.get(deviceId).addAndGet(memorySize * -1L);
    }
}

