/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceSession
extends AbstractSession<INDArray, Pair<SameDiffOp, OpContext>> {
    private static final Logger log = LoggerFactory.getLogger(InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
    protected static final String KERAS_TRAIN_TEST = "keras_learning_phase";
    private SessionMemMgr mmgr;
    private IdentityDependencyTracker<INDArray, Dep> arrayUseTracker = new IdentityDependencyTracker();
    private Map<String, OpContext> opContexts = new HashMap<String, OpContext>();

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.mmgr = new ArrayCacheMemoryMgr();
    }

    @Override
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> placeholders, At at) {
        this.arrayUseTracker.clear();
        for (SDVariable v : this.sameDiff.variables()) {
            if (v.getVariableType() == VariableType.CONSTANT) {
                this.arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.name()));
                continue;
            }
            if (v.getVariableType() != VariableType.VARIABLE) continue;
            this.arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.name()));
        }
        boolean kerasWorkaround = false;
        List<String> phs = this.sameDiff.inputs();
        if (phs != null && !phs.isEmpty()) {
            for (String s : phs) {
                if (!s.endsWith(KERAS_TRAIN_TEST) || placeholders.containsKey(s)) continue;
                INDArray scalar = this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(at.operation().isTrainingPhase());
                placeholders = new HashMap<String, INDArray>(placeholders);
                placeholders.put(s, scalar);
                kerasWorkaround = true;
            }
        }
        if (placeholders == null || placeholders.isEmpty()) {
            return placeholders;
        }
        HashMap<String, INDArray> out = new HashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : placeholders.entrySet()) {
            Preconditions.checkState(this.sameDiff.hasVariable(e.getKey()), "Invalid placeholder passed for execution: No variable/placeholder with name %s exists", (Object)e.getKey());
            INDArray arr = e.getValue();
            if (arr.isAttached()) {
                MemoryWorkspace ws;
                MemoryWorkspace memoryWorkspace = ws = arr.data() == null ? null : arr.data().getParentWorkspace();
                if (ws != null && ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!ws.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses leaked workspace pointer from workspace [" + ws.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                    if (ws.getGenerationId() != arr.data().getGenerationId()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + e.getKey() + "\" array uses outdated workspace pointer from workspace [" + ws.getId() + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + arr.data().getGenerationId() + ". Workspace current iteration: " + ws.getGenerationId() + "\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                }
            }
            DataType dt = this.sameDiff.getVariable(e.getKey()).dataType();
            if (kerasWorkaround && e.getKey().endsWith(KERAS_TRAIN_TEST)) {
                this.arrayUseTracker.addDependency(arr, new ExecDoneDep());
            } else if (arr.dataType() == dt) {
                this.arrayUseTracker.addDependency(e.getValue(), new PlaceholderDep(e.getKey()));
            } else {
                INDArray cast = this.mmgr.allocate(false, dt, arr.shape());
                cast.assign(arr);
                arr = cast;
                this.arrayUseTracker.addDependency(arr, new ExecDoneDep());
            }
            out.put(e.getKey(), arr);
        }
        return out;
    }

    @Override
    protected Map<String, INDArray> postProcessOutput(Map<String, INDArray> output) {
        if (this.dt.hasNewAllSatisfied()) {
            List execSteps = this.dt.getNewAllSatisfiedList();
            for (AbstractSession.ExecStep es : execSteps) {
                if (es.getType() != AbstractSession.ExecType.OP) continue;
                OpDep od = new OpDep(es.getName(), es.getFrameIter().getFrame(), es.getFrameIter().getIteration(), es.getFrameIter().getParentFrame());
                this.arrayUseTracker.markSatisfied(od, true);
            }
        }
        this.arrayUseTracker.markSatisfied(new ExecDoneDep(), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            List l = this.arrayUseTracker.getNewAllSatisfiedList();
            for (INDArray arr : l) {
                this.mmgr.release(arr);
            }
        }
        return output;
    }

    public INDArray[] getOutputs(Pair<SameDiffOp, OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
        int i;
        SameDiffOp op = opPair.getFirst();
        at.setFrameIter(outputFrameIter);
        if (listeners != null && listeners.size() > 0) {
            SameDiffOp sdOp = this.sameDiff.getOps().get(op.getOp().getOwnName());
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                l.preOpExecution(this.sameDiff, at, sdOp, opPair.getSecond());
            }
        }
        INDArray[] out = this.doExec(op.getOp(), opPair.getRight(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
        if (log.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append(op.getName()).append(" - ").append(outputFrameIter).append(" outputs: ");
            List<String> opOutNames = op.getOutputsOfOp();
            for (i = 0; i < out.length; ++i) {
                if (i > 0) {
                    sb.append(", ");
                }
                sb.append("(").append(i).append(" - ").append((String)opOutNames.get(i)).append(" = ").append(out[i] == null ? null : Long.valueOf(out[i].getId())).append(")");
            }
            log.trace(sb.toString());
        }
        if (listeners != null && listeners.size() > 0) {
            Map namedOuts = null;
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                if (namedOuts == null) {
                    HashMap<String, INDArray> namedOutsBuilder = new HashMap<String, INDArray>();
                    for (int i2 = 0; i2 < out.length; ++i2) {
                        namedOutsBuilder.put(op.outputsOfOp.get(i2), out[i2]);
                    }
                    namedOuts = Collections.unmodifiableMap(namedOutsBuilder);
                }
                l.opExecution(this.sameDiff, at, batch, op, opPair.getSecond(), out);
                for (String varName : namedOuts.keySet()) {
                    l.activationAvailable(this.sameDiff, at, batch, op, varName, (INDArray)namedOuts.get(varName));
                }
            }
        }
        op.getOp().clearArrays();
        if (opPair.getSecond() != null) {
            opPair.getSecond().purge();
        }
        SameDiffOp o = this.sameDiff.getOps().get(op.getName());
        List<String> outVarNames = o.getOutputsOfOp();
        for (i = 0; i < out.length; ++i) {
            if (out[i] == null && o.getOp() instanceof Switch) continue;
            String name = outVarNames.get(i);
            Variable v = this.sameDiff.getVariables().get(name);
            List<String> inputsForOps = v.getInputsForOp();
            if (inputsForOps != null) {
                for (String opName : inputsForOps) {
                    OpDep d;
                    Dep d2;
                    if (!this.subgraphOps.contains(opName)) continue;
                    SameDiffOp forOp = this.sameDiff.getOps().get(opName);
                    if (forOp.getOp() instanceof Enter) {
                        Enter e = (Enter)forOp.getOp();
                        if (e.isConstant()) {
                            d2 = new ExecDoneDep();
                            this.arrayUseTracker.addDependency(out[i], d2);
                            continue;
                        }
                        d2 = new OpDep(opName, e.getFrameName(), 0, outputFrameIter);
                        this.arrayUseTracker.addDependency(out[i], d2);
                        continue;
                    }
                    if (forOp.getOp() instanceof NextIteration) {
                        d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration() + 1, outputFrameIter.getParentFrame());
                        this.arrayUseTracker.addDependency(out[i], d);
                        continue;
                    }
                    if (forOp.getOp() instanceof Exit) {
                        FrameIter fi = outputFrameIter.getParentFrame();
                        d2 = new OpDep(opName, fi.getFrame(), fi.getIteration(), fi.getParentFrame());
                        this.arrayUseTracker.addDependency(out[i], d2);
                        continue;
                    }
                    d = new OpDep(opName, outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
                    this.arrayUseTracker.addDependency(out[i], d);
                }
            }
            if ("main".equals(outputFrameIter.getFrame()) && allReqVariables.contains(name)) {
                this.arrayUseTracker.addDependency(out[i], new ReqOutputDep(name));
                continue;
            }
            if (inputsForOps != null && !inputsForOps.isEmpty() || this.arrayUseTracker.hasDependency(out[i])) continue;
            if (log.isTraceEnabled()) {
                log.trace("Found array id {} (output of {}) not required anywhere, deallocating", (Object)out[i].getId(), (Object)o.getName());
            }
            this.mmgr.release(out[i]);
        }
        OpDep d = new OpDep(op.getName(), outputFrameIter.getFrame(), outputFrameIter.getIteration(), outputFrameIter.getParentFrame());
        this.arrayUseTracker.markSatisfied(d, true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            List canClose = this.arrayUseTracker.getNewAllSatisfiedList();
            for (INDArray arr : canClose) {
                if (log.isTraceEnabled()) {
                    log.trace("Closing array... id={}, {}", (Object)arr.getId(), (Object)arr.shapeInfoToString());
                }
                this.mmgr.release(arr);
            }
        }
        return out;
    }

    public INDArray[] doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs) {
        boolean constPhInput;
        int totalInputs = (opInputs == null ? 0 : opInputs.size()) + (constAndPhInputs == null ? 0 : constAndPhInputs.size()) + (allIterInputs == null ? 0 : allIterInputs.size());
        boolean bl = constPhInput = !(opInputs != null && opInputs.size() != 0 || allIterInputs != null && allIterInputs.size() != 0);
        if (op instanceof Identity) {
            Identity i = (Identity)op;
            String[] argNames = i.argNames();
            Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object)argNames);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            INDArray orig = (INDArray)this.nodeOutputs.get(vid);
            return new INDArray[]{orig};
        }
        if (op instanceof Switch) {
            Switch s = (Switch)op;
            String[] argNames = s.argNames();
            AbstractSession.VarId vidPredicate = outputFrameIter.toVarId(argNames[1]);
            INDArray predicate = (INDArray)this.nodeOutputs.get(vidPredicate);
            if (predicate == null && !constAndPhInputs.isEmpty() && constAndPhInputs.contains(argNames[1])) {
                predicate = (INDArray)this.nodeOutputs.get(new AbstractSession.VarId(argNames[1], "main", 0, null));
            }
            Preconditions.checkNotNull((Object)predicate, "Error during graph execution: Predicate array was null. VarId=%s", (Object)vidPredicate);
            Preconditions.checkState(predicate.isScalar() && predicate.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", (Object)predicate);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            if (predicate.getDouble(0L) == 0.0) {
                return new INDArray[]{(INDArray)this.nodeOutputs.get(vid), null};
            }
            return new INDArray[]{null, (INDArray)this.nodeOutputs.get(vid)};
        }
        if (op instanceof Enter) {
            Enter e = (Enter)op;
            String[] input = e.argNames();
            Preconditions.checkState(input.length == 1, "Expected only 1 arg name for enter op: got %s", (Object)input);
            Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", (Object)e.getOwnName(), opInputs, constAndPhInputs);
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            INDArray enterInput = (INDArray)this.nodeOutputs.get(inputVarId);
            Preconditions.checkNotNull((Object)enterInput, "Could not get enter op \"%s\" input: output variable %s - %s", (Object)e.getOwnName(), (Object)e.outputVariablesNames(), (Object)outputFrameIter);
            return new INDArray[]{enterInput};
        }
        if (op instanceof Exit) {
            AbstractSession.VarId inputVarId = constPhInput ? new AbstractSession.VarId(constAndPhInputs.iterator().next(), "main", 0, null) : (allIterInputs != null && allIterInputs.size() > 0 ? allIterInputs.iterator().next() : opInputs.iterator().next());
            INDArray exitInput = (INDArray)this.nodeOutputs.get(inputVarId);
            return new INDArray[]{exitInput};
        }
        if (op instanceof NextIteration) {
            Preconditions.checkState(totalInputs == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", opInputs, constAndPhInputs);
            AbstractSession.VarId in = allIterInputs != null && !allIterInputs.isEmpty() ? allIterInputs.iterator().next() : opInputs.iterator().next();
            Preconditions.checkState(outputFrameIter.getFrame().equals(in.getFrame()), "Expected same frame for NextIteration input vs. output: got input %s, output %s", (Object)in, (Object)outputFrameIter);
            Preconditions.checkState(outputFrameIter.getIteration() == in.getIteration() + 1, "Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", (Object)in, (Object)outputFrameIter);
            INDArray inArr = (INDArray)this.nodeOutputs.get(in);
            if (inArr == null) {
                Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", op.getOwnName(), this.sameDiff.getOps().get(op.getOwnName()).getOutputsOfOp().get(0), outputFrameIter.getFrame(), outputFrameIter.getIteration());
            }
            return new INDArray[]{inArr};
        }
        if (op instanceof Merge) {
            Object[] in;
            Merge m = (Merge)op;
            for (String string : in = this.sameDiff.getInputsForOp(op)) {
                AbstractSession.VarId vid = outputFrameIter.toVarId(string);
                if (!this.nodeOutputs.containsKey(vid)) continue;
                log.trace("Returning input \"{}\" for merge node \"{}\"", (Object)m.getOwnName(), (Object)string);
                INDArray arr = (INDArray)this.nodeOutputs.get(vid);
                Preconditions.checkState(arr != null, "Could not find output array for %s", (Object)vid);
                return new INDArray[]{arr};
            }
            throw new IllegalStateException("Merge node " + m.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(in) + ") - should not be executed at this point");
        }
        if (op instanceof LoopCond) {
            LoopCond lc = (LoopCond)op;
            String[] argNames = lc.argNames();
            Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object)argNames);
            AbstractSession.VarId vid = outputFrameIter.toVarId(argNames[0]);
            INDArray arr = (INDArray)this.nodeOutputs.get(vid);
            Preconditions.checkNotNull(arr, "Input to LoopCond op must not be null");
            Preconditions.checkState(arr.isScalar() && arr.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
            return new INDArray[]{arr};
        }
        if (op instanceof BaseTensorOp) {
            return this.getOutputsHelperTensorArrayOps(op, outputFrameIter, opInputs, allIterInputs);
        }
        if (op instanceof GradientBackwardsMarker) {
            INDArray out = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Float.valueOf(1.0f));
            return new INDArray[]{out};
        }
        if (op instanceof ExternalErrorsFunction) {
            ExternalErrorsFunction fn = (ExternalErrorsFunction)op;
            String n = fn.getGradPlaceholderName();
            INDArray arr = (INDArray)this.nodeOutputs.get(new AbstractSession.VarId(n, "main", 0, null));
            Preconditions.checkState(arr != null, "Could not find external errors placeholder array: %s", (Object)arr);
            INDArray out = this.mmgr.allocate(false, arr.dataType(), arr.shape());
            out.assign(arr);
            return new INDArray[]{out};
        }
        if (op instanceof Assert) {
            boolean condition;
            Assert a = (Assert)op;
            boolean bl2 = condition = opContext.getInputArray(0).getDouble(0L) != 0.0;
            if (!condition) {
                INDArray msg;
                String s = "Assertion failed for operation \"" + op.getOwnName() + "\" during execution";
                if (a.numInputArguments() >= 3 && (msg = opContext.getInputArray(2)) != null && msg.dataType() == DataType.UTF8) {
                    s = s + ": " + msg.getString(0L);
                }
                if (a.numInputArguments() >= 5) {
                    INDArray arr = opContext.getInputArray(4);
                    s = s + "\n" + arr;
                }
                throw new IllegalStateException(s);
            }
            return opContext.getOutputArrays().toArray(new INDArray[0]);
        }
        if (op instanceof CustomOp) {
            CustomOp c = (CustomOp)((Object)op);
            Nd4j.exec(c, opContext);
            return opContext.getOutputArrays().toArray(new INDArray[0]);
        }
        if (op instanceof Op) {
            Op o = (Op)((Object)op);
            Nd4j.exec(o, opContext);
            return new INDArray[]{opContext.getOutputArray(0)};
        }
        throw new UnsupportedOperationException("Execution not yet implemented for: " + op.getClass().getName());
    }

    public INDArray[] getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs) {
        if (op instanceof TensorArray) {
            AbstractSession.VarId vid = outputFrameIter.toVarId(op.outputVariable().name());
            Preconditions.checkState(!this.tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", (Object)vid);
            this.tensorArrays.put(vid, new ArrayList());
            INDArray dummy = this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(true);
            INDArray scalar = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(0.0);
            return new INDArray[]{dummy, scalar};
        }
        if (op instanceof TensorArrayRead) {
            AbstractSession.VarId v;
            SDVariable idxSDV = op.arg(1);
            INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
            Preconditions.checkState(idxArr.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", (Object)idxArr);
            int i = idxArr.getInt(0);
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = v = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (v == null && allIterInputs != null) {
                v = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState(v != null, "Could not find input %s", (Object)inTensorArray.name());
            while (this.sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
                inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
                v = v.getParentFrame().toVarId(inTensorArray.name());
            }
            List list = this.getTensorArrays().get(v);
            Preconditions.checkState(list != null, "Could not find TensorList for %s", (Object)v);
            Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", (Object)i, (Object)list.size(), (Object)v);
            INDArray out = (INDArray)list.get(i);
            return new INDArray[]{out};
        }
        if (op instanceof TensorArrayWrite) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState(tArr != null, "Could not find input %s", (Object)inTensorArray.name());
            while (this.sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
                inTensorArray = this.sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
                tArr = tArr.getParentFrame().toVarId(inTensorArray.name());
            }
            String idxName = op.arg(1).name();
            SDVariable idxSDV = this.sameDiff.getVariable(idxName);
            INDArray idxArr = this.getArray(idxSDV, opInputs, allIterInputs);
            Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", (Object)idxArr);
            int idx = idxArr.getInt(0);
            String inName = op.arg(2).name();
            SDVariable inSDV = this.sameDiff.getVariable(inName);
            INDArray arr = this.getArray(inSDV, opInputs, allIterInputs);
            Preconditions.checkState(arr != null, "Could not find array for %s", (Object)inName);
            Preconditions.checkState(this.tensorArrays.containsKey(tArr), "Tensor array does not exist for %s", (Object)tArr);
            List l = (List)this.tensorArrays.get(tArr);
            while (l.size() <= idx) {
                l.add(null);
            }
            l.set(idx, arr);
            ExecDoneDep d = new ExecDoneDep();
            this.arrayUseTracker.addDependency(arr, d);
            INDArray scalar = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(0.0);
            return new INDArray[]{scalar};
        }
        if (op instanceof TensorArraySize) {
            List l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState((l = (List)this.tensorArrays.get(tArr)) != null, "Could not find TensorArray: %s", (Object)tArr);
            INDArray scalar = this.mmgr.allocate(false, DataType.INT, new long[0]).assign(l.size());
            return new INDArray[]{scalar};
        }
        if (op instanceof TensorArrayConcat) {
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            List l = (List)this.tensorArrays.get(tArr);
            Concat c = new Concat(0, l.toArray(new INDArray[0]));
            List<LongShapeDescriptor> shape = c.calculateOutputShape();
            INDArray out = this.mmgr.allocate(false, shape.get(0));
            c.setOutputArgument(0, out);
            Nd4j.exec(c);
            return new INDArray[]{out};
        }
        if (op instanceof TensorArrayGather) {
            List l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState((l = (List)this.tensorArrays.get(tArr)) != null, "Could not find TensorArray: %s", (Object)tArr);
            String indicesName = op.arg(1).name();
            SDVariable indicesSDV = this.sameDiff.getVariable(indicesName);
            INDArray idxArr = this.getArray(indicesSDV, opInputs, allIterInputs);
            Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", (Object)idxArr, (Object)indicesName);
            Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", (Object)idxArr.dataType(), (Object)indicesName);
            int[] idxArrInt = idxArr.toIntVector();
            ArrayList newList = new ArrayList();
            if (idxArrInt.length == 1 && idxArrInt[0] == -1) {
                newList.addAll(l);
            } else {
                for (int id : idxArrInt) {
                    Preconditions.checkState(id >= 0, "Index for TensorArrayGather must be >= 0, got %s", id);
                    newList.add(l.get(id));
                }
            }
            Stack s = new Stack(newList.toArray(new INDArray[0]), null, 0);
            List<LongShapeDescriptor> shape = s.calculateOutputShape();
            INDArray out = this.mmgr.allocate(false, shape.get(0));
            s.setOutputArgument(0, out);
            Nd4j.exec(s);
            return new INDArray[]{out};
        }
        if (op instanceof TensorArrayScatter) {
            List l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            TensorArray ta = (TensorArray)this.sameDiff.getVariableOutputOp(inTensorArray.name());
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState((l = (List)this.tensorArrays.get(tArr)) != null, "Could not find TensorArray: %s", (Object)tArr);
            String indicesName = op.arg(1).name();
            SDVariable indicesSDV = this.sameDiff.getVariable(indicesName);
            INDArray idxArr = this.getArray(indicesSDV, opInputs, allIterInputs);
            Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", (Object)idxArr, (Object)indicesName);
            Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", (Object)idxArr.dataType(), (Object)indicesName);
            int[] idxs = idxArr.toIntVector();
            String valuesName = op.arg(2).name();
            SDVariable valuesSDV = this.sameDiff.getVariable(valuesName);
            INDArray valuesArr = this.getArray(valuesSDV, opInputs, allIterInputs);
            while (l.size() <= idxs.length) {
                l.add(null);
            }
            if (idxs.length == 1 && idxs[0] == -1) {
                idxs = ArrayUtil.range(0, (int)valuesArr.size(0));
            }
            INDArrayIndex[] idx = ArrayUtil.nTimes(valuesArr.rank(), NDArrayIndex.all(), INDArrayIndex.class);
            for (int i = 0; i < idxs.length; ++i) {
                idx[0] = NDArrayIndex.point(i);
                INDArray get = this.mmgr.dup(valuesArr.get(idx));
                int outIdx = idxs[i];
                if (valuesArr.rank() == 1 && get.rank() > 0) {
                    get = get.reshape(new long[0]);
                }
                l.set(outIdx, get);
                this.arrayUseTracker.addDependency(get, new ExecDoneDep());
            }
            INDArray scalar = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(0.0);
            return new INDArray[]{scalar};
        }
        if (op instanceof TensorArraySplit) {
            List l;
            AbstractSession.VarId tArr;
            SDVariable inTensorArray = op.arg(0);
            AbstractSession.VarId varId = tArr = opInputs == null ? null : InferenceSession.lookup(inTensorArray.name(), opInputs, false);
            if (tArr == null && allIterInputs != null) {
                tArr = InferenceSession.lookup(inTensorArray.name(), allIterInputs, false);
            }
            Preconditions.checkState((l = (List)this.tensorArrays.get(tArr)) != null, "Could not find TensorArray: %s", (Object)tArr);
            String splitName = op.arg(1).name();
            INDArray splitArr = this.getArray(this.sameDiff.getVariable(splitName), opInputs, allIterInputs);
            String sizeName = op.arg(2).name();
            SDVariable sizeSDV = this.sameDiff.getVariable(sizeName);
            INDArray sizeArr = this.getArray(sizeSDV, opInputs, allIterInputs);
            Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", (Object)sizeArr, (Object)sizeName);
            Preconditions.checkState(sizeArr.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", (Object)sizeArr.dataType(), (Object)sizeName);
            int[] sizes = sizeArr.toIntVector();
            while (l.size() <= sizes.length) {
                l.add(null);
            }
            INDArrayIndex[] idx = ArrayUtil.nTimes(splitArr.rank(), NDArrayIndex.all(), INDArrayIndex.class);
            int soFar = 0;
            for (int i = 0; i < sizes.length; ++i) {
                idx[0] = NDArrayIndex.interval(soFar, soFar + sizes[i]);
                INDArray sub = this.mmgr.dup(splitArr.get(idx));
                l.set(i, sub);
                soFar += sizes[i];
                this.arrayUseTracker.addDependency(sub, new ExecDoneDep());
            }
            INDArray scalar = this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(0.0);
            return new INDArray[]{scalar};
        }
        throw new IllegalStateException("Execution support not yet implemented for: " + op.getClass().getName());
    }

    @Override
    public INDArray getConstantOrVariable(String variableName) {
        SDVariable v = this.sameDiff.getVariable(variableName);
        Preconditions.checkState(this.sameDiff.getVariable(variableName).isConstant() || v.getVariableType() == VariableType.VARIABLE, "Variable %s is not a constant", (Object)variableName);
        return this.sameDiff.getArrForVarName(variableName);
    }

    @Override
    public Pair<SameDiffOp, OpContext> getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, INDArray> placeholderValues, Set<String> allReqVariables) {
        int numConstPhIns;
        SameDiffOp sdo = this.sameDiff.getOps().get(opName);
        DifferentialFunction df = sdo.getOp();
        Preconditions.checkNotNull((Object)df, "No differential function found with name \"%s\"", (Object)opName);
        if (df instanceof LoopCond || df instanceof Enter || df instanceof Exit || df instanceof NextIteration || df instanceof Merge || df instanceof Switch || df instanceof BaseTensorOp) {
            return new Pair<SameDiffOp, Object>(sdo, null);
        }
        String[] argNames = df.argNames();
        int numArgs = argNames == null ? 0 : argNames.length;
        int numNonConstIns = opInputs == null ? 0 : opInputs.size();
        int numNonConstInsAllIters = allIterInputs == null ? 0 : allIterInputs.size();
        int n = numConstPhIns = constAndPhInputs == null ? 0 : constAndPhInputs.size();
        if (numArgs != numNonConstIns + numConstPhIns + numNonConstInsAllIters) {
            if (numArgs > 1) {
                HashSet uniqueArgNames = new HashSet();
                Collections.addAll(uniqueArgNames, argNames);
                Preconditions.checkState(uniqueArgNames.size() == numNonConstIns + numConstPhIns + numNonConstInsAllIters, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", (Object)df.getClass().getSimpleName(), (Object)opName, uniqueArgNames, opInputs, constAndPhInputs);
            } else {
                Preconditions.checkState(numArgs == numNonConstIns + numConstPhIns, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", (Object)df.getClass().getSimpleName(), (Object)opName, (Object)argNames, opInputs, constAndPhInputs);
            }
        }
        INDArray[] args = null;
        if (argNames != null && argNames.length > 0) {
            args = new INDArray[argNames.length];
            int i = 0;
            for (String s : argNames) {
                SDVariable v = this.sameDiff.getVariable(s);
                if (v.isConstant()) {
                    args[i] = v.getArr();
                } else if (v.getVariableType() == VariableType.VARIABLE) {
                    args[i] = v.getArr();
                } else if (v.isPlaceHolder()) {
                    Preconditions.checkState(placeholderValues != null && placeholderValues.containsKey(s), "No array was provided for required placeholder variable \"%s\"", (Object)s);
                    args[i] = placeholderValues.get(s);
                } else {
                    AbstractSession.VarId vid = InferenceSession.lookup(s, opInputs, allIterInputs, true);
                    args[i] = (INDArray)this.nodeOutputs.get(vid);
                }
                Preconditions.checkNotNull((Object)args[i], "Could not parameterize op %s: array %s (variable %s) is null", (Object)opName, (Object)i, (Object)v.name());
                ++i;
            }
        }
        boolean isLoop = !frameIter.getFrame().equals("main") && frameIter.getIteration() > 0;
        OpContext oc = this.opContexts.get(opName);
        if (oc == null) {
            oc = Nd4j.getExecutioner().buildContext();
            this.opContexts.put(opName, oc);
        }
        if (df instanceof CustomOp) {
            List<LongShapeDescriptor> outShape;
            DynamicCustomOp customOp = (DynamicCustomOp)df;
            if (args != null) {
                oc.setInputArrays(args);
            }
            if (df instanceof Identity) {
                return new Pair<SameDiffOp, OpContext>(sdo, oc);
            }
            if (customOp.numIArguments() > 0) {
                oc.setIArguments(customOp.iArgs());
            }
            if (customOp.numDArguments() > 0) {
                oc.setDArguments(customOp.dArgs());
            }
            if (customOp.numTArguments() > 0) {
                oc.setTArguments(customOp.tArgs());
            }
            if (customOp.numBArguments() > 0) {
                oc.setBArguments(customOp.bArgs());
            }
            Preconditions.checkState((outShape = customOp.calculateOutputShape(oc)) != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", (Object)customOp.opName(), (Object)customOp.getOwnName());
            String[] outNames = df.outputVariablesNames();
            Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", (Object)df.opName(), (Object)outShape.size(), (Object)outNames.length);
            for (int i = 0; i < outShape.size(); ++i) {
                DataType currDT;
                LongShapeDescriptor reqShape = outShape.get(i);
                DataType dt = this.sameDiff.getVariable(outNames[i]).dataType();
                if (dt != (currDT = reqShape.dataType())) {
                    reqShape = reqShape.asDataType(dt);
                }
                boolean isOutput = allReqVariables.contains(outNames[i]);
                INDArray out = this.mmgr.allocate(isOutput, reqShape);
                oc.setOutputArray(i, out);
            }
        } else if (df instanceof Op) {
            Op op = (Op)((Object)df);
            boolean axisArg = false;
            boolean emptyReduce = false;
            if (op instanceof ReduceOp && ((ReduceOp)op).getOpType() != Op.Type.REDUCE3 && df.argNames().length == 2) {
                SDVariable axisArgVar = df.arg(1);
                Preconditions.checkState(axisArgVar.dataType().isIntType(), "Legacy op %s input 1 (axis) was expected to be an integer type, is %s", df.getClass(), (Object)axisArgVar.dataType());
                INDArray arr = this.getArray(axisArgVar, opInputs, allIterInputs);
                Preconditions.checkState(arr != null, "Could not get axis argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                if (!arr.isEmpty()) {
                    int[] axis = arr.toIntVector();
                    int rank = args[0].rank();
                    axis = Shape.normalizeAxis(rank, axis);
                    df.setDimensions(axis);
                    ((BaseReduceOp)op).setEmptyReduce(false);
                } else {
                    df.setDimensions(null);
                    emptyReduce = true;
                    ((BaseReduceOp)op).setEmptyReduce(true);
                }
                axisArg = true;
            } else if (op instanceof ScalarOp && df.argNames().length == 2) {
                SDVariable scalarVar = df.arg(1);
                INDArray scalar = this.getArray(scalarVar, opInputs, allIterInputs);
                Preconditions.checkState(scalar != null, "Could not get scalar argument for op %s: %s", (Object)df.getOwnName(), df.getClass());
                Preconditions.checkState(scalar.isScalar(), "Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", (Object)df.getOwnName(), df.getClass(), (Object)scalar);
                ((ScalarOp)op).setScalar(scalar);
            }
            if (args != null && args.length > 0) {
                oc.setInputArray(0, args[0]);
                if (args.length == 2 && !axisArg) {
                    oc.setInputArray(1, args[1]);
                }
            }
            boolean isOutput = allReqVariables.contains(((BaseOp)op).outputVariablesNames()[0]);
            if (emptyReduce) {
                INDArray z = this.mmgr.allocate(false, oc.getInputArray(0).dataType(), oc.getInputArray(0).shape());
                oc.setOutputArray(0, z);
            } else {
                List<LongShapeDescriptor> outputShape = ((BaseOp)op).calculateOutputShape(oc);
                Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
                LongShapeDescriptor lsd = outputShape.get(0);
                INDArray z = this.mmgr.allocate(isOutput, lsd);
                oc.setOutputArray(0, z);
            }
        }
        return new Pair<SameDiffOp, OpContext>(sdo, oc);
    }

    protected INDArray getArray(SDVariable sdv, Collection<AbstractSession.VarId> opInputs, Collection<AbstractSession.VarId> allIterInputs) {
        String n = sdv.name();
        if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
            return this.getConstantOrVariable(n);
        }
        AbstractSession.VarId inVarId = InferenceSession.lookup(n, opInputs, allIterInputs, false);
        Preconditions.checkState(inVarId != null, "Could not find array for variable %s", (Object)sdv.name());
        return (INDArray)this.nodeOutputs.get(inVarId);
    }

    public SessionMemMgr getMmgr() {
        return this.mmgr;
    }

    public void setMmgr(SessionMemMgr mmgr) {
        this.mmgr = mmgr;
    }

    public IdentityDependencyTracker<INDArray, Dep> getArrayUseTracker() {
        return this.arrayUseTracker;
    }

    public void setArrayUseTracker(IdentityDependencyTracker<INDArray, Dep> arrayUseTracker) {
        this.arrayUseTracker = arrayUseTracker;
    }

    protected static class ExecDoneDep
    extends Dep {
        @Override
        public String toString() {
            return "InferenceSession.ExecDoneDep()";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ExecDoneDep)) {
                return false;
            }
            ExecDoneDep other = (ExecDoneDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            return super.equals(o);
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ExecDoneDep;
        }

        @Override
        public int hashCode() {
            int result = super.hashCode();
            return result;
        }
    }

    protected static class ReqOutputDep
    extends Dep {
        protected String outputName;

        public String getOutputName() {
            return this.outputName;
        }

        public void setOutputName(String outputName) {
            this.outputName = outputName;
        }

        @Override
        public String toString() {
            return "InferenceSession.ReqOutputDep(outputName=" + this.getOutputName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ReqOutputDep)) {
                return false;
            }
            ReqOutputDep other = (ReqOutputDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$outputName = this.getOutputName();
            String other$outputName = other.getOutputName();
            return !(this$outputName == null ? other$outputName != null : !this$outputName.equals(other$outputName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ReqOutputDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $outputName = this.getOutputName();
            result = result * 59 + ($outputName == null ? 43 : $outputName.hashCode());
            return result;
        }

        public ReqOutputDep(String outputName) {
            this.outputName = outputName;
        }
    }

    protected static class ConstantDep
    extends Dep {
        protected String constName;

        public String getConstName() {
            return this.constName;
        }

        public void setConstName(String constName) {
            this.constName = constName;
        }

        @Override
        public String toString() {
            return "InferenceSession.ConstantDep(constName=" + this.getConstName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ConstantDep)) {
                return false;
            }
            ConstantDep other = (ConstantDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$constName = this.getConstName();
            String other$constName = other.getConstName();
            return !(this$constName == null ? other$constName != null : !this$constName.equals(other$constName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof ConstantDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $constName = this.getConstName();
            result = result * 59 + ($constName == null ? 43 : $constName.hashCode());
            return result;
        }

        public ConstantDep(String constName) {
            this.constName = constName;
        }
    }

    protected static class VariableDep
    extends Dep {
        protected String varName;

        public String getVarName() {
            return this.varName;
        }

        public void setVarName(String varName) {
            this.varName = varName;
        }

        @Override
        public String toString() {
            return "InferenceSession.VariableDep(varName=" + this.getVarName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VariableDep)) {
                return false;
            }
            VariableDep other = (VariableDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$varName = this.getVarName();
            String other$varName = other.getVarName();
            return !(this$varName == null ? other$varName != null : !this$varName.equals(other$varName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof VariableDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $varName = this.getVarName();
            result = result * 59 + ($varName == null ? 43 : $varName.hashCode());
            return result;
        }

        public VariableDep(String varName) {
            this.varName = varName;
        }
    }

    protected static class PlaceholderDep
    extends Dep {
        protected String phName;

        public String getPhName() {
            return this.phName;
        }

        public void setPhName(String phName) {
            this.phName = phName;
        }

        @Override
        public String toString() {
            return "InferenceSession.PlaceholderDep(phName=" + this.getPhName() + ")";
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof PlaceholderDep)) {
                return false;
            }
            PlaceholderDep other = (PlaceholderDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$phName = this.getPhName();
            String other$phName = other.getPhName();
            return !(this$phName == null ? other$phName != null : !this$phName.equals(other$phName));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof PlaceholderDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $phName = this.getPhName();
            result = result * 59 + ($phName == null ? 43 : $phName.hashCode());
            return result;
        }

        public PlaceholderDep(String phName) {
            this.phName = phName;
        }
    }

    public static class OpDep
    extends Dep {
        protected String opName;
        protected int iter;

        protected OpDep(@NonNull String opName, @NonNull String frame, int iter, FrameIter parentFrame) {
            if (opName == null) {
                throw new NullPointerException("opName is marked non-null but is null");
            }
            if (frame == null) {
                throw new NullPointerException("frame is marked non-null but is null");
            }
            this.opName = opName;
            this.frame = frame;
            this.iter = iter;
            this.parentFrame = parentFrame;
        }

        @Override
        public String toString() {
            return "OpDep(" + this.opName + ",frame=" + this.frame + ",iter=" + this.iter + (this.parentFrame == null ? "" : ",parent=" + this.parentFrame) + ")";
        }

        public OpDep(String opName, int iter) {
            this.opName = opName;
            this.iter = iter;
        }

        public String getOpName() {
            return this.opName;
        }

        public int getIter() {
            return this.iter;
        }

        public void setOpName(String opName) {
            this.opName = opName;
        }

        public void setIter(int iter) {
            this.iter = iter;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof OpDep)) {
                return false;
            }
            OpDep other = (OpDep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!super.equals(o)) {
                return false;
            }
            String this$opName = this.getOpName();
            String other$opName = other.getOpName();
            if (this$opName == null ? other$opName != null : !this$opName.equals(other$opName)) {
                return false;
            }
            return this.getIter() == other.getIter();
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof OpDep;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = super.hashCode();
            String $opName = this.getOpName();
            result = result * 59 + ($opName == null ? 43 : $opName.hashCode());
            result = result * 59 + this.getIter();
            return result;
        }
    }

    public static abstract class Dep {
        protected String frame;
        protected FrameIter parentFrame;

        public String getFrame() {
            return this.frame;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String frame) {
            this.frame = frame;
        }

        public void setParentFrame(FrameIter parentFrame) {
            this.parentFrame = parentFrame;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Dep)) {
                return false;
            }
            Dep other = (Dep)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$frame = this.getFrame();
            String other$frame = other.getFrame();
            if (this$frame == null ? other$frame != null : !this$frame.equals(other$frame)) {
                return false;
            }
            FrameIter this$parentFrame = this.getParentFrame();
            FrameIter other$parentFrame = other.getParentFrame();
            return !(this$parentFrame == null ? other$parentFrame != null : !((Object)this$parentFrame).equals(other$parentFrame));
        }

        protected boolean canEqual(Object other) {
            return other instanceof Dep;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $frame = this.getFrame();
            result = result * 59 + ($frame == null ? 43 : $frame.hashCode());
            FrameIter $parentFrame = this.getParentFrame();
            result = result * 59 + ($parentFrame == null ? 43 : ((Object)$parentFrame).hashCode());
            return result;
        }

        public String toString() {
            return "InferenceSession.Dep(frame=" + this.getFrame() + ", parentFrame=" + this.getParentFrame() + ")";
        }
    }
}

