/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.tad;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.tad.BasicTADManager;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DeviceTADManager
extends BasicTADManager {
    private static final Logger log = LoggerFactory.getLogger(DeviceTADManager.class);
    protected List<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>> tadCache = new ArrayList<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>>();
    private Semaphore lock = new Semaphore(1);

    public DeviceTADManager() {
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numDevices; ++i) {
            this.tadCache.add(i, new ConcurrentHashMap());
        }
    }

    @Override
    public void purgeBuffers() {
        log.info("Purging TAD buffers...");
        this.tadCache = new ArrayList<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>>();
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numDevices; ++i) {
            log.info("Resetting device: [{}]", (Object)i);
            this.tadCache.add(i, new ConcurrentHashMap());
        }
        super.purgeBuffers();
    }

    @Override
    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension != null && dimension.length > 1) {
            Arrays.sort(dimension);
        }
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!this.tadCache.get(deviceId).containsKey(descriptor)) {
            log.trace("Creating new TAD...");
            Pair<DataBuffer, DataBuffer> buffers = super.getTADOnlyShapeInfo(array, dimension);
            this.tadCache.get(deviceId).put(descriptor, buffers);
            this.bytes.addAndGet(((DataBuffer)buffers.getFirst()).length() * 4L);
            if (buffers.getSecond() != null) {
                this.bytes.addAndGet(((DataBuffer)buffers.getSecond()).length() * 8L);
            }
            log.trace("Using TAD from cache...");
        }
        return this.tadCache.get(deviceId).get(descriptor);
    }
}

