/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas;

import java.io.ObjectStreamException;
import java.util.List;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.BaseNDArrayProxy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.JvmShapeInfo;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JCublasNDArray
extends BaseNDArray {
    private static final Logger log = LoggerFactory.getLogger(JCublasNDArray.class);

    public JCublasNDArray(DataBuffer buffer, CudaLongDataBuffer shapeInfo, long[] javaShapeInfo) {
        this.jvmShapeInfo = new JvmShapeInfo(javaShapeInfo);
        this.shapeInformation = shapeInfo;
        this.data = buffer;
    }

    public JCublasNDArray(double[][] data) {
        super(data);
    }

    public JCublasNDArray(double[][] data, char ordering) {
        super(data, ordering);
    }

    public JCublasNDArray(int[] shape, DataBuffer buffer) {
        super(shape, buffer);
    }

    public JCublasNDArray(float[] data, int[] shape, char ordering) {
        super(data, shape, ordering);
    }

    public JCublasNDArray(float[] data, int[] shape, long offset, char ordering) {
        super(data, shape, offset, ordering);
    }

    public JCublasNDArray(int[] shape, int[] stride, long offset, char ordering) {
        super(shape, stride, offset, ordering);
    }

    public JCublasNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) {
        super(shape, stride, offset, ordering, initialize);
    }

    public JCublasNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) {
        super(shape, stride, offset, ordering, initialize);
    }

    public JCublasNDArray(int[] shape, int[] stride, char ordering) {
        super(shape, stride, ordering);
    }

    public JCublasNDArray(int[] shape, long offset, char ordering) {
        super(shape, offset, ordering);
    }

    public JCublasNDArray(long[] shape, long offset, char ordering) {
        super(shape, offset, ordering);
    }

    public JCublasNDArray(int[] shape) {
        super(shape);
    }

    public JCublasNDArray(long[] shape) {
        super(shape);
    }

    public JCublasNDArray(int newRows, int newColumns, char ordering) {
        super(newRows, newColumns, ordering);
    }

    public JCublasNDArray(List<INDArray> slices, int[] shape, char ordering) {
        super(slices, shape, ordering);
    }

    public JCublasNDArray(List<INDArray> slices, long[] shape, char ordering) {
        super(slices, shape, ordering);
    }

    public JCublasNDArray(List<INDArray> slices, int[] shape, int[] stride, char ordering) {
        super(slices, shape, stride, ordering);
    }

    public JCublasNDArray(float[] data, int[] shape, int[] stride, char ordering) {
        super(data, shape, stride, ordering);
    }

    public JCublasNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) {
        super(data, shape, stride, offset, ordering);
    }

    public JCublasNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) {
        super(data, shape, stride, offset, ordering);
    }

    public JCublasNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) {
        super(data, shape, stride, offset, ordering);
    }

    public JCublasNDArray(int[] data, int[] shape, int[] strides) {
        super(data, shape, strides);
    }

    public JCublasNDArray(DataBuffer data, int[] shape) {
        super(data, shape);
    }

    public JCublasNDArray(DataBuffer data, long[] shape) {
        super(data, shape);
    }

    public JCublasNDArray(DataBuffer buffer, int[] shape, long offset) {
        super(buffer, shape, offset);
    }

    public JCublasNDArray(float[] data, int[] shape) {
        super(data, shape);
    }

    public JCublasNDArray(float[] data, int[] shape, long offset) {
        super(data, shape, offset);
    }

    public JCublasNDArray(int[] shape, int[] stride, long offset) {
        super(shape, stride, offset);
    }

    public JCublasNDArray(int[] shape, int[] stride) {
        super(shape, stride);
    }

    public JCublasNDArray(int[] shape, long offset) {
        super(shape, offset);
    }

    public JCublasNDArray(int[] shape, char ordering) {
        super(shape, ordering);
    }

    public JCublasNDArray(int newRows, int newColumns) {
        super(newRows, newColumns);
    }

    public JCublasNDArray(List<INDArray> slices, int[] shape) {
        super(slices, shape);
    }

    public JCublasNDArray(List<INDArray> slices, long[] shape) {
        super(slices, shape);
    }

    public JCublasNDArray(List<INDArray> slices, int[] shape, int[] stride) {
        super(slices, shape, stride);
    }

    public JCublasNDArray(float[] data, int[] shape, int[] stride) {
        super(data, shape, stride);
    }

    public JCublasNDArray(float[] data, int[] shape, int[] stride, long offset) {
        super(data, shape, stride, offset);
    }

    public JCublasNDArray(float[] data) {
        super(data);
    }

    public JCublasNDArray(JCublasNDArray doubleMatrix) {
        this(new long[]{doubleMatrix.rows(), doubleMatrix.columns()});
        this.data = this.dup().data();
    }

    public JCublasNDArray(double[] data, int[] shape, int[] stride, long offset) {
        super(data, shape, stride, offset);
    }

    public JCublasNDArray(float[][] floats) {
        super(floats);
    }

    public JCublasNDArray(float[][] data, char ordering) {
        super(data, ordering);
    }

    public JCublasNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) {
        super(buffer, shape, offset, ordering);
    }

    public JCublasNDArray() {
    }

    public JCublasNDArray(DataBuffer buffer) {
        super(buffer);
    }

    public JCublasNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) {
        super(buffer, shape, stride, offset, ordering);
    }

    public JCublasNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, DataType dataType) {
        super(buffer, shape, stride, offset, ordering, dataType);
    }

    public JCublasNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering, DataType dataType) {
        super(buffer, shape, stride, offset, ews, ordering, dataType);
    }

    public JCublasNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType dataType) {
        super(buffer, shape, stride, ordering, dataType);
    }

    public JCublasNDArray(float[] data, char order) {
        super(data, order);
    }

    public JCublasNDArray(FloatBuffer floatBuffer, char order) {
        super((DataBuffer)floatBuffer, order);
    }

    public JCublasNDArray(DataBuffer buffer, int[] shape, int[] strides) {
        super(buffer, shape, strides);
    }

    public JCublasNDArray(double[] data, int[] shape, char ordering) {
        super(data, shape, ordering);
    }

    public JCublasNDArray(double[] data, long[] shape, char ordering) {
        super(data, shape, ordering);
    }

    public JCublasNDArray(float[] data, long[] shape, char ordering) {
        super(data, shape, ordering);
    }

    public JCublasNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) {
        super(data, shape, stride, offset, ordering);
    }

    public INDArray dup() {
        if (this.isCompressed() && this.ordering() == Nd4j.order().charValue()) {
            INDArray ret = Nd4j.createArrayFromShapeBuffer((DataBuffer)this.data().dup(), (DataBuffer)this.shapeInfoDataBuffer());
            ret.markAsCompressed(true);
            return ret;
        }
        INDArray res = super.dup();
        Nd4j.getExecutioner().commit();
        return res;
    }

    public INDArray dup(char order) {
        if (this.isCompressed() && this.ordering() == order) {
            INDArray ret = Nd4j.createArrayFromShapeBuffer((DataBuffer)this.data().dup(), (DataBuffer)this.shapeInfoDataBuffer());
            ret.markAsCompressed(true);
            return ret;
        }
        return super.dup(order);
    }

    public boolean equals(Object o) {
        return super.equals(o);
    }

    public String toString() {
        if (!this.isS()) {
            AtomicAllocator.getInstance().synchronizeHostData((INDArray)this);
        }
        return super.toString();
    }

    public void setShapeInfoDataBuffer(DataBuffer buffer) {
        this.shapeInformation = buffer;
        this.jvmShapeInfo = new JvmShapeInfo(this.shapeInformation.asLong());
    }

    private Object writeReplace() throws ObjectStreamException {
        return new BaseNDArrayProxy((INDArray)this);
    }

    public INDArray permutei(int ... rearrange) {
        Nd4j.getExecutioner().push();
        return super.permutei(rearrange);
    }

    public LongShapeDescriptor shapeDescriptor() {
        return LongShapeDescriptor.fromShape((long[])this.shape(), (long[])this.stride(), (long)this.elementWiseStride(), (char)this.ordering(), (DataType)this.dataType(), (boolean)this.isEmpty());
    }

    public INDArray unsafeDuplication() {
        return this.unsafeDuplication(true);
    }

    public INDArray unsafeDuplication(boolean blocking) {
        WorkspaceUtils.assertValidArray((INDArray)this, (String)"Cannot duplicate array");
        DataBuffer rb = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createSame(this.data, false) : Nd4j.getDataBufferFactory().createSame(this.data, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        INDArray ret = Nd4j.createArrayFromShapeBuffer((DataBuffer)rb, (DataBuffer)this.shapeInfoDataBuffer());
        if (blocking) {
            Nd4j.getExecutioner().push();
        }
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getDeviceContext();
        AllocationPoint srcPoint = allocator.getAllocationPoint((INDArray)this);
        AllocationPoint dstPoint = allocator.getAllocationPoint(ret);
        int route = 0;
        MemcpyDirection direction = MemcpyDirection.HOST_TO_HOST;
        long prof = PerformanceTracker.getInstance().helperStartTransaction();
        if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            route = 1;
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getDevicePointer(), this.data.length() * (long)this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToDevice, (Pointer)(blocking ? context.getOldStream() : context.getSpecialStream()));
            dstPoint.tickDeviceWrite();
            direction = MemcpyDirection.DEVICE_TO_DEVICE;
        } else if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            route = 2;
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getDevicePointer(), this.data.length() * (long)this.data.getElementSize(), CudaConstants.cudaMemcpyDeviceToHost, (Pointer)(blocking ? context.getOldStream() : context.getSpecialStream()));
            dstPoint.tickHostWrite();
            direction = MemcpyDirection.DEVICE_TO_HOST;
        } else if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE && srcPoint.getAllocationStatus() == AllocationStatus.HOST) {
            route = 3;
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getDevicePointer(), srcPoint.getHostPointer(), this.data.length() * (long)this.data.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, (Pointer)(blocking ? context.getOldStream() : context.getSpecialStream()));
            dstPoint.tickDeviceWrite();
            direction = MemcpyDirection.HOST_TO_DEVICE;
        } else {
            route = 4;
            NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(dstPoint.getHostPointer(), srcPoint.getHostPointer(), this.data.length() * (long)this.data.getElementSize(), CudaConstants.cudaMemcpyHostToHost, (Pointer)(blocking ? context.getOldStream() : context.getSpecialStream()));
            dstPoint.tickHostWrite();
        }
        if (blocking) {
            context.syncOldStream();
        } else {
            context.syncSpecialStream();
        }
        PerformanceTracker.getInstance().helperRegisterTransaction(dstPoint.getDeviceId(), prof, dstPoint.getNumberOfBytes(), direction);
        return ret;
    }

    public INDArray leverageTo(String id) {
        if (!this.isAttached()) {
            return this;
        }
        if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) {
            return this;
        }
        WorkspaceUtils.assertValidArray((INDArray)this, (String)"Cannot leverage INDArray to new workspace");
        MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
        MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id);
        if (current == target) {
            return this;
        }
        if (this.data.getParentWorkspace() == target) {
            return this;
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(target);
        INDArray copy = null;
        if (!this.isView()) {
            Nd4j.getExecutioner().commit();
            DataBuffer buffer = Nd4j.createBuffer((long)this.length(), (boolean)false);
            AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
            AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
            CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
            MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
            long perfD = PerformanceTracker.getInstance().helperStartTransaction();
            if (pointSrc.isActualOnDeviceSide()) {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * (long)Nd4j.sizeOfDataType((DataType)buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, (Pointer)context.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
            } else {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * (long)Nd4j.sizeOfDataType((DataType)buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, (Pointer)context.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
                direction = MemcpyDirection.HOST_TO_DEVICE;
            }
            context.syncOldStream();
            PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointSrc.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
            copy = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)this.shapeInfoDataBuffer());
            pointDst.tickHostRead();
            pointDst.tickDeviceWrite();
            AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
        } else {
            copy = this.dup(this.ordering());
            Nd4j.getExecutioner().commit();
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(current);
        return copy;
    }

    public INDArray migrate() {
        WorkspaceUtils.assertValidArray((INDArray)this, (String)"Cannot leverage INDArray to new workspace");
        MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (current == null) {
            return this;
        }
        INDArray copy = null;
        if (!this.isView()) {
            Nd4j.getExecutioner().commit();
            DataBuffer buffer = Nd4j.createBuffer((DataType)this.dataType(), (long)this.length(), (boolean)false);
            AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
            AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
            CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
            MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
            long perfD = PerformanceTracker.getInstance().helperStartTransaction();
            if (pointSrc.isActualOnDeviceSide()) {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.length() * (long)Nd4j.sizeOfDataType((DataType)buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, (Pointer)context.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
            } else {
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.length() * (long)Nd4j.sizeOfDataType((DataType)buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, (Pointer)context.getOldStream()) == 0) {
                    throw new ND4JIllegalStateException("memcpyAsync failed");
                }
                direction = MemcpyDirection.HOST_TO_DEVICE;
            }
            context.syncOldStream();
            PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointDst.getNumberOfBytes(), direction);
            if (pointDst.getDeviceId() != Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId()) {
                pointDst.setDeviceId(Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
            }
            copy = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)this.shapeInfoDataBuffer());
            pointDst.tickDeviceWrite();
            AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
        } else {
            copy = this.dup(this.ordering());
        }
        return copy;
    }
}

