/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.Aggressiveness;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.time.Ring;
import org.nd4j.jita.allocator.time.rings.LockedRing;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.constant.ConstantProtector;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.handler.impl.CudaZeroHandler;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AtomicAllocator
implements Allocator {
    private static final AtomicAllocator INSTANCE = new AtomicAllocator();
    private Configuration configuration;
    private transient MemoryHandler memoryHandler;
    private AtomicLong allocationsCounter = new AtomicLong(0L);
    private AtomicLong objectsTracker = new AtomicLong(0L);
    private Map<Long, AllocationPoint> allocationsMap = new ConcurrentHashMap<Long, AllocationPoint>();
    private static Logger log = LoggerFactory.getLogger(AtomicAllocator.class);
    private ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();
    private ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock();
    private final AtomicBoolean shouldStop = new AtomicBoolean(false);
    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
    private final Ring deviceLong = new LockedRing(30);
    private final Ring deviceShort = new LockedRing(30);
    private final Ring zeroLong = new LockedRing(30);
    private final Ring zeroShort = new LockedRing(30);
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    private AtomicLong useTracker = new AtomicLong(System.currentTimeMillis());
    protected static ConstantProtector protector;

    public static AtomicAllocator getInstance() {
        if (INSTANCE == null) {
            throw new RuntimeException("AtomicAllocator is NULL");
        }
        return INSTANCE;
    }

    private AtomicAllocator() {
        this.configuration = CudaEnvironment.getInstance().getConfiguration();
        this.applyConfiguration();
        this.memoryHandler = new CudaZeroHandler();
        this.memoryHandler.init(this.configuration, this);
        protector = ConstantProtector.getInstance();
    }

    protected Map<Long, AllocationPoint> allocationsMap() {
        return this.allocationsMap;
    }

    public void applyConfiguration() {
        CudaEnvironment.getInstance().notifyConfigurationApplied();
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(this.configuration.isDebug());
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(this.configuration.isVerbose());
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableP2P(this.configuration.isCrossDeviceAccessAllowed());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setGridLimit(this.configuration.getMaximumGridSize());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setOmpNumThreads(this.configuration.getMaximumBlockSize());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setOmpMinThreads(this.configuration.getMinimumBlockSize());
    }

    protected void initDeviceCollectors() {
    }

    @Override
    public CudaContext getDeviceContext() {
        return this.memoryHandler.getDeviceContext();
    }

    @Override
    public void setMemoryHandler(@NonNull MemoryHandler memoryHandler) {
        if (memoryHandler == null) {
            throw new NullPointerException("memoryHandler is marked non-null but is null");
        }
        this.globalLock.writeLock().lock();
        this.memoryHandler = memoryHandler;
        this.memoryHandler.init(this.configuration, this);
        this.globalLock.writeLock().unlock();
    }

    @Override
    public void applyConfiguration(@NonNull Configuration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (!this.wasInitialised.get()) {
            this.globalLock.writeLock().lock();
            this.configuration = configuration;
            this.globalLock.writeLock().unlock();
        }
    }

    @Override
    public Configuration getConfiguration() {
        try {
            this.globalLock.readLock().lock();
            Configuration configuration = this.configuration;
            return configuration;
        }
        finally {
            this.globalLock.readLock().unlock();
        }
    }

    @Override
    public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
        if (buffer == null) {
            throw new NullPointerException("buffer is marked non-null but is null");
        }
        return this.memoryHandler.getDevicePointer(buffer, context);
    }

    public Pointer getPointer(DataBuffer buffer) {
        return this.memoryHandler.getDevicePointer(buffer, this.getDeviceContext());
    }

    @Override
    @Deprecated
    public Pointer getPointer(DataBuffer buffer, AllocationShape shape, boolean isView, CudaContext context) {
        return this.memoryHandler.getDevicePointer(buffer, context);
    }

    @Override
    public Pointer getPointer(INDArray array, CudaContext context) {
        if (array.isEmpty() || array.isS()) {
            throw new UnsupportedOperationException("Pew-pew");
        }
        return this.memoryHandler.getDevicePointer(array.data(), context);
    }

    @Override
    public Pointer getHostPointer(INDArray array) {
        if (array.isEmpty()) {
            return null;
        }
        this.synchronizeHostData(array);
        return this.memoryHandler.getHostPointer(array.data());
    }

    @Override
    public Pointer getHostPointer(DataBuffer buffer) {
        return this.memoryHandler.getHostPointer(buffer);
    }

    @Override
    public void synchronizeHostData(INDArray array) {
        if (array.isEmpty() || array.isS()) {
            return;
        }
        DataBuffer buffer = array.data().originalDataBuffer() == null ? array.data() : array.data().originalDataBuffer();
        this.synchronizeHostData(buffer);
    }

    @Override
    public void synchronizeHostData(DataBuffer buffer) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer)buffer).getOpaqueDataBuffer());
    }

    public Integer getDeviceId(INDArray array) {
        return this.getAllocationPoint(array).getDeviceId();
    }

    public void freeMemory(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            this.getMemoryHandler().getMemoryProvider().free(point);
            if (point.getHostPointer() != null) {
                point.setAllocationStatus(AllocationStatus.HOST);
                this.getMemoryHandler().getMemoryProvider().free(point);
                this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
            }
        } else if (point.getHostPointer() != null) {
            this.getMemoryHandler().getMemoryProvider().free(point);
            this.getMemoryHandler().forget(point, AllocationStatus.HOST);
        }
        this.allocationsMap.remove(point.getObjectId());
    }

    @Override
    public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, boolean initialize) {
        AllocationPoint point = null;
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
            point = this.allocateMemory(buffer, requiredMemory, this.memoryHandler.getInitialLocation(), initialize);
        } else if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            point = this.allocateMemory(buffer, requiredMemory, AllocationStatus.HOST, initialize);
        }
        return point;
    }

    public AllocationPoint pickExternalBuffer(DataBuffer buffer) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override
    public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    protected AllocationPoint getAllocationPoint(@NonNull Long objectId) {
        if (objectId == null) {
            throw new NullPointerException("objectId is marked non-null but is null");
        }
        return this.allocationsMap.get(objectId);
    }

    protected void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {
        this.allocationsMap.remove(objectId);
        this.memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);
    }

    protected void purgeDeviceObject(Long threadId, Integer deviceId, Long objectId, AllocationPoint point, boolean copyback) {
        this.memoryHandler.purgeDeviceObject(threadId, deviceId, objectId, point, copyback);
    }

    protected synchronized long seekUnusedZero(Long bucketId, Aggressiveness aggressiveness) {
        AtomicLong freeSpace = new AtomicLong(0L);
        int totalElements = (int)this.memoryHandler.getAllocatedHostObjects(bucketId);
        float shortAverage = this.zeroShort.getAverage();
        float longAverage = this.zeroLong.getAverage();
        float shortThreshold = shortAverage / (float)(Aggressiveness.values().length - aggressiveness.ordinal());
        float longThreshold = longAverage / (float)(Aggressiveness.values().length - aggressiveness.ordinal());
        AtomicInteger elementsDropped = new AtomicInteger(0);
        AtomicInteger elementsSurvived = new AtomicInteger(0);
        for (Long object : this.memoryHandler.getHostTrackingPoints(bucketId)) {
            AllocationPoint point = this.getAllocationPoint(object);
            if (point == null || point.getAllocationStatus() != AllocationStatus.HOST) continue;
            if (point.getBuffer() == null) {
                this.purgeZeroObject(bucketId, object, point, false);
                throw new UnsupportedOperationException("Pew-pew");
            }
            elementsSurvived.incrementAndGet();
        }
        log.debug("Zero {} elements checked: [{}], deleted: {}, survived: {}", bucketId, totalElements, elementsDropped.get(), elementsSurvived.get());
        return freeSpace.get();
    }

    protected long seekUnusedDevice(Long threadId, Integer deviceId, Aggressiveness aggressiveness) {
        AtomicLong freeSpace = new AtomicLong(0L);
        float shortAverage = this.deviceShort.getAverage();
        float longAverage = this.deviceLong.getAverage();
        float shortThreshold = shortAverage / (float)(Aggressiveness.values().length - aggressiveness.ordinal());
        float longThreshold = longAverage / (float)(Aggressiveness.values().length - aggressiveness.ordinal());
        AtomicInteger elementsDropped = new AtomicInteger(0);
        AtomicInteger elementsMoved = new AtomicInteger(0);
        AtomicInteger elementsSurvived = new AtomicInteger(0);
        for (Long object : this.memoryHandler.getDeviceTrackingPoints(deviceId)) {
            AllocationPoint point = this.getAllocationPoint(object);
            if (point.getBuffer() == null) {
                if (point.getAllocationStatus() != AllocationStatus.DEVICE) continue;
                this.purgeDeviceObject(threadId, deviceId, object, point, false);
                throw new UnsupportedOperationException("Pew-pew");
            }
            elementsSurvived.incrementAndGet();
        }
        log.debug("Thread/Device [" + threadId + "/" + deviceId + "] elements purged: [" + elementsDropped.get() + "]; Relocated: [" + elementsMoved.get() + "]; Survivors: [" + elementsSurvived.get() + "]");
        return freeSpace.get();
    }

    public long getTotalAllocatedHostMemory() {
        return 0L;
    }

    protected int getTotalTrackingPoints() {
        return this.allocationsMap.size();
    }

    public long getTotalAllocatedDeviceMemory(Integer deviceId) {
        return 0L;
    }

    @Override
    public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpyAsync(dstBuffer, srcPointer, length, dstOffset);
    }

    @Override
    public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpySpecial(dstBuffer, srcPointer, length, dstOffset);
    }

    @Override
    public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset, CudaContext context) {
        this.memoryHandler.memcpyDevice(dstBuffer, srcPointer, length, dstOffset, context);
    }

    @Override
    public void memcpyBlocking(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpyBlocking(dstBuffer, srcPointer, length, dstOffset);
    }

    @Override
    public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
        this.memoryHandler.memcpy(dstBuffer, srcBuffer);
    }

    @Override
    public void tickHostWrite(DataBuffer buffer) {
        this.getAllocationPoint(buffer).tickHostWrite();
    }

    @Override
    public void tickHostWrite(INDArray array) {
        this.getAllocationPoint(array.data()).tickHostWrite();
    }

    @Override
    public void tickDeviceWrite(INDArray array) {
        this.getAllocationPoint(array.data()).tickDeviceWrite();
    }

    @Override
    public AllocationPoint getAllocationPoint(INDArray array) {
        return this.getAllocationPoint(array.data());
    }

    @Override
    public AllocationPoint getAllocationPoint(DataBuffer buffer) {
        return ((BaseCudaDataBuffer)buffer).getAllocationPoint();
    }

    @Override
    public Integer getDeviceId() {
        return this.memoryHandler.getDeviceId();
    }

    @Override
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(this.getDeviceId().intValue());
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        this.memoryHandler.registerAction(context, result, operands);
    }

    @Override
    public FlowController getFlowController() {
        return this.memoryHandler.getFlowController();
    }

    @Override
    public DataBuffer getConstantBuffer(int[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.INT);
    }

    @Override
    public DataBuffer getConstantBuffer(float[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.FLOAT);
    }

    @Override
    public DataBuffer getConstantBuffer(double[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.DOUBLE);
    }

    @Override
    public DataBuffer moveToConstant(DataBuffer dataBuffer) {
        Nd4j.getConstantHandler().moveToConstantSpace(dataBuffer);
        return dataBuffer;
    }

    @Override
    public MemoryHandler getMemoryHandler() {
        return this.memoryHandler;
    }

    private class DeviceGarbageCollectorThread
    extends Thread
    implements Runnable {
        private final Integer deviceId;
        private final AtomicBoolean terminate;

        public DeviceGarbageCollectorThread(Integer deviceId, AtomicBoolean terminate) {
            this.deviceId = deviceId;
            this.terminate = terminate;
            this.setName("device gc thread [" + deviceId + "]");
            this.setDaemon(true);
        }

        @Override
        public void run() {
            log.info("Starting device GC for device: " + this.deviceId);
            long lastCheck = System.currentTimeMillis();
            while (!this.terminate.get()) {
                try {
                    Thread.sleep(Math.max(AtomicAllocator.this.configuration.getMinimumTTLMilliseconds(), 5000L));
                }
                catch (Exception exception) {
                    // empty catch block
                }
                Aggressiveness aggressiveness = AtomicAllocator.this.configuration.getGpuDeallocAggressiveness();
                if ((AtomicAllocator.this.memoryHandler.getAllocatedDeviceObjects(this.deviceId) > 100000L || (double)AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) > (double)AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.75) && aggressiveness.ordinal() < Aggressiveness.URGENT.ordinal()) {
                    aggressiveness = Aggressiveness.URGENT;
                }
                if ((double)AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) > (double)AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.85) {
                    aggressiveness = Aggressiveness.IMMEDIATE;
                }
                if ((double)AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) < (double)AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.25 && AtomicAllocator.this.memoryHandler.getAllocatedDeviceObjects(this.deviceId) < 500L && lastCheck > System.currentTimeMillis() - 30000L) continue;
                AtomicAllocator.this.seekUnusedDevice(0L, this.deviceId, aggressiveness);
                lastCheck = System.currentTimeMillis();
            }
        }
    }

    private class ZeroGarbageCollectorThread
    extends Thread
    implements Runnable {
        private final Long bucketId;
        private final AtomicBoolean terminate;

        public ZeroGarbageCollectorThread(Long bucketId, AtomicBoolean terminate) {
            this.bucketId = bucketId;
            this.terminate = terminate;
            this.setName("zero gc thread " + bucketId);
            this.setDaemon(true);
        }

        @Override
        public void run() {
            log.debug("Starting zero GC for thread: " + this.bucketId);
            long lastCheck = System.currentTimeMillis();
            while (!this.terminate.get()) {
                try {
                    Thread.sleep(Math.max(AtomicAllocator.this.configuration.getMinimumTTLMilliseconds(), 10000L));
                }
                catch (Exception exception) {
                    // empty catch block
                }
                Aggressiveness aggressiveness = AtomicAllocator.this.configuration.getHostDeallocAggressiveness();
                if ((AtomicAllocator.this.memoryHandler.getAllocatedHostObjects(this.bucketId) > 500000L || (double)AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() > (double)AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.75) && aggressiveness.ordinal() < Aggressiveness.URGENT.ordinal()) {
                    aggressiveness = Aggressiveness.URGENT;
                }
                if ((double)AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() > (double)AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.85) {
                    aggressiveness = Aggressiveness.IMMEDIATE;
                }
                if ((double)AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() < (double)AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.25 && AtomicAllocator.this.memoryHandler.getAllocatedHostObjects(this.bucketId) < 5000L && lastCheck > System.currentTimeMillis() - 30000L) continue;
                AtomicAllocator.this.seekUnusedZero(this.bucketId, aggressiveness);
                lastCheck = System.currentTimeMillis();
            }
        }
    }
}

