/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaOpContext;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCuda;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaExecutioner
extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal();
    protected volatile transient Properties properties;
    protected ThreadLocal<String> lastOp = new ThreadLocal();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);

    public CudaExecutioner() {
        this.experimentalMode.set(nativeOps.isExperimentalEnabled());
    }

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp op) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        int[] dimension = op.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        switch (op.getOpType()) {
            case BROADCAST: {
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer)xShapeInfo, null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                break;
            }
            case BROADCAST_BOOL: {
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer)xShapeInfo, null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingConfigurableHookOut((Op)op, st);
        return op.z();
    }

    protected INDArray naiveExec(ReduceOp op, int ... dimension) {
        DataType argsType;
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            if (op.z() != null) {
                Preconditions.checkState((boolean)op.x().equalShapes(op.z()), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)op.x(), (Object)op.z());
                op.z().assign(op.x());
                return op.z();
            }
            op.setZ(op.x().dup());
            return op.z();
        }
        INDArray ret = op.z();
        this.checkForCompression((Op)op);
        op.validateDataTypes();
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer yDevTadOffsets = null;
        Pointer yDevTadShapeInfo = null;
        if (op.y() != null) {
            if (dimension.length == 0 || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE || op.x().tensorAlongDimension(0L, dimension).length() != op.y().length()) {
                if (!op.isComplexAccumulation() && op.x().length() != op.y().length()) {
                    throw new ND4JIllegalStateException("Op.X [" + op.x().length() + "] and Op.Y [" + op.y().length() + "] lengths should match");
                }
                if (!op.z().isScalar()) {
                    Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
                    yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
                    DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
                    yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
                    xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
                    xShapeInfoHostPointer.put(13L, yDevTadOffsets);
                }
            } else {
                DataBuffer fakeOffsets = Nd4j.getConstantHandler().getConstantBuffer(new int[]{0, 0}, DataType.LONG);
                yDevTadOffsets = fakeOffsets == null ? null : AtomicAllocator.getInstance().getPointer(fakeOffsets, context);
                yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                xShapeInfoHostPointer.put(12L, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                xShapeInfoHostPointer.put(13L, null);
            }
        }
        switch (op.getOpType()) {
            case REDUCE_LONG: 
            case REDUCE_BOOL: {
                argsType = op.x().dataType();
                break;
            }
            default: {
                argsType = op.z().dataType();
            }
        }
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(argsType), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op instanceof Variance) {
            if (ret.isScalar()) {
                nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()), ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else {
                nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, ((Variance)op).isBiasCorrected(), (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.y() != null) {
            if (op.isComplexAccumulation()) {
                LongPointerWrapper dT = new LongPointerWrapper(devTadOffsets);
                LongPointerWrapper yT = new LongPointerWrapper(yDevTadOffsets);
                nativeOps.execReduce3All(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, (LongPointer)devTadShapeInfo, (LongPointer)dT, (LongPointer)yDevTadShapeInfo, (LongPointer)yT);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else if (ret.isScalar()) {
                nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context));
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else {
                nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)yDevTadShapeInfo, (LongPointer)yDevTadOffsets);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (ret.isScalar()) {
            switch (op.getOpType()) {
                case REDUCE_FLOAT: {
                    nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                    break;
                }
                case REDUCE_BOOL: {
                    nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                    break;
                }
                case REDUCE_LONG: {
                    nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                    break;
                }
                case REDUCE_SAME: {
                    nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                    break;
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        } else {
            switch (op.getOpType()) {
                case REDUCE_FLOAT: {
                    nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                    break;
                }
                case REDUCE_BOOL: {
                    nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                    break;
                }
                case REDUCE_SAME: {
                    nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                    break;
                }
                case REDUCE_LONG: {
                    nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(Variance op) {
        return this.exec((ReduceOp)op);
    }

    public INDArray exec(ReduceOp op) {
        boolean wholeDims;
        this.checkForCompression((Op)op);
        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            if (op.z() != null) {
                Preconditions.checkState((boolean)op.x().equalShapes(op.z()), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)op.x(), (Object)op.z());
                op.z().assign(op.x());
                return op.z();
            }
            op.setZ(op.x().dup());
            return op.z();
        }
        int[] dimension = op.dimensions().toIntVector();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        long[] maxShape = Shape.getMaxShape((INDArray[])new INDArray[]{op.x(), op.y()});
        boolean bl = wholeDims = Shape.wholeArrayDimension((int[])dimension) || op.x().rank() == dimension.length || dimension.length == 0;
        long[] retShape = Shape.reductionShape((INDArray)(op.y() == null ? op.x() : (op.x().length() > op.y().length() ? op.x() : op.y())), (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
        if (op.x().isVector() && op.x().length() == (long)ArrayUtil.prod((long[])retShape) && ArrayUtil.prodLong((long[])retShape) > 1L && op.y() == null) {
            return op.noOp();
        }
        DataType dtype = op.resultType();
        INDArray ret = null;
        if (op.z() == null || op.z() == op.x()) {
            if (op.isComplexAccumulation()) {
                long xT = op.x().tensorsAlongDimension(dimension);
                long yT = op.y().tensorsAlongDimension(dimension);
                ret = Nd4j.createUninitialized((DataType)dtype, (long[])new long[]{xT, yT});
            } else {
                if (op.y() != null) {
                    if (op.x().length() == op.y().length()) {
                        if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
                            throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + ", dimension = " + Arrays.toString(dimension) + ")");
                        }
                    } else {
                        if (dimension.length == 0) {
                            throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");
                        }
                        long xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
                        if (xTADSize != op.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + xTADSize + ", y size = " + op.y().length());
                        }
                    }
                }
                ret = Nd4j.create((DataType)dtype, (long[])retShape);
            }
            op.setZ(ret);
        } else if (op.z().length() != (retShape.length == 0 ? 1L : ArrayUtil.prodLong((long[])retShape))) {
            throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.naiveExec(op, dimension);
        this.profilingConfigurableHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(IndexAccumulation op) {
        int[] dimension = Shape.normalizeAxis((int)op.x().rank(), (int[])op.dimensions().toIntVector());
        if (op.x().isEmpty()) {
            for (int d : dimension) {
                Preconditions.checkArgument((op.x().shape()[d] != 0L ? 1 : 0) != 0, (String)"IndexReduce can't be issued along axis with 0 in shape");
            }
        }
        if (op.z() == null) {
            long[] retShape = Shape.reductionShape((INDArray)op.x(), (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
            op.setZ(Nd4j.createUninitialized((DataType)DataType.LONG, (long[])retShape));
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (op.x().isVector() && op.x().length() == op.z().length()) {
            return op.x();
        }
        if (op.z().isEmpty()) {
            return op.z();
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingConfigurableHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(Op op) {
        this.checkForCompression(op);
        if (op instanceof CopyOp) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() != null) {
                AtomicAllocator.getInstance().tickHostWrite(op.z());
            }
            return null;
        }
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.invoke(t);
        } else if (op instanceof ReduceOp) {
            ReduceOp acc = (ReduceOp)op;
            this.invoke(acc, acc.dimensions().toIntVector());
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            this.invoke(sc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.invoke(broadcastOp);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            this.invoke(indexAccumulation, indexAccumulation.dimensions().toIntVector());
        } else if (op instanceof RandomOp) {
            this.exec((RandomOp)op);
        } else if (op instanceof CustomOp) {
            this.exec((CustomOp)op);
        }
        return op.z();
    }

    public TransformOp execAndReturn(TransformOp op) {
        this.checkForCompression((Op)op);
        this.invoke(op);
        return op;
    }

    protected CudaContext invoke(BroadcastOp op) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension());
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension());
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
        switch (op.getOpType()) {
            case BROADCAST: {
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                break;
            }
            case BROADCAST_BOOL: {
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown opType: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingConfigurableHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(IndexAccumulation op, int[] dimension) {
        dimension = Shape.normalizeAxis((int)op.x().rank(), (int[])dimension);
        if ((dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) && (op.z() == op.x() || op.z() == null)) {
            op.setZ(Nd4j.createUninitialized((DataType)DataType.LONG, (long[])new long[0], (char)'c'));
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        if (dimension != null) {
            for (int i = 0; i < dimension.length; ++i) {
                if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
            }
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z().isScalar() ? null : op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null;
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        int[] fdimension = dimension;
        if (fdimension == null) {
            fdimension = new int[]{0};
        }
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
            AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y());
        } else {
            Arrays.sort(dimension);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension));
            nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, dimensionPointer, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
            AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(ReduceOp op, int[] dimension) {
        Pointer yDevTadOffsets;
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            if (op.z() != null) {
                Preconditions.checkState((boolean)op.x().equalShapes(op.z()), (String)"For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", (Object)op.x(), (Object)op.z());
                op.z().assign(op.x());
                return context;
            }
            op.setZ(op.x().dup());
            return context;
        }
        if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) {
            if (op.z() == null) {
                op.setZ(Nd4j.scalar((boolean)((BaseReduceBoolOp)op).emptyValue()));
            } else {
                op.z().assign(((BaseReduceBoolOp)op).emptyValue());
            }
            return context;
        }
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        dimension = Shape.normalizeAxis((int)op.x().rank(), (int[])dimension);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (dimension == null) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (dimension.length > 1) {
            Arrays.sort(dimension);
        }
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pair tadBuffers = op.x().isEmpty() ? Pair.makePair((Object)op.x().data(), null) : tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = op.x().isEmpty() ? null : (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        long[] retShape = Shape.reductionShape((INDArray)op.x(), (int[])dimension, (boolean)true, (boolean)op.isKeepDims());
        if (op.y() != null) {
            if (op.x().length() == op.y().length()) {
                if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
                    throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) + ", dimension = " + Arrays.toString(dimension) + ")");
                }
            } else {
                long xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);
                if (xTADSize != op.y().length()) {
                    throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + xTADSize + ", y size = " + op.y().length());
                }
            }
        }
        DataType dataType = op.resultType();
        if (op.z() == null) {
            INDArray ret = Nd4j.createUninitialized((DataType)dataType, (long[])retShape);
            op.setZ(ret);
        } else if (op.z().dataType() != dataType || !Arrays.equals(retShape, op.z().shape())) {
            throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape) + " but has datatype " + op.z().dataType() + " and shape " + Arrays.toString(op.z().shape()));
        }
        DataBuffer eb = op.extraArgsDataBuff(op.z().dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? op.x().dataType() : op.z().dataType());
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pair yTadBuffers = op.y() == null ? null : tadManager.getTADOnlyShapeInfo(op.y(), dimension);
        Pointer yDevTadShapeInfo = op.y() == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
        DataBuffer yOffsets = op.y() == null ? null : (DataBuffer)yTadBuffers.getSecond();
        Pointer pointer = yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
        if (op.y() != null) {
            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
        }
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        op.validateDataTypes();
        if (op.z().isScalar()) {
            if (op instanceof Variance) {
                nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT: {
                        nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
                        break;
                    }
                    case REDUCE_BOOL: {
                        nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
                        break;
                    }
                    case REDUCE_SAME: {
                        nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
                        break;
                    }
                    case REDUCE_LONG: {
                        nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException();
                    }
                }
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else {
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), dimensionPointer, null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)yDevTadShapeInfo, (LongPointer)yDevTadOffsets);
            } else if (op instanceof Variance) {
                nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, ((Variance)op).isBiasCorrected(), (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets);
            } else {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT: {
                        nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                        break;
                    }
                    case REDUCE_SAME: {
                        nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                        break;
                    }
                    case REDUCE_BOOL: {
                        nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                        break;
                    }
                    case REDUCE_LONG: {
                        nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, extraArgs, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, (Pointer)((IntPointer)op.dimensions().data().addressPointer()), (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException();
                    }
                }
            }
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        this.profilingConfigurableHookOut((Op)op, st);
        Nd4j.getExecutioner().commit();
        return context;
    }

    protected CudaContext intercept(ScalarOp op, int[] dimension) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        Arrays.sort(dimension);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer extraPointers = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        switch (op.getOpType()) {
            case SCALAR: {
                nativeOps.execScalarTad(extraPointers, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, extraArgs, null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)devTadShapeInfoZ, (LongPointer)devTadOffsetsZ);
                break;
            }
            case SCALAR_BOOL: {
                nativeOps.execScalarBoolTad(extraPointers, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, extraArgs, null, (LongPointer)op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null, (LongPointer)devTadShapeInfo, (LongPointer)devTadOffsets, (LongPointer)devTadShapeInfoZ, (LongPointer)devTadOffsetsZ);
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingConfigurableHookOut((Op)op, st);
        return null;
    }

    public INDArray exec(ScalarOp op) {
        this.invoke(op);
        return op.z();
    }

    protected CudaContext invoke(ScalarOp op) {
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (op.x().length() != op.z().length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        if (op.dimensions() != null) {
            this.intercept(op, op.dimensions().toIntVector());
            return null;
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? op.x().dataType() : op.z().dataType()), context) : null;
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null});
        switch (op.getOpType()) {
            case SCALAR_BOOL: {
                nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs);
                break;
            }
            case SCALAR: {
                nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.scalar(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context), extraArgs);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar());
        this.profilingConfigurableHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(TransformOp op) {
        Pointer hostYShapeInfo;
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = allocator.getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        INDArray ret = null;
        Pointer x = allocator.getPointer(op.x(), context);
        Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context);
        Object dimensionDevPointer = null;
        Object dimensionHostPointer = null;
        Object retPointer = null;
        Object retHostShape = null;
        Object dimension = null;
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer pointer = hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        if (op.z() == null) {
            ret = Nd4j.createUninitialized((DataType)op.resultType(), (long[])op.x().shape(), (char)op.x().ordering());
            op.setZ(ret);
        }
        Pointer extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? op.x().dataType() : op.z().dataType()), context) : null;
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Object hostTadShapeInfo = null;
        Object devTadShapeInfo = null;
        Object hostMaxTadShapeInfo = null;
        Object devMaxTadShapeInfo = null;
        Object devTadOffsets = null;
        Object devMaxTadOffsets = null;
        op.validateDataTypes(this.experimentalMode.get());
        Pointer z = allocator.getPointer(op.z(), context);
        Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, hostMaxTadShapeInfo, devMaxTadShapeInfo, devMaxTadOffsets, dimensionDevPointer, dimensionHostPointer, retPointer, new CudaPointer(dimension == null ? 0L : (long)(dimension).length), retHostShape});
        if (op.y() != null) {
            Pointer y = allocator.getPointer(op.y(), context);
            Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context);
            if (op.x().length() != op.y().length() || op.x().length() != op.z().length()) {
                throw new ND4JIllegalStateException("X, Y and Z arguments should have the same length for PairwiseTransform");
            }
            switch (op.getOpType()) {
                case TRANSFORM_BOOL: 
                case PAIRWISE_BOOL: {
                    nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                default: {
                    nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostYShapeInfo, y, (LongPointer)yShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
            }
        } else {
            switch (op.getOpType()) {
                case TRANSFORM_ANY: {
                    nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_FLOAT: {
                    nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_BOOL: {
                    nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_SAME: {
                    nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                case TRANSFORM_STRICT: {
                    nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(), null, (LongPointer)hostXShapeInfo, x, (LongPointer)xShapeInfo, null, (LongPointer)hostZShapeInfo, z, (LongPointer)zShapeInfo, extraArgs);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException();
                }
            }
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        if (extraArgs != null) {
            extraArgs.address();
        }
        if (ret != null) {
            ret.elementWiseStride();
        }
        this.profilingConfigurableHookOut((Op)op, st);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4L, false);
        batch.setParamsSurface(buffer);
        return buffer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        BaseCudaDataBuffer surfaceBuffer = (BaseCudaDataBuffer)this.getBuffer(batch);
        surfaceBuffer.lazyAllocateHostPointer();
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        IntPointer pointer = new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)).asIntPointer();
        AllocationPoint surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer);
        int maxTypes = 5;
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxArraySize = batch.getSample().maxIntArraySize();
        int indexPos = maxTypes * (Batch.getBatchLimit() * 16);
        int intArraysPos = indexPos + batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16);
        int realPos = (intArraysPos + maxIntArrays * maxArraySize * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataType.DOUBLE ? 2 : 1);
        if (Nd4j.dataType() == DataType.HALF) {
            realPos *= 2;
        }
        int argsPos = (realPos + batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataType.FLOAT ? 2 : 1);
        if (Nd4j.dataType() == DataType.HALF) {
            argsPos /= 4;
        }
        int shapesPos = argsPos + batch.getSample().maxArguments() * (Batch.getBatchLimit() * 16);
        DataType dataType = null;
        for (int i = 0; i < batch.getNumAggregates(); ++i) {
            int e;
            Aggregate op = (Aggregate)batch.getAggregates().get(i);
            if (i == 0) {
                dataType = ((INDArray)op.getArguments().get(0)).dataType();
            }
            int idx = i * maxTypes;
            pointer.put((long)idx, op.getArguments().size());
            pointer.put((long)(idx + 1), op.getShapes().size());
            pointer.put((long)(idx + 2), op.getIndexingArguments().size());
            pointer.put((long)(idx + 3), op.getRealArguments().size());
            pointer.put((long)(idx + 4), op.getIntArrayArguments().size());
            for (int e2 = 0; e2 < op.getIndexingArguments().size(); ++e2) {
                idx = indexPos + i * batch.getSample().maxIndexArguments();
                pointer.put((long)(idx + e2), ((Integer)op.getIndexingArguments().get(e2)).intValue());
            }
            int bsize = maxIntArrays * maxArraySize;
            for (int e3 = 0; e3 < op.getIntArrayArguments().size(); ++e3) {
                int step = i * bsize + e3 * maxArraySize;
                if (op.getIntArrayArguments().get(e3) == null) continue;
                for (int x = 0; x < ((int[])op.getIntArrayArguments().get(e3)).length; ++x) {
                    idx = intArraysPos + step + x;
                    pointer.put((long)idx, ((int[])op.getIntArrayArguments().get(e3))[x]);
                }
            }
            switch (dataType) {
                case FLOAT: {
                    FloatPointer realPtr = new FloatPointer((Pointer)pointer);
                    for (e = 0; e < op.getRealArguments().size(); ++e) {
                        idx = realPos + i * op.maxRealArguments();
                        realPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).floatValue());
                    }
                    break;
                }
                case DOUBLE: {
                    DoublePointer dPtr = new DoublePointer((Pointer)pointer);
                    for (e = 0; e < op.getRealArguments().size(); ++e) {
                        idx = realPos + i * op.maxRealArguments();
                        dPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).doubleValue());
                    }
                    break;
                }
                case HALF: {
                    ShortPointer sPtr = new ShortPointer((Pointer)pointer);
                    for (e = 0; e < op.getRealArguments().size(); ++e) {
                        idx = realPos + i * op.maxRealArguments();
                        sPtr.put((long)(idx + e), BaseDataBuffer.fromFloat((float)((Number)op.getRealArguments().get(e)).floatValue()));
                    }
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unknown data type");
                }
            }
            PointerPointer ptrPtr = new PointerPointer((Pointer)pointer);
            for (e = 0; e < op.getArguments().size(); ++e) {
                idx = argsPos + i * batch.getSample().maxArguments();
                if (op.getArguments().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(e)).tickDeviceWrite();
            }
            for (e = 0; e < op.getShapes().size(); ++e) {
                idx = shapesPos + i * batch.getSample().maxShapes();
                if (op.getShapes().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(e)).tickDeviceWrite();
            }
        }
        surfacePoint.tickHostWrite();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(Math.min(batch.getNumAggregates(), CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize())));
        extraArgs.put(3L, (Pointer)new CudaPointer(batch.getSample().getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(batch.getSample().getSharedMemorySize()));
        nativeOps.execAggregateBatch(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), (int)FlatBuffersMapper.getDataTypeAsByte((DataType)dataType));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        surfacePoint.tickHostWrite();
    }

    public void exec(List<Aggregate> batch) {
        if (batch.size() == 0) {
            return;
        }
        List batches = Batch.getBatches(batch, (int)8192);
        for (Batch single : batches) {
            this.exec(single);
        }
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        context.syncOldStream();
    }

    public void exec(Aggregate op) {
        int numArguments = op.getArguments().size();
        int numShapeArguments = op.getShapes().size();
        int numIndexArguments = op.getIndexingArguments().size();
        int numIntArrays = op.getIntArrayArguments().size();
        int numRealArguments = op.getRealArguments().size();
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(1L));
        extraArgs.put(3L, (Pointer)new CudaPointer(op.getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(op.getSharedMemorySize()));
        long[] arguments = new long[numArguments];
        DataType dataType = ((INDArray)op.getArguments().get(0)).dataType();
        for (int x = 0; x < numArguments; ++x) {
            long l = arguments[x] = op.getArguments().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(x), context).address();
            if (op.getArguments().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(x)).tickDeviceWrite();
        }
        DataBuffer tempX = AllocationUtils.getPointersBuffer(arguments);
        PointerPointer xPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context));
        long[] shapes = new long[numShapeArguments];
        for (int x = 0; x < numShapeArguments; ++x) {
            long l = shapes[x] = op.getShapes().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(x), context).address();
            if (op.getShapes().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(x)).tickDeviceWrite();
        }
        DataBuffer tempS = AllocationUtils.getPointersBuffer(shapes);
        PointerPointer sPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempS, context));
        long[] ints = new long[numIntArrays];
        for (int x = 0; x < numIntArrays; ++x) {
            if (op.getIntArrayArguments().get(x) == null) continue;
            DataBuffer intBuf = Nd4j.getDataBufferFactory().createInt((int[])op.getIntArrayArguments().get(x));
            ints[x] = AtomicAllocator.getInstance().getPointer(intBuf, context).address();
        }
        DataBuffer tempI = AllocationUtils.getPointersBuffer(ints);
        PointerPointer iPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempI, context));
        int[] indexes = new int[numIndexArguments];
        for (int x = 0; x < numIndexArguments; ++x) {
            indexes[x] = (Integer)op.getIndexingArguments().get(x);
        }
        DataBuffer intBuffer = Nd4j.getDataBufferFactory().createInt(indexes);
        double[] reals = new double[numRealArguments];
        for (int x = 0; x < numRealArguments; ++x) {
            reals[x] = ((Number)op.getRealArguments().get(x)).doubleValue();
        }
        INDArray realsBuffer = Nd4j.create((double[])reals, (long[])new long[]{reals.length}, (DataType)dataType);
        nativeOps.execAggregate(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments, (int)FlatBuffersMapper.getDataTypeAsByte((DataType)dataType));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public INDArray exec(RandomOp op) {
        return this.exec(op, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp op, Random rng) {
        Pointer hostZShapeInfo;
        long st = this.profilingConfigurableHookIn((Op)op, new DataBuffer[0]);
        this.checkForCompression((Op)op);
        if (rng.getStatePointer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.opName());
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        PointerPointer extraZZ = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        Pointer hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer pointer = hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        if (op.x() != null && op.y() != null && op.z() != null) {
            nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), null, (LongPointer)hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), null, (LongPointer)hostYShapeInfo, AtomicAllocator.getInstance().getPointer(op.y(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
        } else if (op.x() != null && op.z() != null) {
            nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), null, (LongPointer)hostXShapeInfo, AtomicAllocator.getInstance().getPointer(op.x(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
        } else {
            nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), null, (LongPointer)hostZShapeInfo, AtomicAllocator.getInstance().getPointer(op.z(), context), (LongPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingConfigurableHookOut((Op)op, st);
        return op.z();
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties props = super.getEnvironmentInformation();
            ArrayList devicesList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                HashMap<String, Object> deviceProps = new HashMap<String, Object>();
                deviceProps.put("cuda.deviceName", nativeOps.getDeviceName(i));
                deviceProps.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory(i));
                deviceProps.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory(i));
                deviceProps.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor(i)));
                deviceProps.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor(i)));
                devicesList.add(i, deviceProps);
            }
            props.put("backend", "CUDA");
            props.put("cuda.availableDevices", (Object)nativeOps.getAvailableDevices());
            props.put("cuda.devicesInformation", devicesList);
            props.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
            props.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
            props.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
            this.properties = props;
        } else {
            List devicesList = (List)this.properties.get("cuda.devicesInformation");
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                Map dev = (Map)devicesList.get(i);
                dev.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory(i));
                dev.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory(i));
            }
            this.properties.put("cuda.devicesInformation", devicesList);
            this.properties.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
    }

    public void commit() {
        AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
        AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream();
    }

    public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
        int numPrefixBlocks;
        DataBuffer tempX;
        int numPrefixBlocks2;
        DataBuffer blocksBuffer;
        DataBuffer buffer = input.data();
        int numThreads = 1024;
        int numBlocks = (int)(buffer.length() / (long)numThreads + (long)(buffer.length() % (long)numThreads == 0L ? 0 : 1));
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        DataBuffer dataBuffer = blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true) : Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true, Nd4j.getMemoryManager().getCurrentWorkspace());
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(1L, (Pointer)context.getOldStream());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1(extras, AtomicAllocator.getInstance().getPointer(buffer), (LongPointer)AtomicAllocator.getInstance().getHostPointer(input.shapeInfoDataBuffer()), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (float)threshold);
        AtomicAllocator.getInstance().getAllocationPoint(blocksBuffer).tickDeviceWrite();
        int numMatches = blocksBuffer.getInt(0L);
        if (numMatches < 2) {
            return null;
        }
        if (boundary != null && numMatches > boundary) {
            numMatches = boundary;
            blocksBuffer.put(0L, numMatches);
        }
        DataBuffer encodedBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(4 + numMatches), false) : Nd4j.getDataBufferFactory().createInt((long)(4 + numMatches), false, Nd4j.getMemoryManager().getCurrentWorkspace());
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite();
        encodedBuffer.put(0L, numMatches);
        encodedBuffer.put(1L, (int)buffer.length());
        encodedBuffer.put(2L, Float.floatToIntBits((float)threshold));
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite();
        encodedBuffer.put(3L, 0);
        int prefixThreads = 512;
        int numElts = numBlocks;
        int level = 0;
        ArrayList<DataBuffer> buffers = new ArrayList<DataBuffer>();
        do {
            numPrefixBlocks2 = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)));
            if (numBlocks <= 1) continue;
            ++level;
        } while ((numElts = numPrefixBlocks2) > 1);
        long[] pointers = new long[level];
        level = 0;
        numElts = numBlocks;
        DataBuffer dataBuffer2 = tempX = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false) : Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            if ((numPrefixBlocks = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)))) <= 1) continue;
            DataBuffer bf = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false) : Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false, Nd4j.getMemoryManager().getCurrentWorkspace());
            buffers.add(bf);
            pointers[level++] = AtomicAllocator.getInstance().getPointer(bf).address();
        } while ((numElts = numPrefixBlocks) > 1);
        AtomicAllocator.getInstance().memcpyBlocking(tempX, (Pointer)new LongPointer(pointers), pointers.length * 8, 0L);
        extras.put(2L, AtomicAllocator.getInstance().getPointer(tempX));
        DataBuffer offsetsBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numBlocks, true) : Nd4j.getDataBufferFactory().createInt((long)numBlocks, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(extras, (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (long)numBlocks, (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickDeviceWrite();
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3(extras, AtomicAllocator.getInstance().getPointer(buffer), (LongPointer)AtomicAllocator.getInstance().getHostPointer(input.shapeInfoDataBuffer()), (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(encodedBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        return Nd4j.createArrayFromShapeBuffer((DataBuffer)encodedBuffer, (DataBuffer)input.shapeInfoDataBuffer());
    }

    public INDArray thresholdEncode(INDArray input, double threshold) {
        return this.thresholdEncode(input, threshold, null);
    }

    public INDArray thresholdDecode(INDArray encoded, INDArray target) {
        DataBuffer buffer = encoded.data();
        if (buffer.dataType() != DataType.INT) {
            throw new UnsupportedOperationException();
        }
        long compressedLength = buffer.getInt(0L);
        long originalLength = buffer.getInt(1L);
        if (target.length() != originalLength) {
            throw new ND4JIllegalStateException("originalLength [" + originalLength + "] stored in encoded array doesn't match target length [" + target.length() + "]");
        }
        DataBuffer result = target.data();
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(1L, (Pointer)context.getOldStream());
        nativeOps.decodeThreshold(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, AtomicAllocator.getInstance().getPointer(result), (LongPointer)AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
        return target;
    }

    public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
        long length = indArray.length();
        long tLen = target.data().length();
        if (tLen != length / 16L + 5L) {
            throw new ND4JIllegalStateException("Length of target array should be " + (length / 16L + 5L));
        }
        if (target.data().dataType() != DataType.INT) {
            throw new ND4JIllegalStateException("Target array should have INT dataType");
        }
        DataBuffer buffer = target.data();
        buffer.put(0L, (int)length);
        buffer.put(1L, (int)length);
        buffer.put(2L, Float.floatToIntBits((float)threshold));
        buffer.put(3L, 1);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(indArray, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(indArray), context.getOldStream(), context.getBufferScalar(), context.getBufferReduction()});
        long val = nativeOps.encodeBitmap(extras, AtomicAllocator.getInstance().getPointer(indArray, context), (LongPointer)AtomicAllocator.getInstance().getHostPointer(indArray.shapeInfoDataBuffer()), length, (IntPointer)AtomicAllocator.getInstance().getPointer(buffer, context), (float)threshold);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray, new INDArray[0]);
        AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        return val;
    }

    public INDArray bitmapDecode(INDArray encoded, INDArray target) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(target, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(target), context.getOldStream(), context.getBufferScalar(), context.getBufferReduction()});
        nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.length(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer)AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, target, new INDArray[0]);
        return target;
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String[] split;
            String list = nativeOps.getAllCustomOps();
            if (list == null || list.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap<String, CustomOpDescriptor> map = new HashMap<String, CustomOpDescriptor>();
            for (String op : split = list.split(";")) {
                if (op == null || op.isEmpty()) continue;
                String[] another = op.split(":");
                CustomOpDescriptor descriptor = CustomOpDescriptor.builder().hash(Long.valueOf(another[1]).longValue()).numInputs(Integer.valueOf(another[2]).intValue()).numOutputs(Integer.valueOf(another[3]).intValue()).allowsInplace(Integer.valueOf(another[4]) == 1).numTArgs(Integer.valueOf(another[5]).intValue()).numIArgs(Integer.valueOf(another[6]).intValue()).build();
                map.put(another[0], descriptor);
            }
            this.customOps = Collections.unmodifiableMap(map);
        }
        return this.customOps;
    }

    protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) {
        int rank = (int)ptr.get(0L);
        long[] shape = new long[rank * 2 + 4];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = ptr.get((long)i);
        }
        ArrayType t = ArrayOptionsHelper.arrayType((long[])shape);
        return LongShapeDescriptor.fromShape((long[])Shape.shape((long[])shape), (long[])Shape.stride((long[])shape), (long)Shape.elementWiseStride((long[])shape), (char)Shape.order((long[])shape), (DataType)ArrayOptionsHelper.dataType((long[])shape), (t == ArrayType.EMPTY ? 1 : 0) != 0);
    }

    /*
     * WARNING - void declaration
     */
    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op) {
        void var12_15;
        if (op == null) {
            throw new NullPointerException("op is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        String lc = op.opName().toLowerCase();
        long hash = op.opHash();
        ArrayList<LongShapeDescriptor> result = new ArrayList<LongShapeDescriptor>();
        if (op.numInputArguments() < 1 && op.getDescriptor().getNumInputs() != -2) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: number of input args was 0", (Object)op.getClass().getName());
            }
            return Collections.emptyList();
        }
        PointerPointer inputBuffers = new PointerPointer((long)(op.inputArguments().length * 2));
        PointerPointer inputShapes = new PointerPointer((long)op.inputArguments().length);
        int cnt = 0;
        for (INDArray iNDArray : op.inputArguments()) {
            if (!iNDArray.isEmpty()) {
                inputBuffers.put((long)cnt, iNDArray.data().addressPointer());
                inputBuffers.put((long)(cnt + op.inputArguments().length), AtomicAllocator.getInstance().getPointer(iNDArray.data()));
            }
            inputShapes.put((long)cnt++, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        LongPointer iArgs = op.iArgs().length > 0 ? new LongPointer((long)op.iArgs().length) : null;
        cnt = 0;
        long[] lArray = op.iArgs();
        int n = lArray.length;
        boolean bl = false;
        while (var12_15 < n) {
            long i = lArray[var12_15];
            iArgs.put((long)cnt++, i);
            ++var12_15;
        }
        DoublePointer tArgs = op.tArgs().length > 0 ? new DoublePointer((long)op.tArgs().length) : null;
        BooleanPointer bArgs = op.bArgs().length > 0 ? new BooleanPointer((long)op.bArgs().length) : null;
        cnt = 0;
        for (boolean b : op.bArgs()) {
            bArgs.put((long)cnt++, b);
        }
        cnt = 0;
        for (double t : op.tArgs()) {
            tArgs.put((long)cnt++, t);
        }
        OpaqueShapeList opaqueShapeList = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (opaqueShapeList == null) {
            throw new RuntimeException();
        }
        int e = 0;
        while ((long)e < nativeOps.getShapeListSize(opaqueShapeList)) {
            result.add(this.getShapeFromPointer(new PagedPointer((Pointer)nativeOps.getShape(opaqueShapeList, (long)e)).asLongPointer()));
            ++e;
        }
        nativeOps.deleteShapeList((Pointer)opaqueShapeList);
        return result;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public INDArray[] exec(CustomOp op) {
        Nd4j.getExecutioner().commit();
        if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
            try {
                List<LongShapeDescriptor> list = this.calculateOutputShape(op);
                if (list.isEmpty()) {
                    throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
                }
                for (LongShapeDescriptor shape : list) {
                    op.addOutputArgument(new INDArray[]{Nd4j.create((LongShapeDescriptor)shape)});
                }
            }
            catch (Exception e) {
                throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
            }
        }
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        String name = op.opName();
        try (CudaOpContext context = (CudaOpContext)this.buildContext();){
            context.markInplace(op.isInplaceCall());
            context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
            context.setInputArrays(op.inputArguments());
            context.setOutputArrays(op.outputArguments());
            context.setBArguments(op.bArgs());
            context.setIArguments(op.iArgs());
            context.setTArguments(op.tArgs());
            INDArray[] result = this.exec(op, context);
            Pair<Long, Long> states = context.getRngStates();
            Nd4j.getRandom().setStates(((Long)states.getFirst()).longValue(), ((Long)states.getSecond()).longValue());
            INDArray[] iNDArrayArray = result;
            return iNDArrayArray;
        }
        catch (ND4JOpProfilerException e) {
            throw e;
        }
        catch (Exception e) {
            throw new RuntimeException("Op [" + name + "] execution failed", e);
        }
    }

    public void enableDebugMode(boolean reallyEnable) {
        this.debug.set(reallyEnable);
        nativeOps.enableDebugMode(reallyEnable);
    }

    public void enableVerboseMode(boolean reallyEnable) {
        this.verbose.set(reallyEnable);
        nativeOps.enableVerboseMode(reallyEnable);
    }

    public void registerGraph(long id, Pointer graph) {
        nativeOps.registerGraph(null, id, graph);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public Map<String, INDArray> executeGraph(long id, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> reverseMap) {
        if (map == null) {
            throw new NullPointerException("map is marked @NonNull but is null");
        }
        if (reverseMap == null) {
            throw new NullPointerException("reverseMap is marked @NonNull but is null");
        }
        Nd4j.getExecutioner().commit();
        PointerPointer ptrBuffers = new PointerPointer((long)(map.size() * 2));
        PointerPointer ptrShapes = new PointerPointer((long)(map.size() * 2));
        IntPointer ptrIndices = new IntPointer((long)map.size());
        int cnt = 0;
        ArrayList<String> keySet = new ArrayList<String>(map.keySet());
        for (String key : keySet) {
            INDArray array = map.get(key);
            ptrBuffers.put((long)cnt, AtomicAllocator.getInstance().getHostPointer(array));
            ptrShapes.put((long)cnt, AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()));
            ptrIndices.put((long)cnt, reverseMap.get(key).intValue());
            ++cnt;
        }
        LinkedHashMap<String, INDArray> newMap = new LinkedHashMap<String, INDArray>();
        OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpStatus status = OpStatus.byNumber((int)nativeOps.getVariablesSetStatus(result));
        if (status != OpStatus.ND4J_STATUS_OK) {
            throw new ND4JIllegalStateException("Op execution failed: " + status);
        }
        int e = 0;
        while ((long)e < nativeOps.getVariablesSetSize(result)) {
            OpaqueVariable var = nativeOps.getVariable(result, (long)e);
            int nodeId = nativeOps.getVariableId(var);
            int index = nativeOps.getVariableIndex(var);
            LongPointer shapeInfo = nativeOps.getVariableShape(var);
            Pointer buffer = nativeOps.getVariableBuffer(var);
            int rank = (int)shapeInfo.get(0L);
            long[] jshape = new long[rank * 2 + 4];
            for (int i = 0; i < jshape.length; ++i) {
                jshape[i] = shapeInfo.get((long)i);
            }
            long[] shapeOf = Shape.shapeOf((long[])jshape);
            long[] stridesOf = Shape.stridesOf((long[])jshape);
            char order = Shape.order((long[])jshape);
            INDArray array = Nd4j.create((long[])shapeOf, (long[])stridesOf, (long)0L, (char)order);
            Pointer.memcpy((Pointer)AtomicAllocator.getInstance().getHostPointer(array), (Pointer)buffer, (long)(ArrayUtil.prod((long[])shapeOf) * Nd4j.sizeOfDataType()));
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
            String nodeName = nativeOps.getVariableName(var);
            newMap.put(nodeName, array);
            ++e;
        }
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        nativeOps.deleteVariablesSet(result);
        return newMap;
    }

    public void forgetGraph(long id) {
        nativeOps.unregisterGraph(null, id);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
    }

    public void setElementsThreshold(int threshold) {
        nativeOps.setElementThreshold(threshold);
    }

    public void setTadThreshold(int threshold) {
        nativeOps.setTADThreshold(threshold);
    }

    public OpExecutioner.ExecutionerType type() {
        return OpExecutioner.ExecutionerType.CUDA;
    }

    public String getString(Utf8Buffer buffer, long index) {
        long addr = ((LongIndexer)buffer.indexer()).get(index);
        PagedPointer ptr = new PagedPointer(addr);
        Nd4jCuda.utf8string str = new Nd4jCuda.utf8string((Pointer)ptr);
        return str._buffer().capacity((long)str._length()).getString();
    }

    public boolean isExperimentalMode() {
        return this.experimentalMode.get();
    }

    public void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, @NonNull int[] axis) {
        if (array == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        if (indices == null) {
            throw new NullPointerException("indices is marked @NonNull but is null");
        }
        if (updates == null) {
            throw new NullPointerException("updates is marked @NonNull but is null");
        }
        if (axis == null) {
            throw new NullPointerException("axis is marked @NonNull but is null");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(array, indices, updates);
        Pair tadX = tadManager.getTADOnlyShapeInfo(array, axis);
        Pair tadY = tadManager.getTADOnlyShapeInfo(updates, axis);
        if (((DataBuffer)tadY.getSecond()).length() != indices.length()) {
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer stuff = this.extraz.get().put(new Pointer[]{null, context.getOldStream()});
        nativeOps.scatterUpdate(stuff, op.ordinal(), (int)indices.length(), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer((DataBuffer)tadX.getFirst()), null, AtomicAllocator.getInstance().getPointer(array, context), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadX.getFirst()), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadX.getSecond()), null, (LongPointer)AtomicAllocator.getInstance().getHostPointer((DataBuffer)tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadY.getFirst()), (LongPointer)AtomicAllocator.getInstance().getPointer((DataBuffer)tadY.getSecond()), AtomicAllocator.getInstance().getHostPointer(indices), (LongPointer)AtomicAllocator.getInstance().getHostPointer(indices.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(indices, context), (LongPointer)AtomicAllocator.getInstance().getPointer(indices.shapeInfoDataBuffer(), context));
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates);
    }

    public OpContext buildContext() {
        return new CudaOpContext();
    }

    public INDArray[] exec(CustomOp op, OpContext context) {
        long st = this.profilingConfigurableHookIn(op);
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        ((CudaOpContext)context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
        int status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        if (status != 0) {
            throw new RuntimeException("Op [" + op.opName() + "] execution failed");
        }
        for (INDArray arr : op.outputArguments()) {
            AtomicAllocator.getInstance().registerAction(ctx, arr, new INDArray[0]);
        }
        AtomicAllocator.getInstance().registerAction(ctx, null, op.inputArguments());
        this.profilingConfigurableHookOut(op, st);
        if (context.getOutputArrays().isEmpty()) {
            return new INDArray[0];
        }
        return context.getOutputArrays().toArray(new INDArray[context.getOutputArrays().size()]);
    }

    public INDArrayStatistics inspectArray(@NonNull INDArray array) {
        if (array == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        Nd4jCuda.DebugInfo debugInfo = new Nd4jCuda.DebugInfo();
        CudaContext ctx = AtomicAllocator.getInstance().getDeviceContext();
        AtomicAllocator.getInstance().synchronizeHostData(array);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{null, ctx.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), ctx.getBufferAllocation(), ctx.getBufferReduction(), ctx.getBufferScalar(), ctx.getBufferSpecial()});
        nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer)AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer)AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), (Pointer)debugInfo);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        return INDArrayStatistics.builder().minValue(debugInfo._minValue()).maxValue(debugInfo._maxValue()).meanValue(debugInfo._meanValue()).stdDevValue(debugInfo._stdDevValue()).countInf(debugInfo._infCount()).countNaN(debugInfo._nanCount()).countNegative(debugInfo._negativeCount()).countPositive(debugInfo._positiveCount()).countZero(debugInfo._zeroCount()).build();
    }

    public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), (long)Shape.shapeInfoLength((long)shape.length));
        nativeOps.deleteShapeBuffer(dbf);
        return result;
    }

    public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer)array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        CudaLongDataBuffer tadShape = new CudaLongDataBuffer((Pointer)nativeOps.getPrimaryShapeInfo(pack), (Pointer)nativeOps.getSpecialShapeInfo(pack), (long)nativeOps.getShapeInfoLength(pack));
        CudaLongDataBuffer tadOffsets = new CudaLongDataBuffer((Pointer)nativeOps.getPrimaryOffsets(pack), (Pointer)nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
        nativeOps.deleteTadPack(pack);
        return new TadPack((DataBuffer)tadShape, (DataBuffer)tadOffsets);
    }

    public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer buffer = Nd4j.createBuffer((Pointer)nativeOps.getConstantDataBufferPrimary(dbf), (Pointer)nativeOps.getConstantDataBufferSpecial(dbf), (long)values.length, (DataType)desiredType);
        buffer.setConstant(true);
        return buffer;
    }

    public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
        if (nativeOps.lastErrorCode() != 0) {
            throw new RuntimeException(nativeOps.lastErrorMessage());
        }
        DataBuffer buffer = Nd4j.createBuffer((Pointer)nativeOps.getConstantDataBufferPrimary(dbf), (Pointer)nativeOps.getConstantDataBufferSpecial(dbf), (long)values.length, (DataType)desiredType);
        buffer.setConstant(true);
        return buffer;
    }

    public String runLightBenchmarkSuit(boolean printOut) {
        return nativeOps.runLightBenchmarkSuit(printOut);
    }

    public String runFullBenchmarkSuit(boolean printOut) {
        return nativeOps.runFullBenchmarkSuit(printOut);
    }
}

