/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.grid.GridPointers;
import org.nd4j.linalg.api.ops.grid.OpDescriptor;
import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner;
import org.nd4j.linalg.jcublas.ops.executioner.aggregates.AggregateDescriptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class CudaGridExecutioner
extends CudaExecutioner
implements GridExecutioner {
    private ThreadLocal<OpDescriptor> lastOp = new ThreadLocal();
    private ThreadLocal<Deque<OpDescriptor>> deviceQueues = new ThreadLocal();
    private ThreadLocal<AtomicLong> opCounter = new ThreadLocal();
    private AtomicLong metaCounter = new AtomicLong(0L);
    private AtomicLong execCounter = new AtomicLong(0L);
    private List<WatchdogPair> watchdog = new CopyOnWriteArrayList<WatchdogPair>();
    private List<Queue<AggregateDescriptor>> aggregates = new ArrayList<Queue<AggregateDescriptor>>();
    private static Logger logger = LoggerFactory.getLogger(CudaGridExecutioner.class);
    private AtomicBoolean experimental = new AtomicBoolean(false);

    public CudaGridExecutioner() {
        this.deviceQueues.set(new ArrayDeque());
        int numDevices = nativeOps.getAvailableDevices();
        for (int x = 0; x < numDevices; ++x) {
            this.aggregates.add(new ConcurrentLinkedQueue());
        }
        this.experimental.set(nativeOps.isExperimentalEnabled());
    }

    @Override
    public INDArray exec(Op op) {
        this.checkForCompression(op);
        this.invokeWatchdog(op);
        if (op instanceof ReduceOp) {
            this.exec((ReduceOp)op, Integer.MAX_VALUE);
        } else if (op instanceof IndexAccumulation) {
            this.exec((IndexAccumulation)op, Integer.MAX_VALUE);
        } else if (op instanceof ScalarOp || op instanceof TransformOp) {
            this.processAsGridOp(op, new int[0]);
        } else if (op instanceof BroadcastOp) {
            this.invoke((BroadcastOp)op, null);
        } else {
            this.pushToGrid(new OpDescriptor(op));
        }
        return op.z();
    }

    protected void pushToGrid(OpDescriptor descriptor) {
        this.pushToGrid(descriptor, true);
    }

    protected void invokeWatchdog(Op op) {
        if (this.watchdog.size() > 0) {
            for (WatchdogPair pair : this.watchdog) {
                if (this.compareArrays(pair.getArray(), op)) continue;
                if (this.compareDevicePointers(pair.getArray(), op)) {
                    throw new RuntimeException();
                }
                if (!this.compareHostPointers(pair.getArray(), op)) continue;
            }
        }
    }

    protected boolean compareDevicePointers(INDArray array, Op op) {
        long opY;
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context);
        long opZ = AtomicAllocator.getInstance().getPointer(op.z(), context).address();
        long opX = AtomicAllocator.getInstance().getPointer(op.x(), context).address();
        long l = opY = op.y() == null ? 0L : AtomicAllocator.getInstance().getPointer(op.y(), context).address();
        if (opZ == pointer.address()) {
            return true;
        }
        if (opY == pointer.address()) {
            return true;
        }
        return opX == pointer.address();
    }

    protected boolean compareHostPointers(INDArray array, Op op) {
        long opY;
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context);
        long opZ = AtomicAllocator.getInstance().getHostPointer(op.z()).address();
        long opX = AtomicAllocator.getInstance().getHostPointer(op.x()).address();
        long l = opY = op.y() == null ? 0L : AtomicAllocator.getInstance().getHostPointer(op.y()).address();
        return opZ == pointer.address() || opY == pointer.address() || opX == pointer.address();
    }

    protected boolean compareArrays(INDArray array, Op op) {
        return op.x() == array || op.y() == array || op.z() == array;
    }

    protected void pushToGrid(OpDescriptor descriptor, boolean flush) {
        this.execCounter.incrementAndGet();
        Op op = descriptor.getOp();
        int[] dimensions = descriptor.getDimensions();
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            if (flush) {
                this.flushQueue();
            }
            super.invoke(t, null);
        } else if (op instanceof Variance) {
            Variance acc = (Variance)op;
            if (flush) {
                this.flushQueue();
            }
            super.naiveExec(acc, dimensions);
        } else if (op instanceof ReduceOp) {
            ReduceOp acc = (ReduceOp)op;
            if (flush) {
                this.flushQueue();
            }
            super.naiveExec(acc, dimensions);
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            if (flush) {
                this.flushQueue();
            }
            super.invoke(sc, null);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            if (flush) {
                this.flushQueue();
            }
            if (dimensions != null) {
                super.exec(broadcastOp);
            } else {
                super.invoke(broadcastOp, null);
            }
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            if (flush) {
                this.flushQueue();
            }
        } else if (op instanceof MetaOp) {
            this.metaCounter.incrementAndGet();
            this.exec((MetaOp)op);
        } else if (op instanceof GridOp) {
            this.exec((GridOp)op);
        }
    }

    public long getMetaCounter() {
        return this.metaCounter.get();
    }

    public long getExecutionCounter() {
        return this.execCounter.get();
    }

    protected void processAsGridOp(Op op, int ... dimension) {
        OpDescriptor last = this.lastOp.get();
        if (last != null) {
            MetaType type = this.getMetaOpType(op, dimension);
            this.lastOp.remove();
            try {
                switch (type) {
                    case NOT_APPLICABLE: {
                        this.dequeueOp(last);
                        this.pushToGrid(last, false);
                        if (op instanceof TransformOp && op.y() != null && this.onCurrentDeviceXYZ(op)) {
                            this.enqueueOp(new OpDescriptor(op, dimension));
                            break;
                        }
                        this.pushToGrid(new OpDescriptor(op, dimension), false);
                        break;
                    }
                    case PREDICATE: {
                        PredicateMetaOp metaOp = new PredicateMetaOp(last, new OpDescriptor(op, dimension));
                        this.pushToGrid(new OpDescriptor(metaOp), false);
                        break;
                    }
                    case INVERTED_PREDICATE: {
                        OpDescriptor currentOp = new OpDescriptor(op, dimension);
                        this.dequeueOp(last);
                        this.dequeueOp(currentOp);
                        InvertedPredicateMetaOp metaOp = new InvertedPredicateMetaOp(last, currentOp);
                        this.pushToGrid(new OpDescriptor(metaOp), false);
                        break;
                    }
                    case POSTULATE: {
                        PostulateMetaOp metaOp = new PostulateMetaOp(last, new OpDescriptor(op, dimension));
                        this.pushToGrid(new OpDescriptor(metaOp), false);
                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Not supported MetaType: [" + (Object)((Object)type) + "]");
                    }
                }
            }
            catch (Throwable t) {
                throw new RuntimeException("Error executing previous op: " + last.getOp().getClass().getName() + " - note that in some cases the error/stack trace may be delayed by 1 operation due to the asynchronous nature of ND4J's CUDA grid executioner.\nTo obtain the original error stack trace for debugging purposes, use nd4j-native backend, Nd4j.getExecutioner().commit() calls after ops, or set the following system property: set \"opexec\" to org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner", t);
            }
        } else if (op instanceof TransformOp && op.y() != null && this.onCurrentDeviceXYZ(op)) {
            this.enqueueOp(new OpDescriptor(op, dimension));
        } else {
            this.pushToGrid(new OpDescriptor(op, dimension), false);
        }
    }

    protected boolean onCurrentDeviceXYZ(Op op) {
        int deviceId = AtomicAllocator.getInstance().getDeviceId();
        int deviceX = AtomicAllocator.getInstance().getDeviceId(op.x());
        int deviceY = AtomicAllocator.getInstance().getDeviceId(op.y());
        int deviceZ = AtomicAllocator.getInstance().getDeviceId(op.y());
        return deviceId == deviceX && deviceY == deviceZ && deviceZ == deviceX;
    }

    protected void enqueueOp(OpDescriptor descriptor) {
        AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().x()).markEnqueued(true);
        AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().z()).markEnqueued(true);
        if (descriptor.getOp().y() != null) {
            AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().y()).markEnqueued(true);
        }
        this.lastOp.set(descriptor);
    }

    protected void dequeueOp(OpDescriptor descriptor) {
        AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().x()).markEnqueued(false);
        AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().z()).markEnqueued(false);
        if (descriptor.getOp().y() != null) {
            AtomicAllocator.getInstance().getAllocationPoint(descriptor.getOp().y()).markEnqueued(false);
        }
    }

    protected MetaType getMetaOpType(Op op, int ... dimension) {
        OpDescriptor last = this.lastOp.get();
        if (last == null) {
            return MetaType.NOT_APPLICABLE;
        }
        if (this.experimental.get()) {
            logger.info("Experimental hook");
            if (last.getOp() instanceof ScalarOp || last.getOp() instanceof TransformOp) {
                return this.isMatchingZX(last.getOp(), op) ? MetaType.PREDICATE : MetaType.NOT_APPLICABLE;
            }
            if (last.getOp() instanceof ReduceOp && (op instanceof ScalarOp || op instanceof TransformOp) && op.y() == null) {
                return this.isMatchingZX(last.getOp(), op) ? MetaType.POSTULATE : MetaType.NOT_APPLICABLE;
            }
        } else if (!(!(last.getOp() instanceof TransformOp) || last.getOp().y() == null || !(op instanceof ScalarOp) || ((ScalarOp)op).getDimension() != null || op instanceof ScalarMax || op instanceof ScalarMin || op.opNum() >= 7 && op.opNum() <= 11 || op.opNum() == 16 || op.opNum() == 13 || op.opNum() >= 56 && op.opNum() <= 59)) {
            return this.isMatchingZX(last.getOp(), op) ? MetaType.INVERTED_PREDICATE : MetaType.NOT_APPLICABLE;
        }
        return MetaType.NOT_APPLICABLE;
    }

    protected boolean isMatchingZX(Op opA, Op opB) {
        return opA.x() == opB.x() && opA.z() == opB.z() && opA.x() == opB.z();
    }

    protected boolean isMatchingZXY(Op opA, Op opB) {
        return opA.z() == opB.x() || opA.z() == opB.y();
    }

    protected GridPointers pointerizeOp(OpDescriptor descriptor) {
        return this.pointerizeOp(descriptor.getOp(), descriptor.getDimensions());
    }

    protected GridPointers pointerizeOp(Op op, int ... dimensions) {
        GridPointers pointers = new GridPointers(op, dimensions);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getDeviceContext();
        pointers.setX(allocator.getPointer(op.x(), context));
        pointers.setXShapeInfo(allocator.getPointer(op.x().shapeInfoDataBuffer(), context));
        pointers.setZ(allocator.getPointer(op.z(), context));
        pointers.setZShapeInfo(allocator.getPointer(op.z().shapeInfoDataBuffer(), context));
        pointers.setZLength(op.z().length());
        if (op.y() != null) {
            pointers.setY(allocator.getPointer(op.y(), context));
            pointers.setYShapeInfo(allocator.getPointer(op.y().shapeInfoDataBuffer(), context));
        }
        if (dimensions != null && dimensions.length > 0) {
            DataBuffer dimensionBuffer = Nd4j.getConstantHandler().getConstantBuffer(dimensions, DataType.INT);
            pointers.setDimensions(allocator.getPointer(dimensionBuffer, context));
            pointers.setDimensionsLength(dimensions.length);
        }
        if (dimensions != null && dimensions.length > 0) {
            Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimensions);
            Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
            Pointer devTadOffsets = tadBuffers.getSecond() == null ? null : AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context);
            pointers.setTadShape(devTadShapeInfo);
            pointers.setTadOffsets(devTadOffsets);
        }
        return pointers;
    }

    @Override
    public int getQueueLength() {
        return this.lastOp.get() == null ? 0 : 1;
    }

    @Deprecated
    protected int getQueueLength(int deviceId) {
        return -1;
    }

    protected GridOp buildGrid() {
        return null;
    }

    protected void buildZ(IndexAccumulation op, int ... dimension) {
        long[] retShape;
        long[] lArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension(dimension)) {
            long[] lArray2 = new long[2];
            lArray2[0] = 1L;
            lArray = lArray2;
            lArray2[1] = 1L;
        } else {
            lArray = retShape = ArrayUtil.removeIndex(op.x().shape(), dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new long[]{1L, retShape[0]} : new long[]{retShape[0], 1L};
        } else if (retShape.length == 0) {
            retShape = new long[]{1L, 1L};
        }
        if (op.z() == null || op.z() == op.x()) {
            INDArray ret = null;
            ret = Nd4j.createUninitialized(retShape);
            op.setZ(ret);
        } else if (!Arrays.equals(retShape, op.z().shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + op + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(op.z().shape()));
        }
    }

    protected void buildZ(ReduceOp op, int ... dimension) {
        long[] retShape;
        long[] lArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension(dimension)) {
            long[] lArray2 = new long[2];
            lArray2[0] = 1L;
            lArray = lArray2;
            lArray2[1] = 1L;
        } else {
            lArray = retShape = ArrayUtil.removeIndex(op.x().shape(), dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new long[]{1L, retShape[0]} : new long[]{retShape[0], 1L};
        } else if (retShape.length == 0) {
            retShape = new long[]{1L, 1L};
        }
        INDArray ret = null;
        if (op.z() == null || op.z() == op.x()) {
            if (op.isComplexAccumulation()) {
                long xT = op.x().tensorsAlongDimension(dimension);
                long yT = op.y().tensorsAlongDimension(dimension);
                ret = Nd4j.create(xT, yT);
            } else {
                ret = Nd4j.zeros(retShape);
            }
            op.setZ(ret);
        } else {
            if (op.z().length() != ArrayUtil.prodLong(retShape)) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
            }
            ret = op.z();
        }
    }

    public INDArray exec(ReduceOp op, int ... dimension) {
        if (dimension == null || dimension.length == 0 || dimension[0] == Integer.MAX_VALUE) {
            this.flushQueue();
        } else {
            this.buildZ(op, dimension);
            this.processAsGridOp(op, dimension);
        }
        return op.z();
    }

    public INDArray exec(IndexAccumulation op, int ... dimension) {
        if (dimension == null || dimension.length == 0 || dimension[0] == Integer.MAX_VALUE) {
            this.flushQueue();
            this.buildZ(op, Integer.MAX_VALUE);
            super.invoke(op, null, new int[]{Integer.MAX_VALUE});
        } else {
            this.buildZ(op, dimension);
            this.processAsGridOp(op, dimension);
        }
        return op.z();
    }

    public INDArray exec(BroadcastOp op, int ... dimension) {
        this.processAsGridOp(op, dimension);
        return op.z();
    }

    @Override
    protected CudaContext invoke(BroadcastOp op, OpContext oc) {
        Preconditions.checkState(oc == null);
        this.processAsGridOp(op, op.getDimension());
        return null;
    }

    @Override
    protected CudaContext invoke(ScalarOp op, OpContext oc) {
        Preconditions.checkState(oc == null);
        this.processAsGridOp(op, null);
        return null;
    }

    @Override
    protected CudaContext invoke(TransformOp op, OpContext oc) {
        Preconditions.checkState(oc == null);
        this.processAsGridOp(op, null);
        return null;
    }

    protected void prepareGrid(MetaOp op) {
        GridPointers ptrA = this.pointerizeOp(op.getFirstOpDescriptor());
        GridPointers ptrB = this.pointerizeOp(op.getSecondOpDescriptor());
        op.setFirstPointers(ptrA);
        op.setSecondPointers(ptrB);
    }

    @Override
    public void exec(MetaOp op) {
    }

    @Override
    public void exec(GridOp op) {
    }

    protected void purgeQueue() {
        this.lastOp.remove();
    }

    @Override
    public void flushQueue() {
        OpDescriptor op = this.lastOp.get();
        if (op != null) {
            if (!this.experimental.get()) {
                this.lastOp.remove();
                this.dequeueOp(op);
                this.pushToGrid(op, false);
            } else {
                throw new UnsupportedOperationException("Experimental flush isn't supported yet");
            }
        }
    }

    @Override
    public void flushQueueBlocking() {
        this.flushQueue();
        CudaContext context = AtomicAllocator.getInstance().getDeviceContext();
        context.syncSpecialStream();
        context.syncOldStream();
    }

    public void addToWatchdog(INDArray array, String tag) {
        this.watchdog.add(new WatchdogPair(array, tag));
    }

    @Override
    public INDArray exec(RandomOp op) {
        return this.exec(op, Nd4j.getRandom());
    }

    @Override
    public void exec(List<Aggregate> batch) {
        this.flushQueue();
        super.exec(batch);
    }

    @Override
    public void exec(Aggregate op) {
        this.flushQueue();
        super.exec(op);
    }

    @Override
    public void aggregate(Aggregate op) {
        this.aggregate(op, Thread.currentThread().getId());
    }

    @Override
    public void aggregate(Aggregate op, long key) {
        int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        if (this.opCounter.get() == null) {
            this.opCounter.set(new AtomicLong(0L));
        }
        this.aggregates.get(deviceId).add(new AggregateDescriptor(op, key, this.opCounter.get().getAndIncrement()));
    }

    @Override
    public INDArray exec(RandomOp op, Random rng) {
        this.flushQueue();
        return super.exec(op, rng);
    }

    protected void buildAggregation() {
    }

    @Override
    public void push() {
        this.flushQueue();
    }

    @Override
    public void commit() {
        this.flushQueueBlocking();
    }

    private static class WatchdogPair {
        private INDArray array;
        private String tag;

        public INDArray getArray() {
            return this.array;
        }

        public String getTag() {
            return this.tag;
        }

        public void setArray(INDArray array) {
            this.array = array;
        }

        public void setTag(String tag) {
            this.tag = tag;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof WatchdogPair)) {
                return false;
            }
            WatchdogPair other = (WatchdogPair)o;
            if (!other.canEqual(this)) {
                return false;
            }
            INDArray this$array = this.getArray();
            INDArray other$array = other.getArray();
            if (this$array == null ? other$array != null : !this$array.equals(other$array)) {
                return false;
            }
            String this$tag = this.getTag();
            String other$tag = other.getTag();
            return !(this$tag == null ? other$tag != null : !this$tag.equals(other$tag));
        }

        protected boolean canEqual(Object other) {
            return other instanceof WatchdogPair;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            INDArray $array = this.getArray();
            result = result * 59 + ($array == null ? 43 : $array.hashCode());
            String $tag = this.getTag();
            result = result * 59 + ($tag == null ? 43 : $tag.hashCode());
            return result;
        }

        public String toString() {
            return "CudaGridExecutioner.WatchdogPair(array=" + this.getArray() + ", tag=" + this.getTag() + ")";
        }

        public WatchdogPair() {
        }

        public WatchdogPair(INDArray array, String tag) {
            this.array = array;
            this.tag = tag;
        }
    }

    protected static enum MetaType {
        NOT_APPLICABLE,
        PREDICATE,
        INVERTED_PREDICATE,
        POSTULATE;

    }
}

