/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.flow.impl;

import lombok.NonNull;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.concurrency.EventsProvider;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SynchronousFlowController
implements FlowController {
    private static Logger log = LoggerFactory.getLogger(SynchronousFlowController.class);
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    protected EventsProvider eventsProvider = new EventsProvider();

    @Override
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override
    public void synchronizeToHost(AllocationPoint point) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(point.getPtrDataBuffer());
    }

    @Override
    public void synchronizeToDevice(@NonNull AllocationPoint point) {
        if (point == null) {
            throw new NullPointerException("point is marked non-null but is null");
        }
        NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToSpecial(point.getPtrDataBuffer());
    }

    @Override
    public void waitTillFinished(AllocationPoint point) {
        if (point.getLastWriteEvent() != null) {
            point.getLastWriteEvent().synchronize();
        }
    }

    @Override
    public CudaContext prepareActionAllWrite(INDArray ... operands) {
        CudaContext context = this.allocator.getDeviceContext();
        Integer cId = this.allocator.getDeviceId();
        for (INDArray operand : operands) {
            if (operand == null || operand.isEmpty()) continue;
            Nd4j.getCompressor().autoDecompress(operand);
            AllocationPoint pointData = this.allocator.getAllocationPoint(operand);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            if (pointData.getDeviceId() != cId.intValue() && pointData.getDeviceId() >= 0) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId.intValue() && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
            }
            this.prepareDelayedMemory(operand);
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public CudaContext prepareAction(INDArray result, INDArray ... operands) {
        CudaContext context = this.allocator.getDeviceContext();
        Integer cId = this.allocator.getDeviceId();
        if (result != null && !result.isEmpty()) {
            Nd4j.getCompressor().autoDecompress(result);
            this.prepareDelayedMemory(result);
            AllocationPoint pointData = this.allocator.getAllocationPoint(result);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(result.shapeInfoDataBuffer());
            if (!(pointData.getDeviceId() == cId.intValue() || pointData.getDeviceId() < 0 || CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId.intValue() && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)result).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(result.shape(), result.stride(), (long)result.elementWiseStride(), result.ordering(), result.dataType(), result.isEmpty()));
            }
            this.allocator.getAllocationPoint(result).setCurrentContext(context);
        }
        if (operands == null) {
            return context;
        }
        for (INDArray operand : operands) {
            if (operand == null || operand.isEmpty() || operand.isS()) continue;
            Nd4j.getCompressor().autoDecompress(operand);
            AllocationPoint pointData = this.allocator.getAllocationPoint(operand);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            Nd4j.getAffinityManager().ensureLocation(operand, AffinityManager.Location.DEVICE);
            if (!(pointData.getDeviceId() == cId.intValue() || pointData.getDeviceId() < 0 || CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId.intValue() && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)operand).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(operand.shape(), operand.stride(), (long)operand.elementWiseStride(), operand.ordering(), operand.dataType(), operand.isEmpty()));
            }
            this.prepareDelayedMemory(operand);
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void waitTillReleased(AllocationPoint point) {
        this.waitTillFinished(point);
        if (point.getLastReadEvent() != null) {
            point.getLastReadEvent().synchronize();
        }
    }

    @Override
    public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint ... operands) {
    }

    @Override
    public void registerActionAllWrite(CudaContext context, INDArray ... operands) {
        for (INDArray operand : operands) {
            if (operand == null) continue;
            AllocationPoint pointOperand = this.allocator.getAllocationPoint(operand);
            pointOperand.tickDeviceWrite();
        }
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        if (result == null || result.isEmpty()) {
            return;
        }
        AllocationPoint point = this.allocator.getAllocationPoint(result);
        point.tickDeviceWrite();
        for (INDArray operand : operands) {
            if (operand == null || operand.isEmpty()) continue;
            AllocationPoint pointOperand = this.allocator.getAllocationPoint(operand);
            pointOperand.tickDeviceRead();
        }
    }

    @Override
    public CudaContext prepareAction(AllocationPoint result, AllocationPoint ... operands) {
        CudaContext context = this.allocator.getDeviceContext();
        if (result != null) {
            result.setCurrentContext(context);
        }
        for (AllocationPoint operand : operands) {
            if (operand == null) continue;
            operand.setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void commitTransfer(cudaStream_t streamUsed) {
        streamUsed.synchronize();
    }

    protected void prepareDelayedMemory(INDArray array) {
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AllocationPoint pointData = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            AllocationPoint pointShape = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            if (pointData.getAllocationStatus() != AllocationStatus.DEVICE) {
                this.prepareDelayedMemory(array.data());
            }
            if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
                DataBuffer oShape = array.shapeInfoDataBuffer();
                DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
                if (nShape == oShape) {
                    Nd4j.getConstantHandler().moveToConstantSpace(nShape);
                }
                ((JCublasNDArray)array).setShapeInfoDataBuffer(nShape);
            }
        }
    }

    protected void prepareDelayedMemory(DataBuffer buffer) {
        this.allocator.getMemoryHandler().promoteObject(buffer);
    }

    @Override
    public EventsProvider getEventsProvider() {
        return this.eventsProvider;
    }
}

