/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import com.google.flatbuffers.FlatBufferBuilder;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.ListenerVariables;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.ArgumentInterceptor;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.autodiff.samediff.SameDiffLambda;
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.TrainingSession;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDBitwise;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDImage;
import org.nd4j.autodiff.samediff.ops.SDLinalg;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.ND4JFileUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Sets;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.GraphDef;

public class SameDiff
extends SDBaseOps {
    private static final Logger log = LoggerFactory.getLogger(SameDiff.class);
    protected static final String GRAD_FN_KEY = "grad";
    private final Map<String, Variable> variables = new LinkedHashMap<String, Variable>();
    private final Map<String, SameDiffOp> ops = new LinkedHashMap<String, SameDiffOp>();
    private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<Long, InferenceSession>();
    private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
    private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
    private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<Long, Map<String, INDArray>>();
    private final List<String> lossVariables = new ArrayList<String>();
    private final List<Listener> listeners = new ArrayList<Listener>();
    private final List<NameScope> nameScopes = new ArrayList<NameScope>();
    private List<String> outputs;
    private TrainingConfig trainingConfig;
    private boolean initializedTraining;
    private Map<String, GradientUpdater> updaterMap;
    private int variableId = 0;
    public final SDMath math = new SDMath(this);
    public final SDRandom random = new SDRandom(this);
    public final SDNN nn = new SDNN(this);
    public final SDCNN cnn = new SDCNN(this);
    public final SDRNN rnn = new SDRNN(this);
    public final SDLoss loss = new SDLoss(this);
    public final SDImage image = new SDImage(this);
    public final SDBitwise bitwise = new SDBitwise(this);
    public final SDLinalg linalg = new SDLinalg(this);
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered = new AtomicBoolean(false);
    private boolean debugMode;
    private Stack<ArgumentInterceptor> argumentInterceptors = new Stack();
    private Set<ArgumentInterceptor> pausedArgumentInterceptors = new HashSet<ArgumentInterceptor>();
    private Set<String> blockNames = new HashSet<String>();
    boolean logExecution = true;
    private SameDiff parent;
    private SameDiff child;

    public SDMath math() {
        return this.math;
    }

    public SDRandom random() {
        return this.random;
    }

    public SDNN nn() {
        return this.nn;
    }

    public SDCNN cnn() {
        return this.cnn;
    }

    public SDRNN rnn() {
        return this.rnn;
    }

    public SDLoss loss() {
        return this.loss;
    }

    public SDImage image() {
        return this.image;
    }

    public SDBitwise bitwise() {
        return this.bitwise;
    }

    public SDLinalg linalg() {
        return this.linalg;
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    public void setListeners(Listener ... listeners) {
        this.listeners.clear();
        this.addListeners(listeners);
    }

    public void setListeners(Collection<? extends Listener> listeners) {
        this.listeners.clear();
        this.addListeners(listeners);
    }

    public void addListeners(Listener ... listeners) {
        this.addListeners(Arrays.asList(listeners));
    }

    public void addListeners(Collection<? extends Listener> listeners) {
        this.listeners.addAll(listeners);
    }

    public List<Listener> getListeners() {
        return this.listeners;
    }

    public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize) {
        if (variableArrayHolder == null) {
            throw new NullPointerException("variableArrayHolder is marked non-null but is null");
        }
        if (constantArrayHolder == null) {
            throw new NullPointerException("constantArrayHolder is marked non-null but is null");
        }
        if (initialize) {
            variableArrayHolder.initFrom(this.variablesArrays);
            constantArrayHolder.initFrom(this.constantArrays);
        }
        this.variablesArrays = variableArrayHolder;
        this.constantArrays = constantArrayHolder;
    }

    public String currentNameScope() {
        if (this.nameScopes.isEmpty()) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        boolean first = true;
        for (NameScope ns : this.nameScopes) {
            if (!first) {
                sb.append("/");
            }
            sb.append(ns.getName());
            first = false;
        }
        return sb.toString();
    }

    protected String nameWithScope(String name) {
        String scope = this.currentNameScope();
        if (scope == null) {
            return name;
        }
        if (!name.startsWith(scope + "/")) {
            return scope + "/" + name;
        }
        return name;
    }

    void addNameScope(NameScope nameScope) {
        this.nameScopes.add(nameScope);
    }

    void closeNameScope(NameScope nameScope) {
        Preconditions.checkState(!this.nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined");
        Preconditions.checkState(this.nameScopes.get(this.nameScopes.size() - 1).equals(nameScope), "Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", (Object)nameScope, (Object)this.currentNameScope());
        this.nameScopes.remove(this.nameScopes.size() - 1);
    }

    public NameScope withNameScope(String nameScope) {
        NameScope ns = new NameScope(this, nameScope);
        this.addNameScope(ns);
        return ns;
    }

    public List<SameDiffOp> getOpsInScope(NameScope scope) {
        ArrayList<SameDiffOp> ops = new ArrayList<SameDiffOp>();
        for (SameDiffOp v : this.ops.values()) {
            if (!v.getName().startsWith(scope.getName())) continue;
            ops.add(v);
        }
        return ops;
    }

    public List<SameDiffOp> getOpsInScope(String scope) {
        return this.getOpsInScope(new NameScope(this, scope));
    }

    public List<SDVariable> getVariablesInScope(NameScope scope) {
        ArrayList<SDVariable> vars = new ArrayList<SDVariable>();
        for (SDVariable v : this.variables()) {
            if (!v.name().startsWith(scope.getName())) continue;
            vars.add(v);
        }
        return vars;
    }

    public List<SDVariable> getVariablesInScope(String scope) {
        return this.getVariablesInScope(new NameScope(this, scope));
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap<Integer, Integer> thisVertexIdToNew = new HashMap<Integer, Integer>();
        int idx = 1;
        for (SDVariable var : this.variables()) {
            SDVariable clone = var.clone(this);
            SDVariable newVar = sameDiff.var(clone);
            if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null) {
                sameDiff.associateArrayWithVariable(var.getArr(), newVar);
            }
            thisVertexIdToNew.put(idx, idx);
            clone.setSameDiff(sameDiff);
            ++idx;
        }
        HashMap<String, Integer> reverseMap = new HashMap<String, Integer>();
        int count = 0;
        for (Variable v : this.variables.values()) {
            reverseMap.put(v.getName(), count++);
        }
        LinkedHashMap<String, DifferentialFunction> newFunctions = new LinkedHashMap<String, DifferentialFunction>();
        for (SameDiffOp op : this.ops.values()) {
            DifferentialFunction function = op.getOp();
            DifferentialFunction clone = FlatBuffersMapper.cloneViaSerialize(this, function, reverseMap);
            clone.setSameDiff(sameDiff);
            clone.setOwnName(function.getOwnName());
            if (sameDiff.opExists(function.getOwnName())) {
                sameDiff.putOpForId(function.getOwnName(), function);
            }
            newFunctions.put(function.getOwnName(), clone);
            SDVariable[] argsForFunction = function.args();
            SDVariable[] outputsForFunction = function.outputVariables();
            sameDiff.addArgsFor(argsForFunction, clone);
            sameDiff.addOutgoingFor(outputsForFunction, function);
            for (SDVariable arg : clone.args()) {
                arg.setSameDiff(sameDiff);
            }
            for (SDVariable output : clone.outputVariables()) {
                output.setSameDiff(sameDiff);
            }
            sameDiff.ops.put(function.getOwnName(), op);
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean opExists(String id) {
        return this.ops.containsKey(id);
    }

    public DifferentialFunction getVariableOutputOp(String variableName) {
        Preconditions.checkState(this.variables.containsKey(variableName), "No variable with name \"%s\" found in graph", (Object)variableName);
        if (this.variables.get(variableName).getOutputOfOp() == null) {
            return null;
        }
        return this.ops.get(this.variables.get(variableName).getOutputOfOp()).getOp();
    }

    public DifferentialFunction getOpById(@NonNull String id) {
        if (id == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (!this.ops.containsKey(id)) {
            throw new ND4JIllegalStateException("No function with id " + id + " found!");
        }
        return this.ops.get(id).getOp();
    }

    public void putOpForId(String id, DifferentialFunction function) {
        if (this.ops.containsKey(id) && this.ops.get(id).getOp() == null) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (!this.ops.containsKey(id)) {
            this.ops.put(id, SameDiffOp.builder().name(id).op(function).build());
        }
    }

    public String[] getInputsForOp(@NonNull DifferentialFunction function) {
        if (function == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\"");
        }
        List<String> inputs = this.ops.get(function.getOwnName()).getInputsToOp();
        return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
    }

    public String[] getOutputsForOp(DifferentialFunction function) {
        if (!this.ops.containsKey(function.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName());
        }
        List<String> outputs = this.ops.get(function.getOwnName()).getOutputsOfOp();
        return outputs == null ? null : outputs.toArray(new String[outputs.size()]);
    }

    public SDVariable[] getOutputVariablesForOp(DifferentialFunction function) {
        String[] inputs = this.getOutputsForOp(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
        }
        return vars;
    }

    public SDVariable[] getInputVariablesForOp(DifferentialFunction function) {
        String[] inputs = this.getInputsForOp(function);
        if (inputs == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + function);
        }
        SDVariable[] vars = new SDVariable[inputs.length];
        for (int i = 0; i < inputs.length; ++i) {
            vars[i] = this.getVariable(inputs[i]);
            if (vars[i] != null) continue;
            throw new ND4JIllegalStateException("Found null variable at index " + i);
        }
        return vars;
    }

    public void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr) {
        if (varName == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        if (arr == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(varName), "No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (v.isConstant()) {
            this.constantArrays.setArray(varName, arr);
        } else if (v.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.setArray(varName, arr);
        } else if (v.isPlaceHolder()) {
            long tid = Thread.currentThread().getId();
            if (!this.placeholdersPerThread.containsKey(tid)) {
                this.placeholdersPerThread.put(tid, new HashMap());
            }
            this.placeholdersPerThread.get(tid).put(varName, arr);
        } else {
            throw new UnsupportedOperationException("Cannot set variable of type " + (Object)((Object)v.getVariableType()) + " using this method");
        }
    }

    public boolean arrayAlreadyExistsForVarName(String varName) {
        SDVariable var = this.getVariable(varName);
        switch (var.getVariableType()) {
            case VARIABLE: {
                return this.variablesArrays.hasArray(varName);
            }
            case ARRAY: {
                long tid = Thread.currentThread().getId();
                return this.sessions.containsKey(tid) && this.sessions.get(tid).contains(varName, "main", 0, null);
            }
            case CONSTANT: {
                return this.constantArrays.hasArray(varName);
            }
            case PLACEHOLDER: {
                return this.placeholdersPerThread.containsKey(Thread.currentThread().getId()) && this.placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)var.getVariableType()));
    }

    public INDArray getArrForVarName(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(varName), "No variable found with name \"%s\"", (Object)varName);
        SDVariable v = this.variables.get(varName).getVariable();
        switch (v.getVariableType()) {
            case VARIABLE: {
                return this.variablesArrays.getArray(varName);
            }
            case CONSTANT: {
                if (!this.constantArrays.hasArray(varName)) {
                    return null;
                }
                return this.constantArrays.getArray(varName);
            }
            case ARRAY: {
                InferenceSession s = this.sessions.get(Thread.currentThread().getId());
                if (s == null) {
                    return null;
                }
                return (INDArray)s.get(varName, "main", 0, null, false);
            }
            case PLACEHOLDER: {
                long tid = Thread.currentThread().getId();
                if (this.placeholdersPerThread.get(tid) == null || !this.placeholdersPerThread.get(tid).containsKey(varName)) {
                    return null;
                }
                return this.placeholdersPerThread.get(tid).get(varName);
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)v.getVariableType()));
    }

    public void associateArrayWithVariable(INDArray arr, @NonNull String variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(variable), "Cannot associate array with variable \"%s\": variable \"%s\" does not exist in this SameDiff instance", (Object)variable, (Object)variable);
        this.associateArrayWithVariable(arr, this.getVariable(variable));
    }

    public void associateArrayWithVariable(INDArray arr, SDVariable variable) {
        if (variable == null) {
            throw new ND4JIllegalArgumentException("Variable must not be null!");
        }
        if (arr == null) {
            throw new ND4JIllegalArgumentException("Array must not be null");
        }
        if (variable.dataType() != arr.dataType()) {
            arr = arr.castTo(variable.dataType());
        }
        Preconditions.checkState(variable.dataType() == arr.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", (Object)variable.name(), (Object)variable.dataType(), (Object)arr.dataType());
        if (this.sessions.get(Thread.currentThread().getId()) == null) {
            this.sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
        }
        if (arr.isAttached()) {
            arr = arr.detach();
        }
        switch (variable.getVariableType()) {
            case VARIABLE: {
                this.variablesArrays.setArray(variable.name(), arr);
                break;
            }
            case CONSTANT: {
                this.constantArrays.setArray(variable.name(), arr);
                break;
            }
            case ARRAY: {
                throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for this type of variable is calculated ");
            }
            case PLACEHOLDER: {
                long[] phShape = variable.placeholderShape();
                Preconditions.checkState(phShape == null || Shape.shapeMatchesPlaceholder(phShape, arr.shape()), "Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:shape is wrong rank or does not match on one or more dimensions", (Object)arr, (Object)phShape);
                long tid = Thread.currentThread().getId();
                if (!this.placeholdersPerThread.containsKey(tid)) {
                    this.placeholdersPerThread.put(tid, new HashMap());
                }
                this.placeholdersPerThread.get(tid).put(variable.name(), arr);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown variable type: " + (Object)((Object)variable.getVariableType()));
            }
        }
        if (this.sameDiffFunctionInstances != null && this.sameDiffFunctionInstances.size() > 0) {
            for (Map.Entry<String, SameDiff> e : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = e.getValue();
                SDVariable v = sd.getVariable(variable.name());
                if (v == null) continue;
                sd.associateArrayWithVariable(arr, v);
            }
        }
    }

    public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable) {
        if (arr == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT, "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", (Object)variable.name(), (Object)variable.getVariableType());
        if (arr.isView()) {
            arr = arr.dup();
        }
        if (variable.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.setArray(variable.name(), arr);
        } else {
            this.constantArrays.setArray(variable.name(), arr);
        }
    }

    public void putSubFunction(String name, SameDiff nameSpace) {
        if (this.sameDiffFunctionInstances.containsKey(name) && this.sameDiffFunctionInstances.get(name) != nameSpace) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(name, nameSpace);
    }

    public Map<String, SDVariable> variableMap() {
        LinkedHashMap<String, SDVariable> ret = new LinkedHashMap<String, SDVariable>();
        for (Variable v : this.variables.values()) {
            ret.put(v.getName(), v.getVariable());
        }
        return ret;
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    private SameDiff() {
        super(null);
        this.sd = this;
        this.sameDiffFunctionInstances = new LinkedHashMap<String, SameDiff>();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public <X extends SDVariable> X setupFunction(X function) {
        Preconditions.checkNotNull(function, "Passed in function must not be null!");
        if (function instanceof SDVariable) {
            if (function.getSameDiff() != this) {
                function.setSameDiff(this);
            }
            return function;
        }
        return function;
    }

    public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            varNames[i] = variables[i].name();
        }
        this.addOutgoingFor(varNames, function);
    }

    public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.ops.get(function.getOwnName()).getOutputsOfOp() != null && !this.ops.get(function.getOwnName()).getOutputsOfOp().isEmpty()) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
        }
        if (varNames == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (int i = 0; i < varNames.length; ++i) {
            if (varNames[i] != null) continue;
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
        }
        this.ops.get(function.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
        for (String resultName : varNames) {
            this.variables.get(resultName).setOutputOfOp(function.getOwnName());
        }
    }

    public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.argumentInterceptors.push(interceptor);
    }

    private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        return this.pausedArgumentInterceptors.contains(interceptor);
    }

    private ArgumentInterceptor getArgumentInterceptorToUse() {
        if (this.argumentInterceptors.isEmpty()) {
            return null;
        }
        ArgumentInterceptor use = this.argumentInterceptors.peek();
        int i = 1;
        while (this.isArgumentInterceptorPaused(use)) {
            if (this.argumentInterceptors.size() - i < 0) {
                return null;
            }
            use = (ArgumentInterceptor)this.argumentInterceptors.elementAt(this.argumentInterceptors.size() - i);
            ++i;
        }
        return use;
    }

    public void removeArgumentInterceptor() {
        if (!this.argumentInterceptors.isEmpty()) {
            this.argumentInterceptors.pop();
        }
    }

    public void pauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.add(this.argumentInterceptors.peek());
    }

    public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.pausedArgumentInterceptors.add(interceptor);
    }

    public void unpauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.remove(this.argumentInterceptors.peek());
    }

    public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor) {
        if (interceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.pausedArgumentInterceptors.remove(interceptor);
    }

    public void addArgsFor(String[] variables, DifferentialFunction function) {
        ArgumentInterceptor interceptor = this.getArgumentInterceptorToUse();
        if (interceptor != null) {
            this.pauseArgumentInterceptor(interceptor);
            for (int i = 0; i < variables.length; ++i) {
                variables[i] = interceptor.intercept(this.getVariable(variables[i])).name();
            }
            this.unpauseArgumentInterceptor(interceptor);
        }
        if (function.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (!this.ops.containsKey(function.getOwnName())) {
            this.ops.put(function.getOwnName(), SameDiffOp.builder().name(function.getOwnName()).op(function).build());
        }
        this.ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables));
        for (String variableName : variables) {
            List<String> funcs = this.variables.get(variableName).getInputsForOp();
            if (funcs == null) {
                funcs = new ArrayList<String>();
                this.variables.get(variableName).setInputsForOp(funcs);
            }
            if (funcs.contains(function.getOwnName())) continue;
            funcs.add(function.getOwnName());
        }
    }

    public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
        String[] varNames = new String[variables.length];
        for (int i = 0; i < varNames.length; ++i) {
            if (variables[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            varNames[i] = variables[i].name();
        }
        this.addArgsFor(varNames, function);
    }

    public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function) {
        List<String> oldFuncs;
        if (newArg == null) {
            throw new NullPointerException("newArg is marked non-null but is null");
        }
        if (function == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        Preconditions.checkArgument(i < function.args().length, "Index out of range: function " + function.getOwnName() + " only has " + function.args().length + " args but you are tryingto replace the argument at " + i);
        String oldName = function.arg(i).name();
        String newName = newArg.name();
        List<String> oldArgs = this.ops.get(function.getOwnName()).getInputsToOp();
        oldArgs = new ArrayList<String>(oldArgs);
        oldArgs.set(i, newName);
        this.ops.get(function.getOwnName()).setInputsToOp(oldArgs);
        List<String> funcs = this.variables.get(newName).getInputsForOp();
        if (funcs == null) {
            funcs = new ArrayList<String>();
            this.variables.get(newName).setInputsForOp(funcs);
        }
        if (!funcs.contains(function.getOwnName())) {
            funcs.add(function.getOwnName());
        }
        if ((oldFuncs = this.variables.get(oldName).getInputsForOp()) != null && !ArrayUtils.contains(function.argNames(), oldName)) {
            oldFuncs.remove(function.getOwnName());
        }
    }

    public boolean hasArgs(DifferentialFunction function) {
        List<String> vertexIdArgs = this.ops.get(function.getOwnName()).getInputsToOp();
        return vertexIdArgs != null && vertexIdArgs.size() > 0;
    }

    public void clearPlaceholders(boolean allThreads) {
        if (allThreads) {
            this.placeholdersPerThread.clear();
        } else {
            long tid = Thread.currentThread().getId();
            this.placeholdersPerThread.remove(tid);
        }
        for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
            sd.clearPlaceholders(allThreads);
        }
    }

    public void clearOpInputs() {
        for (SameDiffOp op : this.ops.values()) {
            Object o;
            if (op.getOp() instanceof Op) {
                o = (Op)((Object)op.getOp());
                o.setX(null);
                if (o.y() == null) continue;
                o.setY(null);
                continue;
            }
            if (!(op.getOp() instanceof DynamicCustomOp)) continue;
            o = (DynamicCustomOp)op.getOp();
            ((DynamicCustomOp)o).setInputArguments(null);
        }
        for (SameDiff sd : this.sameDiffFunctionInstances.values()) {
            sd.clearOpInputs();
        }
    }

    public DifferentialFunction[] ops() {
        ArrayList<DifferentialFunction> out = new ArrayList<DifferentialFunction>(this.ops.size());
        for (SameDiffOp op : this.ops.values()) {
            out.add(op.getOp());
        }
        return out.toArray(new DifferentialFunction[out.size()]);
    }

    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.variables != null ? this.variables.hashCode() : 0);
        return result;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff)o;
        boolean eqVars = this.variables.equals(sameDiff.variables);
        boolean eqOps = this.ops.equals(sameDiff.ops);
        return eqVars && eqOps;
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public SameDiff dup() {
        ByteBuffer bb = this.asFlatBuffers(true);
        try {
            return SameDiff.fromFlatBuffers(bb);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public long numElements() {
        long ret = 0L;
        for (SDVariable variable : this.variables()) {
            long[] shape = variable.getShape();
            if (shape == null) continue;
            ret += (long)ArrayUtil.prod(shape);
        }
        return ret;
    }

    public List<String> inputs() {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : this.variables.keySet()) {
            if (!this.isPlaceHolder(s)) continue;
            out.add(s);
        }
        return out;
    }

    public List<String> outputs() {
        return this.outputs;
    }

    public void setOutputs(String ... outputs) {
        this.setOutputs(outputs == null ? null : Arrays.asList(outputs));
    }

    public void setOutputs(List<String> outputs) {
        if (outputs != null) {
            for (String s : outputs) {
                Preconditions.checkArgument(this.variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name");
            }
        }
        this.outputs = outputs;
    }

    public List<SDVariable> variables() {
        return new ArrayList<SDVariable>(this.variableMap().values());
    }

    public List<String> getLossVariables() {
        return Collections.unmodifiableList(this.lossVariables);
    }

    public void setLossVariables(String ... lossVariableNames) {
        if (lossVariableNames == null) {
            throw new NullPointerException("lossVariableNames is marked non-null but is null");
        }
        this.lossVariables.clear();
        for (String s : lossVariableNames) {
            this.addLossVariable(s);
        }
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
    }

    public void setLossVariables(SDVariable ... lossVariables) {
        if (lossVariables == null) {
            throw new NullPointerException("lossVariables is marked non-null but is null");
        }
        String[] varNames = new String[lossVariables.length];
        for (int i = 0; i < lossVariables.length; ++i) {
            varNames[i] = lossVariables[i].name();
        }
        this.setLossVariables(varNames);
    }

    public void addLossVariable(@NonNull String variableName) {
        if (variableName == null) {
            throw new NullPointerException("variableName is marked non-null but is null");
        }
        Preconditions.checkState(this.hasVariable(variableName), "No variable with name \"%s\" exists", (Object)variableName);
        SDVariable v = this.getVariable(variableName);
        Preconditions.checkState(v.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized. SDVariable \"%s\" has datatype %s", (Object)variableName, (Object)v.dataType());
        Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized. SDVariable \"%s\" has variable type %s", (Object)variableName, (Object)v.getVariableType());
        if (!this.lossVariables.contains(variableName)) {
            this.lossVariables.add(variableName);
        }
    }

    public void addLossVariable(@NonNull SDVariable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        this.addLossVariable(variable.name());
    }

    public void setTrainingConfig(TrainingConfig trainingConfig) {
        this.trainingConfig = trainingConfig;
    }

    public History fit(@NonNull DataSet dataSet, Listener ... listeners) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false, null, 1, listeners);
    }

    public History fit(@NonNull MultiDataSet dataSet, Listener ... listeners) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet), 1, false, null, 1, listeners);
    }

    public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit().train(iter, numEpochs).validate(validationIter, validationFrequency).listeners(listeners).exec();
    }

    public History fit(@NonNull DataSetIterator iter, int numEpochs, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit().train(iter, numEpochs).listeners(listeners).exec();
    }

    public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit(iter, numEpochs, true, validationIter, validationFrequency, listeners);
    }

    public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return this.fit().train(iter, numEpochs).listeners(listeners).exec();
    }

    public FitConfig fit() {
        return new FitConfig(this);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected synchronized History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, Listener ... listeners) {
        if (iter == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        boolean async = iter.asyncSupported();
        boolean validationAsync = false;
        if (validationData != null) {
            validationAsync = validationData.asyncSupported();
        }
        if (async) {
            iter = new AsyncMultiDataSetIterator(iter, 3, true);
        }
        if (validationAsync) {
            validationData = new AsyncMultiDataSetIterator(validationData, 3, true);
        }
        try {
            History history = this.fitHelper(iter, numEpochs, incrementEpochCount, validationData, validationFrequency, Arrays.asList(listeners));
            return history;
        }
        finally {
            if (async) {
                ((AsyncMultiDataSetIterator)iter).shutdown();
            }
            if (validationAsync) {
                ((AsyncMultiDataSetIterator)validationData).shutdown();
            }
        }
    }

    /*
     * WARNING - void declaration
     */
    protected synchronized History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull List<Listener> listeners) {
        void var18_23;
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        Preconditions.checkNotNull(iter, "Iterator must not be null");
        Preconditions.checkState(numEpochs > 0, "Number of training epochs must be a positive number. Got: %s", numEpochs);
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before training. Use setTrainingConfig(TrainingConfig)");
        Preconditions.checkState(numEpochs == 1 || iter.resetSupported(), "Cannot train for multiple epochs on an iterator that does not support resetting");
        HistoryListener history = new HistoryListener(this.trainingConfig);
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        activeListeners.add(history);
        for (Listener l : this.listeners) {
            if (!l.isActive(Operation.TRAINING)) continue;
            activeListeners.add(l);
        }
        for (Listener l : listeners) {
            if (!l.isActive(Operation.TRAINING)) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, Operation.TRAINING);
        this.validateListenerActivations(activeListeners, Operation.TRAINING_VALIDATION);
        if (!iter.hasNext() && iter.resetSupported()) {
            iter.reset();
        }
        boolean performedValidation = false;
        int trainThreadNum = 0;
        long jThreadId = Thread.currentThread().getId();
        boolean hasListeners = !activeListeners.isEmpty();
        At at = At.builder().epoch(this.trainingConfig.getEpochCount()).iteration(this.trainingConfig.getIterationCount()).trainingThreadNum(trainThreadNum).javaThreadNum(jThreadId).operation(Operation.TRAINING).build();
        LossCurve lossCurve = null;
        HashSet<String> requiredVars = new HashSet<String>();
        for (Listener listener : activeListeners) {
            Object s;
            ListenerVariables lv = listener.requiredVariables(this);
            if (lv == null || (s = lv.trainingVariables()) == null) continue;
            requiredVars.addAll((Collection<String>)s);
        }
        ArrayList<Listener> listenersWitHistory = new ArrayList<Listener>(listeners);
        for (Listener l : this.listeners) {
            if (listenersWitHistory.contains(l)) continue;
            listenersWitHistory.add(l);
        }
        listenersWitHistory.add(history);
        SameDiff sameDiff = this.getFunction(GRAD_FN_KEY);
        if (sameDiff == null) {
            this.createGradFunction();
            SameDiff sameDiff2 = this.getFunction(GRAD_FN_KEY);
        }
        TrainingSession ts = new TrainingSession((SameDiff)var18_23);
        var18_23.setTrainingConfig(this.trainingConfig);
        for (Listener listener : activeListeners) {
            listener.operationStart((SameDiff)var18_23, Operation.TRAINING);
        }
        LinkedHashSet<String> paramsToTrain = new LinkedHashSet<String>();
        for (Variable v : this.variables.values()) {
            if (v.getVariable().getVariableType() != VariableType.VARIABLE) continue;
            paramsToTrain.add(v.getName());
        }
        Object var21_28 = null;
        for (int i = 0; i < numEpochs; ++i) {
            if (incrementEpochCount && hasListeners) {
                at.setEpoch(this.trainingConfig.getEpochCount());
                for (Listener l : activeListeners) {
                    l.epochStart(this, at);
                }
            }
            long epochStartTime = System.currentTimeMillis();
            double[] lossSums = null;
            List<String> lossNames = null;
            int lossCount = 0;
            while (iter.hasNext()) {
                Map<String, INDArray> placeholders;
                long dataEnd;
                long dataStart = hasListeners ? System.currentTimeMillis() : 0L;
                MultiDataSet ds = (MultiDataSet)iter.next();
                long l = dataEnd = hasListeners ? System.currentTimeMillis() : 0L;
                if (!performedValidation) {
                    Preconditions.checkState(this.trainingConfig.getDataSetFeatureMapping().size() == ds.numFeatureArrays(), "The number of dataset feature mapping variables set in the training configuration (%s) must match the number of dataset feature arrays (%s)", this.trainingConfig.getDataSetFeatureMapping().size(), ds.numFeatureArrays());
                    List<String> labelMapping = this.trainingConfig.getDataSetLabelMapping();
                    int lblSize = labelMapping == null ? 0 : labelMapping.size();
                    Preconditions.checkState(lblSize == ds.numLabelsArrays(), "The number of dataset label mapping variables set in the training configuration (%s) must match the number of dataset label arrays (%s)", lblSize, ds.numLabelsArrays());
                    performedValidation = true;
                }
                if (hasListeners) {
                    at.setIteration(this.trainingConfig.getIterationCount());
                    for (Listener l2 : activeListeners) {
                        l2.iterationStart(this, at, ds, dataEnd - dataStart);
                    }
                }
                Preconditions.checkState((placeholders = this.toPlaceholderMap(ds)).size() > 0, "No placeholder variables were set for training");
                if (!this.initializedTraining) {
                    this.initializeTraining();
                }
                Loss loss = ts.trainingIteration(this.trainingConfig, placeholders, paramsToTrain, this.updaterMap, ds, this.getLossVariables(), listenersWitHistory, at);
                if (lossSums == null) {
                    lossSums = (double[])loss.getLosses().clone();
                } else {
                    for (int j = 0; j < lossSums.length; ++j) {
                        int n = j;
                        lossSums[n] = lossSums[n] + loss.getLosses()[j];
                    }
                }
                ++lossCount;
                this.trainingConfig.incrementIterationCount();
            }
            long epochTime = System.currentTimeMillis() - epochStartTime;
            if (incrementEpochCount) {
                void var21_29;
                lossNames = var21_29.getLossNames();
                int j = 0;
                while (j < lossSums.length) {
                    int n = j++;
                    lossSums[n] = lossSums[n] / (double)lossCount;
                }
                lossCurve = lossCurve != null ? lossCurve.addLossAndCopy(lossSums, lossNames) : new LossCurve(lossSums, lossNames);
            }
            if (incrementEpochCount) {
                if (hasListeners) {
                    boolean doStop = false;
                    Listener stopped = null;
                    for (Listener l : activeListeners) {
                        ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
                        if (res != ListenerResponse.STOP || i >= numEpochs - 1) continue;
                        doStop = true;
                        stopped = l;
                    }
                    if (doStop) {
                        log.info("Stopping training early.  Listener " + stopped + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
                        for (Listener l1 : activeListeners) {
                            l1.operationEnd(this, Operation.TRAINING);
                        }
                        if (i < numEpochs - 1) {
                            iter.reset();
                        }
                        if (incrementEpochCount) {
                            this.trainingConfig.incrementEpochCount();
                        }
                        return history.getReport();
                    }
                    if (validationData != null && (validationFrequency <= 0 || i % validationFrequency == 0)) {
                        long validationStart = System.currentTimeMillis();
                        this.outputHelper(validationData, new At(at.epoch(), 0, 0, 0L, null, Operation.TRAINING_VALIDATION), listenersWitHistory, new String[0]);
                        long validationTime = System.currentTimeMillis() - validationStart;
                        boolean doStopV = false;
                        Listener stoppedV = null;
                        for (Listener l : activeListeners) {
                            ListenerResponse res = l.validationDone(this, at, validationTime);
                            if (res != ListenerResponse.STOP || i >= numEpochs - 1) continue;
                            doStopV = true;
                            stoppedV = l;
                        }
                        if (doStopV) {
                            log.info("Stopping training early from validation.  Listener " + stoppedV + " gave a STOP signal at epoch " + at.epoch() + " and iteration " + at.iteration());
                            for (Listener l1 : activeListeners) {
                                l1.operationEnd(this, Operation.TRAINING);
                            }
                            if (i < numEpochs - 1) {
                                iter.reset();
                            }
                            if (incrementEpochCount) {
                                this.trainingConfig.incrementEpochCount();
                            }
                            return history.getReport();
                        }
                    }
                }
                this.trainingConfig.incrementEpochCount();
            }
            if (i >= numEpochs - 1) continue;
            iter.reset();
        }
        for (Listener l1 : activeListeners) {
            l1.operationEnd(this, Operation.TRAINING);
        }
        return history.getReport();
    }

    private void validateListenerActivations(List<Listener> listeners, Operation op) {
        for (Listener l : listeners) {
            ListenerVariables lv = l.requiredVariables(this);
            if (lv == null) continue;
            for (String s : lv.requiredVariables(op)) {
                if (this.variables.containsKey(s)) continue;
                Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", (Object)l, (Object)s);
            }
        }
    }

    public double calcRegularizationScore() {
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
        if (this.trainingConfig.getRegularization() == null || this.trainingConfig.getRegularization().isEmpty()) {
            return 0.0;
        }
        List<Regularization> l = this.trainingConfig.getRegularization();
        double loss = 0.0;
        for (Variable v : this.variables.values()) {
            SDVariable sdv = v.getVariable();
            if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType()) continue;
            for (Regularization r : l) {
                INDArray arr = sdv.getArr();
                loss += r.score(arr, this.trainingConfig.getIterationCount(), this.trainingConfig.getEpochCount());
            }
        }
        return loss;
    }

    protected void initializeTraining() {
        if (!this.initializedTraining) {
            if (this.trainingConfig == null) {
                throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
            }
            this.updaterMap = new HashMap<String, GradientUpdater>();
            for (Variable v : this.variables.values()) {
                if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) continue;
                INDArray arr = v.getVariable().getArr();
                long stateSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                INDArray view = stateSize == 0L ? null : Nd4j.createUninitialized(arr.dataType(), 1L, stateSize);
                GradientUpdater gu = this.trainingConfig.getUpdater().instantiate(view, false);
                gu.setStateViewArray(view, arr.shape(), arr.ordering(), true);
                this.updaterMap.put(v.getName(), gu);
            }
            this.initializedTraining = true;
        }
    }

    private Map<String, INDArray> toPlaceholderMap(MultiDataSet ds) {
        HashMap<String, INDArray> placeholders = new HashMap<String, INDArray>();
        int count = 0;
        for (String s : this.trainingConfig.getDataSetFeatureMapping()) {
            placeholders.put(s, ds.getFeatures(count++));
        }
        count = 0;
        if (this.trainingConfig.getDataSetLabelMapping() != null) {
            for (String s : this.trainingConfig.getDataSetLabelMapping()) {
                placeholders.put(s, ds.getLabels(count++));
            }
        }
        if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetFeatureMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getFeaturesMaskArray(count++));
            }
        }
        if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
            count = 0;
            for (String s : this.trainingConfig.getDataSetLabelMaskMapping()) {
                if (s == null) {
                    ++count;
                    continue;
                }
                placeholders.put(s, ds.getLabelsMaskArray(count++));
            }
        }
        return placeholders;
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List<Listener> listeners, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
        this.evaluate().data(iterator).evaluate(outputVariable, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        this.evaluate().data(iterator).evaluate(outputVariable, evaluations).exec();
    }

    public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map<String, IEvaluation> variableEvals, Listener ... listeners) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (variableEvals == null) {
            throw new NullPointerException("variableEvals is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        HashMap<String, List<IEvaluation>> variableEvalsList = new HashMap<String, List<IEvaluation>>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
            variableEvalsList.put(s, Collections.singletonList(variableEvals.get(s)));
        }
        this.evaluate((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), variableEvalsList, map, listeners);
    }

    public void evaluateMultiple(DataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Listener ... listeners) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (String s : variableEvals.keySet()) {
            map.put(s, 0);
        }
        this.evaluate((MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator), variableEvals, map, listeners);
    }

    public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull List<Listener> listeners, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        Preconditions.checkArgument(evaluations != null && evaluations.length > 0, "No evaluations were passed to the evaluate method");
        this.evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, IEvaluation ... evaluations) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (outputVariable == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (evaluations == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        this.evaluate().data(iterator).evaluate(outputVariable, labelIndex, evaluations).exec();
    }

    public void evaluate(MultiDataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Map<String, Integer> predictionLabelMapping, Listener ... listeners) {
        this.evaluateHelper(iterator, variableEvals, predictionLabelMapping, At.defaultAt(Operation.EVALUATION), listeners);
    }

    public EvaluationConfig evaluate() {
        return new EvaluationConfig(this);
    }

    private void evaluateHelper(MultiDataSetIterator iterator, Map<String, List<IEvaluation>> variableEvals, Map<String, Integer> predictionLabelMapping, At at, Listener ... listeners) {
        boolean hasListeners;
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        Preconditions.checkState(variableEvals.keySet().equals(predictionLabelMapping.keySet()), "Keysets for variable evaluations and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", variableEvals.keySet(), predictionLabelMapping.keySet());
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        for (Listener l : listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, at.operation());
        for (Listener l : activeListeners) {
            l.operationStart(this, at.operation());
        }
        boolean bl = hasListeners = !activeListeners.isEmpty();
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        HashSet<String> requiredVars = new HashSet<String>(variableEvals.keySet());
        if (hasListeners) {
            for (Listener l : activeListeners) {
                ListenerVariables v = l.requiredVariables(this);
                if (v == null) continue;
                requiredVars.addAll(v.evaluationVariables());
            }
        }
        String[] requiredVarsArr = requiredVars.toArray(new String[0]);
        while (iterator.hasNext()) {
            MultiDataSet ds = (MultiDataSet)iterator.next();
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            Map<String, INDArray> m = this.directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr);
            for (Map.Entry<String, List<IEvaluation>> e : variableEvals.entrySet()) {
                INDArray prediction = m.get(e.getKey());
                for (IEvaluation eval : e.getValue()) {
                    INDArray label = ds.getLabels(predictionLabelMapping.get(e.getKey()));
                    INDArray mask = ds.getLabelsMaskArray(predictionLabelMapping.get(e.getKey()));
                    eval.eval(label, prediction, mask);
                }
            }
            at.setIteration(at.iteration() + 1);
        }
        for (Listener l : activeListeners) {
            l.operationEnd(this, at.operation());
        }
    }

    public Map<String, INDArray> output(@NonNull DataSet dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.outputBatches((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), outputs).get(0);
    }

    public Map<String, INDArray> output(@NonNull MultiDataSet dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.outputBatches((MultiDataSetIterator)new SingletonMultiDataSetIterator(dataSet), outputs).get(0);
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator iterator, @NonNull List<Listener> listeners, String ... outputs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).exec();
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.output().data(dataSet).output(outputs).exec();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, List<Listener> listeners, String ... outputs) {
        return this.output().data(iterator).output(outputs).listeners(listeners.toArray(new Listener[0])).execBatches();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator iterator, String ... outputs) {
        return this.output().data(iterator).output(outputs).execBatches();
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator iterator, @NonNull List<Listener> listeners, String ... outputs) {
        if (iterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return SameDiffUtils.stackOutputs(this.outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs));
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator dataSet, String ... outputs) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.output().data(dataSet).output(outputs).exec();
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, List<Listener> listeners, String ... outputs) {
        return this.outputHelper(iterator, At.defaultAt(Operation.INFERENCE), listeners, outputs);
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator iterator, String ... outputs) {
        return this.output().data(iterator).output(outputs).execBatches();
    }

    public OutputConfig output() {
        return new OutputConfig(this);
    }

    private List<Map<String, INDArray>> outputHelper(MultiDataSetIterator iterator, At at, @NonNull List<Listener> listeners, String ... outputs) {
        if (listeners == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        for (Listener l : listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(at.operation())) continue;
            activeListeners.add(l);
        }
        this.validateListenerActivations(activeListeners, at.operation());
        for (Listener l : activeListeners) {
            l.operationStart(this, at.operation());
        }
        boolean hasListeners = !activeListeners.isEmpty();
        List<String> neededOutputs = outputs != null && outputs.length != 0 ? Arrays.asList(outputs) : this.getLossVariables();
        String[] neededOutputsArr = neededOutputs.toArray(new String[0]);
        ArrayList<Map<String, INDArray>> predictions = new ArrayList<Map<String, INDArray>>();
        if (!iterator.hasNext() && iterator.resetSupported()) {
            iterator.reset();
        }
        HashSet<String> requiredVars = new HashSet<String>();
        for (Listener l : activeListeners) {
            if (at.operation() == Operation.TRAINING_VALIDATION) {
                requiredVars.addAll(l.requiredVariables(this).validationVariables());
                continue;
            }
            requiredVars.addAll(l.requiredVariables(this).inferenceVariables());
        }
        while (iterator.hasNext()) {
            long dataStart = hasListeners ? System.currentTimeMillis() : 0L;
            MultiDataSet ds = (MultiDataSet)iterator.next();
            long dataEnd = hasListeners ? System.currentTimeMillis() : 0L;
            Map<String, INDArray> placeholderMap = this.toPlaceholderMap(ds);
            if (hasListeners) {
                for (Listener l : activeListeners) {
                    l.iterationStart(this, at, ds, dataEnd - dataStart);
                }
                Map<String, INDArray> outs = this.directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr);
                for (Listener l : activeListeners) {
                    l.iterationDone(this, at, ds, null);
                }
                predictions.add(outs);
            } else {
                predictions.add(this.directExecHelper(placeholderMap, at, ds, requiredVars, activeListeners, neededOutputsArr));
            }
            at.setIteration(at.iteration() + 1);
        }
        for (Listener l : activeListeners) {
            l.operationEnd(this, at.operation());
        }
        return predictions;
    }

    public BatchOutputConfig batchOutput() {
        return new BatchOutputConfig(this);
    }

    public Map<String, INDArray> outputAll(Map<String, INDArray> placeholders) {
        return this.batchOutput().outputAll().inputs(placeholders).exec();
    }

    public INDArray outputSingle(Map<String, INDArray> placeholders, String output) {
        return this.batchOutput().output(output).inputs(placeholders).execSingle();
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, @NonNull List<String> outputs) {
        if (outputs == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return this.batchOutput().output(outputs.toArray(new String[0])).inputs(placeholders).output();
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, String ... outputs) {
        return this.batchOutput().output(outputs).inputs(placeholders).output();
    }

    public Map<String, INDArray> output(Map<String, INDArray> placeholders, List<Listener> listeners, String ... outputs) {
        return this.batchOutputHelper(placeholders, listeners, Operation.INFERENCE, outputs);
    }

    protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placeholders, List<Listener> listeners, Operation operation, String ... outputs) {
        ArrayList<Listener> activeListeners = new ArrayList<Listener>();
        if (operation == null) {
            operation = Operation.INFERENCE;
        }
        for (Listener l : this.listeners) {
            if (!l.isActive(operation)) continue;
            activeListeners.add(l);
        }
        if (listeners != null) {
            for (Listener l : listeners) {
                if (!l.isActive(operation)) continue;
                activeListeners.add(l);
            }
        }
        for (Listener l : activeListeners) {
            l.operationStart(this, operation);
        }
        this.validateListenerActivations(activeListeners, operation);
        Map<String, INDArray> ret = this.directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs);
        for (Listener l : activeListeners) {
            l.operationEnd(this, operation);
        }
        return ret;
    }

    protected Map<String, INDArray> directExecHelper(Map<String, INDArray> placeholders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String ... outputs) {
        if (at == null) {
            at = At.defaultAt();
        }
        Preconditions.checkState(outputs != null && outputs.length > 0, "No outputs were specified");
        long threadId = Thread.currentThread().getId();
        if (!this.sessions.containsKey(threadId)) {
            log.info("Creating new InferenceSession for thread {}", (Object)threadId);
            this.sessions.put(threadId, new InferenceSession(this));
        }
        List<String> phNames = this.inputs();
        if (placeholders == null && phNames != null) {
            placeholders = this.placeholdersPerThread.get(Thread.currentThread().getId());
        }
        InferenceSession is = this.sessions.get(threadId);
        return is.output(outputs == null ? Collections.emptyList() : Arrays.asList(outputs), placeholders, batch, requiredActivations, activeListeners, at);
    }

    public SDVariable one(String name, int ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, long ... shape) {
        return this.one(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable one(String name, DataType dataType, int ... shape) {
        return this.one(name, dataType, ArrayUtil.toLongArray(shape));
    }

    public SDVariable one(String name, DataType dataType, long ... shape) {
        return this.constant(name, Nd4j.ones(dataType, shape));
    }

    public SDVariable zero(String name, long ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, int ... shape) {
        return this.zero(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable zero(String name, DataType dataType, long ... shape) {
        return this.constant(name, Nd4j.zeros(dataType, shape));
    }

    public SDVariable zero(String name, DataType dataType, int ... shape) {
        return this.zero(name, dataType, ArrayUtil.toLongArray(shape));
    }

    public SDVariable constant(@NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        return this.constant(this.getNewVarName(), constant);
    }

    public SDVariable constant(String name, @NonNull INDArray constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        Preconditions.checkState(!this.variables.containsKey(name), "Variable with name \"%s\" already exists", (Object)name);
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        if (constant.isView()) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                constant = constant.dup();
            }
        }
        SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType());
        name = v.name();
        this.variables.put(name, Variable.builder().name(name).variable(v).build());
        this.constantArrays.setArray(name, constant);
        return v;
    }

    public SDVariable placeHolder(@NonNull String name, DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        Preconditions.checkState(!this.variables.containsKey(name), "Variable already exists with name %s", (Object)name);
        SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType);
        this.variables.put(name, Variable.builder().name(name).variable(ret).build());
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return this.var(name, VariableType.VARIABLE, weightInitScheme, dataType, shape);
    }

    public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (variableType == null) {
            throw new NullPointerException("variableType is marked non-null but is null");
        }
        if (shape != null) {
            for (long l : shape) {
                Preconditions.checkArgument(l != 0L, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", (Object)shape);
            }
        }
        if (this.variables.containsKey(name = name == null || name.length() < 1 ? this.getNewVarName() : this.generateNewVarName(name, 0))) {
            if (this.nameScopes.isEmpty()) {
                throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + this.currentNameScope() + "\"");
            }
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        Preconditions.checkState(variableType != VariableType.VARIABLE || weightInitScheme != null, "A weight initalization scheme must be provided when creating a VARIABLE type SDVariables - variable name: \"%s\"", (Object)name);
        SDVariable ret = new SDVariable(name, variableType, this, shape, dataType);
        this.addVariable(ret);
        if (variableType == VariableType.VARIABLE) {
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                INDArray vArr = weightInitScheme.create(dataType, shape);
                this.variablesArrays.setArray(name, vArr);
            }
        }
        return ret;
    }

    public SDVariable var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return this.var(name, weightInitScheme, shape.dataType(), shape.getShape());
    }

    public SDVariable var(String name, DataType dataType, long ... shape) {
        Preconditions.checkNotNull(shape != null, "Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, shape);
        }
        return this.var(name, (WeightInitScheme)new ZeroInitScheme(), dataType, shape);
    }

    public SDVariable var(String name, LongShapeDescriptor shapeDesc) {
        Preconditions.checkNotNull(shapeDesc != null, "Invalid shape: shape may not be null");
        return this.var(name, shapeDesc, new ZeroInitScheme());
    }

    public SDVariable var(String name, int ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, long ... shape) {
        return this.var(name, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, long ... shape) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked non-null but is null");
        }
        if (shape == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return this.var(name, weightInitScheme, Nd4j.defaultFloatingPointType(), shape);
    }

    public SDVariable var(String name, DataType dataType, int ... shape) {
        Preconditions.checkNotNull(shape, "Invalid shape: shape may not be null");
        if (Shape.isPlaceholderShape(shape)) {
            return this.placeHolder(name, dataType, ArrayUtil.toLongArray(shape));
        }
        return this.var(name, (WeightInitScheme)new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(shape));
    }

    public SDVariable var(@NonNull SDVariable v) {
        if (v == null) {
            throw new NullPointerException("v is marked non-null but is null");
        }
        if (this.variables.containsKey(v.name()) && this.variables.get(v.name()).getVariable().getArr() != null) {
            return this.variables.get(v.name()).getVariable();
        }
        if (v.name() == null || v.name().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        VariableType vt = v.getVariableType();
        Object s = null;
        switch (vt) {
            case VARIABLE: {
                SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
                this.addVariable(r);
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    this.variablesArrays.setArray(v.name(), v.getArr().dup());
                }
                return r;
            }
            case ARRAY: {
                SDVariable ret = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
                return this.addVariable(ret);
            }
            case CONSTANT: {
                return this.constant(v.name(), v.getArr());
            }
            case PLACEHOLDER: {
                return this.placeHolder(v.name(), v.dataType(), v.placeholderShape());
            }
        }
        throw new RuntimeException("Unknown/not supported variable type: " + (Object)((Object)vt));
    }

    private String getNewVarName() {
        return this.generateNewVarName("sd_var", 0, false);
    }

    public SDVariable var(DataType dataType, int ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), dataType, shape);
    }

    public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long ... shape) {
        return this.var(this.getNewVarName(), weightInitScheme, dataType, shape);
    }

    public SDVariable var(INDArray arr) {
        return this.var(this.getNewVarName(), arr);
    }

    public SDVariable var(String name, @NonNull INDArray arr) {
        if (arr == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        if (this.variables.containsKey(name) && this.variables.get(name).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + name + " already exists.");
        }
        Preconditions.checkState(arr.dataType().isFPType(), "Cannot create variable with non-floating point type: provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\nFor non floating point types, these should be created as placeholders or constants instead.", (Object)arr.dataType());
        Preconditions.checkArgument(!arr.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", (Object)arr);
        if (name == null || name.length() < 1) {
            name = this.getNewVarName();
        }
        boolean duped = false;
        if (arr.isAttached()) {
            arr = arr.detach();
            duped = true;
        }
        if (!duped) {
            for (String s : this.variablesArrays.arrayNames()) {
                if (this.variablesArrays.getArray(s) != arr) continue;
                arr = arr.dup();
                break;
            }
        }
        SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType());
        this.associateArrayWithVariable(arr, ret);
        this.addVariable(ret);
        return ret;
    }

    public SDVariable convertToConstant(@NonNull SDVariable variable) {
        if (variable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        this.convertToConstants(Collections.singletonList(variable));
        return variable;
    }

    public void convertToConstants(List<SDVariable> variables) {
        if (variables.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : variables) {
            if (variable.getVariableType() == VariableType.CONSTANT) continue;
            allConst = false;
            Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a constant: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
        for (SDVariable variable : variables) {
            String n = variable.name();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.constantArrays.setArray(n, arr);
            this.variablesArrays.removeArray(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map map : this.placeholdersPerThread.values()) {
                    map.remove(n);
                }
            }
            variable.setVariableType(VariableType.CONSTANT);
        }
        if (this.trainingConfig != null && this.initializedTraining) {
            for (SDVariable v : variables) {
                Map<String, INDArray> m;
                GradientUpdater gu = this.updaterMap.remove(v.name());
                Map<String, INDArray> map = m = gu == null ? null : gu.getState();
                if (m != null) {
                    for (INDArray iNDArray : m.values()) {
                        if (!iNDArray.closeable()) continue;
                        iNDArray.close();
                    }
                }
                if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(v.name())) {
                    ArrayList<String> newFM = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMapping());
                    newFM.remove(v.name());
                    this.trainingConfig.setDataSetFeatureMapping(newFM);
                }
                if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(v.name())) {
                    ArrayList<String> newLM = new ArrayList<String>(this.trainingConfig.getDataSetLabelMapping());
                    newLM.remove(v.name());
                    this.trainingConfig.setDataSetLabelMapping(newLM);
                }
                if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(v.name())) {
                    ArrayList<String> newFMM = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMaskMapping());
                    newFMM.remove(v.name());
                    this.trainingConfig.setDataSetFeatureMaskMapping(newFMM);
                }
                if (this.trainingConfig.getDataSetLabelMaskMapping() == null || !this.trainingConfig.getDataSetLabelMaskMapping().contains(v.name())) continue;
                ArrayList<String> newLMM = new ArrayList<String>(this.trainingConfig.getDataSetLabelMaskMapping());
                newLMM.remove(v.name());
                this.trainingConfig.setDataSetLabelMaskMapping(newLMM);
            }
        }
    }

    public SDVariable convertToVariable(@NonNull SDVariable constant) {
        if (constant == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        Preconditions.checkState(constant.dataType().isFPType(), "Only floating point SDVariables can be converted to variables, datatype of %s is %s", (Object)constant.name(), (Object)constant.dataType());
        this.convertToVariables(Collections.singletonList(constant));
        return constant;
    }

    public void convertToVariables(@NonNull List<SDVariable> constants) {
        if (constants == null) {
            throw new NullPointerException("constants is marked non-null but is null");
        }
        if (constants.size() == 0) {
            return;
        }
        boolean allConst = true;
        for (SDVariable variable : constants) {
            if (variable.getVariableType() != VariableType.VARIABLE) {
                allConst = false;
            }
            Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a variable: %s", (Object)variable);
        }
        if (allConst) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove(GRAD_FN_KEY);
        for (SDVariable variable : constants) {
            String n = variable.name();
            INDArray arr = variable.getArr();
            Preconditions.checkNotNull((Object)arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", (Object)variable);
            this.variablesArrays.setArray(n, arr);
            this.constantArrays.removeArray(n);
            if (!this.placeholdersPerThread.isEmpty()) {
                for (Map<String, INDArray> m : this.placeholdersPerThread.values()) {
                    m.remove(n);
                }
            }
            variable.setVariableType(VariableType.VARIABLE);
        }
        if (this.trainingConfig != null && this.initializedTraining) {
            for (SDVariable v : constants) {
                if (this.updaterMap.containsKey(v.name())) continue;
                INDArray arr = v.getArr();
                long thisSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                if (thisSize > 0L) {
                    INDArray stateArr = Nd4j.create(arr.dataType(), 1L, thisSize);
                    GradientUpdater u = this.trainingConfig.getUpdater().instantiate(stateArr, false);
                    u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true);
                    this.updaterMap.put(v.name(), u);
                    continue;
                }
                GradientUpdater u = this.trainingConfig.getUpdater().instantiate((INDArray)null, true);
                this.updaterMap.put(v.name(), u);
            }
        }
    }

    public void convertDataTypes(@NonNull Map<String, DataType> dataTypeMap) {
        Object v;
        if (dataTypeMap == null) {
            throw new NullPointerException("dataTypeMap is marked non-null but is null");
        }
        if (dataTypeMap.isEmpty()) {
            return;
        }
        for (Map.Entry<String, DataType> entry : dataTypeMap.entrySet()) {
            String string = entry.getKey();
            Preconditions.checkState(this.variables.containsKey(string), "Cannot change datatype of variable \"%s\": No variable with this name exists", (Object)string);
            SDVariable v2 = this.variables.get(string).getVariable();
            Preconditions.checkState(v2.getVariableType() != VariableType.ARRAY, "Cannot change datatype of ARRAY type variable \"%s\": datatype of ARRAY type variables is determined by the datatypes of their inputs plus corresponding ");
            if (v2.getVariableType() == VariableType.PLACEHOLDER) continue;
            Preconditions.checkState(v2.dataType().isNumerical() == entry.getValue().isNumerical(), "Cannot convert variables between numerical and non-numerical types: attempting to convert variable \"%s\" from %s to %s", (Object)entry.getKey(), (Object)v2.dataType(), (Object)entry.getValue());
        }
        boolean anyChanged = false;
        for (Map.Entry<String, DataType> entry : dataTypeMap.entrySet()) {
            String s = entry.getKey();
            DataType d = entry.getValue();
            v = this.variables.get(s).getVariable();
            if (((SDVariable)v).dataType() == d) continue;
            ((SDVariable)v).setDataType(d);
            switch (((SDVariable)v).getVariableType()) {
                case VARIABLE: {
                    INDArray arr = this.variablesArrays.removeArray(entry.getKey());
                    Iterator<String> newArr = arr.castTo(d);
                    this.variablesArrays.setArray(entry.getKey(), (INDArray)((Object)newArr));
                    break;
                }
                case CONSTANT: {
                    INDArray arr2 = this.constantArrays.removeArray(entry.getKey());
                    INDArray newArr2 = arr2.castTo(d);
                    this.constantArrays.setArray(entry.getKey(), newArr2);
                    break;
                }
                case PLACEHOLDER: {
                    Map<String, INDArray> m = this.placeholdersPerThread.get(Thread.currentThread().getId());
                    if (m == null || !m.containsKey(entry.getKey())) break;
                    m.put(entry.getKey(), m.get(entry.getKey()).castTo(d));
                    break;
                }
                default: {
                    throw new IllegalStateException("Cannot convert array type variable");
                }
            }
            anyChanged = true;
        }
        if (anyChanged) {
            this.sessions.clear();
            HashSet<String> hashSet = new HashSet<String>();
            LinkedList<String> linkedList = new LinkedList<String>();
            for (String s : dataTypeMap.keySet()) {
                v = this.variables.get(s);
                ((Variable)v).getVariable().setDataType(dataTypeMap.get(s));
                List<String> inToOp = ((Variable)v).getInputsForOp();
                if (inToOp == null) continue;
                for (String op : inToOp) {
                    if (hashSet.contains(op)) continue;
                    hashSet.add(op);
                    linkedList.add(op);
                }
            }
            while (!linkedList.isEmpty()) {
                String op = (String)linkedList.remove();
                SameDiffOp o = this.ops.get(op);
                List<String> inVars = o.getInputsToOp();
                ArrayList<DataType> inDTypes = new ArrayList<DataType>();
                if (inVars != null) {
                    for (String s : inVars) {
                        SDVariable v3 = this.variables.get(s).getVariable();
                        inDTypes.add(v3.dataType());
                    }
                }
                List<DataType> outDtypes = o.getOp().calculateOutputDataTypes(inDTypes);
                List<String> outVars = o.getOutputsOfOp();
                for (int i = 0; i < outVars.size(); ++i) {
                    String varName = outVars.get(i);
                    Variable var = this.variables.get(varName);
                    SDVariable v4 = var.getVariable();
                    v4.setDataType(outDtypes.get(i));
                    if (var.getInputsForOp() == null) continue;
                    for (String opName : var.getInputsForOp()) {
                        if (hashSet.contains(opName)) continue;
                        hashSet.add(opName);
                        linkedList.add(opName);
                    }
                }
            }
        }
    }

    public void renameVariable(String from, String to) {
        Variable var;
        ArrayList<String> newCDs;
        SameDiffOp op;
        Preconditions.checkState(this.variables.containsKey(from), "Cannot rename variable \"%s\": no variable with this name exists", (Object)from);
        Preconditions.checkState(!this.variables.containsKey(to), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", (Object)from, (Object)to, (Object)to);
        Variable v = this.variables.get(from);
        v.setName(to);
        v.getVariable().setVarName(to);
        if (v.getInputsForOp() != null) {
            for (String string : v.getInputsForOp()) {
                op = this.ops.get(string);
                ArrayList<String> newInputs = new ArrayList<String>(op.getInputsToOp());
                while (newInputs.contains(from)) {
                    newInputs.set(newInputs.indexOf(from), to);
                }
                op.setInputsToOp(newInputs);
            }
        }
        if (v.getControlDepsForOp() != null) {
            for (String string : v.getControlDepsForOp()) {
                op = this.ops.get(string);
                newCDs = new ArrayList<String>(op.getControlDeps());
                while (newCDs.contains(from)) {
                    newCDs.set(newCDs.indexOf(from), to);
                }
                op.setControlDeps(newCDs);
            }
        }
        if (v.getControlDepsForVar() != null) {
            for (String string : v.getControlDepsForVar()) {
                var = this.variables.get(string);
                newCDs = new ArrayList<String>(var.getControlDeps());
                while (newCDs.contains(from)) {
                    newCDs.set(newCDs.indexOf(from), to);
                }
                var.setControlDeps(newCDs);
            }
        }
        if (v.getControlDeps() != null) {
            for (String string : v.getControlDeps()) {
                var = this.variables.get(string);
                ArrayList<String> newCDsFor = new ArrayList<String>(var.getControlDepsForVar());
                while (newCDsFor.contains(from)) {
                    newCDsFor.set(newCDsFor.indexOf(from), to);
                }
                var.setControlDepsForVar(newCDsFor);
            }
        }
        if (v.getOutputOfOp() != null) {
            SameDiffOp op2 = this.ops.get(v.getOutputOfOp());
            ArrayList<String> arrayList = new ArrayList<String>(op2.getOutputsOfOp());
            while (arrayList.contains(from)) {
                arrayList.set(arrayList.indexOf(from), to);
            }
            op2.setOutputsOfOp(arrayList);
        }
        this.variables.remove(from);
        this.variables.put(to, v);
        if (v.getVariable().getVariableType() == VariableType.CONSTANT && this.constantArrays.hasArray(from)) {
            this.constantArrays.rename(from, to);
        }
        if (v.getVariable().getVariableType() == VariableType.VARIABLE && this.variablesArrays.hasArray(from)) {
            this.variablesArrays.rename(from, to);
        }
        if (v.getVariable().getVariableType() == VariableType.PLACEHOLDER) {
            for (Map map : this.placeholdersPerThread.values()) {
                if (map == null || !map.containsKey(from)) continue;
                INDArray arr = (INDArray)map.remove(from);
                map.put(to, arr);
            }
        }
        if (this.trainingConfig != null) {
            ArrayList<String> l;
            if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetFeatureMapping(l);
            }
            if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetLabelMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetLabelMapping(l);
            }
            if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetFeatureMaskMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetFeatureMaskMapping(l);
            }
            if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getDataSetLabelMaskMapping());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setDataSetLabelMaskMapping(l);
            }
            if (this.trainingConfig.getLossVariables() != null && this.trainingConfig.getLossVariables().contains(from)) {
                l = new ArrayList<String>(this.trainingConfig.getLossVariables());
                while (l.contains(from)) {
                    l.set(l.indexOf(from), to);
                }
                this.trainingConfig.setLossVariables(l);
            }
        }
        for (SameDiff sameDiff : this.sameDiffFunctionInstances.values()) {
            if (!sameDiff.hasVariable(from)) continue;
            sameDiff.renameVariable(from, to);
        }
        if (this.lossVariables.contains(from)) {
            int idx = this.lossVariables.indexOf(from);
            this.lossVariables.set(idx, to);
        }
    }

    public void removeArgFromOp(String varName, DifferentialFunction function) {
        SDVariable[] args = function.args();
        for (int i = 0; i < args.length; ++i) {
            if (!args[i].name().equals(varName)) continue;
            List<String> reverseArgs = this.ops.get(function.getOwnName()).getInputsToOp();
            ArrayList<String> newArgs = new ArrayList<String>(args.length - 1);
            for (int arg = 0; arg < args.length; ++arg) {
                if (reverseArgs.get(arg).equals(varName)) continue;
                newArgs.add(reverseArgs.get(arg));
            }
            this.ops.get(function.getOwnName()).setInputsToOp(newArgs);
            break;
        }
        this.variables.get(varName).getInputsForOp().remove(function.getOwnName());
    }

    public SDVariable getVariable(String name) {
        Variable v = this.variables.get(name);
        return v == null ? null : v.getVariable();
    }

    public boolean hasVariable(String name) {
        return this.variables.containsKey(name);
    }

    public SDVariable getGradForVariable(String varName) {
        Preconditions.checkState(this.variables.containsKey(varName), "No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        Preconditions.checkState(v.dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating point variables have gradients", (Object)varName, (Object)v.dataType());
        if (this.variables.containsKey(varName) && this.variables.get(varName).getGradient() != null) {
            return this.variables.get(varName).getGradient();
        }
        if (this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY) && this.sameDiffFunctionInstances.get((Object)GRAD_FN_KEY).variables.containsKey(varName)) {
            return this.sameDiffFunctionInstances.get((Object)GRAD_FN_KEY).variables.get(varName).getGradient();
        }
        return null;
    }

    public boolean variableHasGradient(String varName) {
        Preconditions.checkState(this.variables.containsKey(varName), "No variable with name \"%s\" exists", (Object)varName);
        SDVariable v = this.getVariable(varName);
        if (!v.dataType().isFPType() || v.isConstant()) {
            return false;
        }
        return this.getGradForVariable(varName) != null;
    }

    public void setGradientForVariableName(String variableName, SDVariable variable) {
        Preconditions.checkState(this.variables.containsKey(variableName), "No variable exists with name \"%s\"", (Object)variableName);
        if (variable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + variableName);
        }
        this.variables.get(variableName).setGradient(variable);
    }

    public SDVariable grad(String varName) {
        if (!this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY)) {
            this.createGradFunction();
        }
        SameDiff grad = this.getFunction(GRAD_FN_KEY);
        SDVariable var = grad.getVariable(varName);
        return this.getFunction(GRAD_FN_KEY).getGradForVariable(var.name());
    }

    public SDVariable scalar(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable scalar(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.var(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable constant(double value) {
        return this.constant(null, value);
    }

    public SDVariable constant(String name, double value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(float value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, float value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(int value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, int value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(long value) {
        return this.constant((String)null, value);
    }

    public SDVariable constant(String name, long value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(value));
            return sDVariable;
        }
    }

    public SDVariable constant(String name, DataType dataType, Number value) {
        try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
            SDVariable sDVariable = this.constant(name, Nd4j.scalar(dataType, value));
            return sDVariable;
        }
    }

    public SDVariable addVariable(SDVariable variable) {
        Preconditions.checkState(variable.getSameDiff() == this, "Samediff instance must be the same.");
        if (this.variables.containsKey(variable.name()) && !this.variables.get(variable.name()).getVariable().equals(variable)) {
            throw new IllegalArgumentException("Variable with name \"" + variable.name() + "\" already exists");
        }
        Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!");
        this.variables.put(variable.name(), Variable.builder().name(variable.name()).variable(variable).build());
        return variable;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport) {
        if (baseName == null) {
            baseName = function.getOwnName();
        }
        if (baseName == null) {
            baseName = function.opName();
        }
        List<DataType> outputDataTypes = null;
        if (!isImport) {
            ArrayList<DataType> inputDataTypes = new ArrayList<DataType>();
            List<String> fnInputs = this.ops.get(function.getOwnName()).getInputsToOp();
            if (fnInputs != null) {
                for (String var : fnInputs) {
                    inputDataTypes.add(this.variables.get(var).getVariable().dataType());
                }
            }
            outputDataTypes = function.calculateOutputDataTypes(inputDataTypes);
        }
        if (function instanceof CustomOp) {
            CustomOp customOp = (CustomOp)((Object)function);
            int num_outputs = function.getNumOutputs();
            if (num_outputs <= 0) {
                CustomOpDescriptor descriptor = customOp.getDescriptor();
                if (descriptor != null) {
                    num_outputs = descriptor.getNumOutputs();
                }
                if (num_outputs <= 0) {
                    throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override getNumOutputs() to specify number of outputs if required");
                }
            }
            SDVariable[] ret = new SDVariable[num_outputs];
            Preconditions.checkState(isImport || outputDataTypes != null && outputDataTypes.size() == num_outputs, "Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", (Object)(outputDataTypes == null ? null : Integer.valueOf(outputDataTypes.size())), (Object)num_outputs, outputDataTypes, (Object)function.getClass().getSimpleName());
            for (int i = 0; i < ret.length; ++i) {
                SDVariable var;
                SDVariable sDVariable = var = i == 0 ? this.getVariable(baseName) : this.getVariable(baseName + ":" + i);
                if (var == null) {
                    DataType dataType = isImport ? null : outputDataTypes.get(i);
                    var = this.var(this.generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null);
                }
                var.setCreator(function);
                ret[i] = var;
            }
            if (this.getOutputsForOp(function) == null) {
                this.addOutgoingFor(ret, function);
            }
            return ret;
        }
        if (function instanceof BaseOp) {
            DataType dataType;
            SDVariable[] ret = new SDVariable[1];
            SDVariable checkGet = this.getVariable(baseName);
            SDVariable[] args = function.args();
            if (checkGet == null) {
                dataType = outputDataTypes.get(0);
                checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
            }
            if (checkGet == null) {
                dataType = outputDataTypes.get(0);
                checkGet = this.var(baseName, VariableType.ARRAY, null, dataType, (long[])null);
            }
            checkGet.setCreator(function);
            ret[0] = checkGet;
            if (this.getOutputsForOp(function) == null) {
                this.addOutgoingFor(ret, function);
            }
            return ret;
        }
        throw new RuntimeException("Unknown op type: " + function.getClass());
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
        return this.generateOutputVariableForOp(function, function.getOwnName() != null ? function.getOwnName() : function.opName(), false);
    }

    public SameDiff getFunction(String functionName) {
        return this.sameDiffFunctionInstances.get(functionName);
    }

    public TensorArray tensorArray(DataType dataType) {
        TensorArray ta = new TensorArray(this, dataType);
        SDVariable[] outVars = ta.outputVariables();
        return ta;
    }

    public SDVariable invokeFunctionOn(String functionName, SameDiff with) {
        SameDiff instance = this.sameDiffFunctionInstances.get(functionName);
        SDVariable ret = instance.invokeGraphOn(with);
        return ret;
    }

    public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub;
            this.child = sub = SameDiff.create();
            sub.parent = this;
            SDVariable[] ret = new SDVariable[variables.length];
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = sub.var(variables[i]);
            }
            functionDefinition.define(sub, null, ret);
            this.sameDiffFunctionInstances.put(function, sub);
        }
        this.child = null;
        return this.sameDiffFunctionInstances.get(function);
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition) {
        this.defineFunction(function, functionDefinition, new LinkedHashMap<String, INDArray>());
    }

    public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String, INDArray> inputs) {
        if (!this.sameDiffFunctionInstances.containsKey(function)) {
            SameDiff sub = SameDiff.create();
            functionDefinition.define(sub, inputs, null);
            this.sameDiffFunctionInstances.put(function, sub);
        }
    }

    public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, String ... variables) {
        if (variables == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        Preconditions.checkArgument(variables.length > 0, "No variables were specified");
        return this.calculateGradients(placeholderVals, Arrays.asList(variables));
    }

    public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
        if (variables == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
        OutAndGrad oag = this.calculateGradientsAndOutputs(placeholderVals, null, variables);
        return oag.getGradients();
    }

    public OutAndGrad calculateGradientsAndOutputs(Map<String, INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars) {
        HashMap<String, INDArray> outGrads;
        Preconditions.checkArgument(outputVars != null && !outputVars.isEmpty() || gradientVars != null && !gradientVars.isEmpty(), "No variables were specified for either output or gradients");
        if (this.getFunction(GRAD_FN_KEY) == null) {
            this.createGradFunction();
        }
        ArrayList<String> varNames = new ArrayList<String>();
        if (outputVars != null) {
            varNames.addAll(outputVars);
        }
        if (gradientVars != null) {
            for (String s : gradientVars) {
                Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", (Object)s);
                SDVariable v = this.getVariable(s).getGradient();
                if (v == null) continue;
                varNames.add(v.name());
            }
        }
        SameDiff gradFn = this.getFunction(GRAD_FN_KEY);
        gradFn.setListeners(this.listeners);
        Map<String, INDArray> grads = gradFn.batchOutputHelper(placeholderVals, null, Operation.TRAINING, varNames.toArray(new String[0]));
        HashMap<String, INDArray> outOutputs = outputVars == null ? null : new HashMap<String, INDArray>();
        HashMap<String, INDArray> hashMap = outGrads = gradientVars == null ? null : new HashMap<String, INDArray>();
        if (outputVars != null) {
            for (String s : outputVars) {
                outOutputs.put(s, grads.get(s));
            }
        }
        if (gradientVars != null) {
            for (String s : gradientVars) {
                if (this.getVariable(s).getGradient() == null) continue;
                String gradVar = this.getVariable(s).getGradient().name();
                outGrads.put(s, grads.get(gradVar));
            }
        }
        return new OutAndGrad(outOutputs, outGrads);
    }

    public boolean hasGradientFunction() {
        return this.sameDiffFunctionInstances.containsKey(GRAD_FN_KEY);
    }

    public void createGradFunction() {
        this.createGradFunction(null);
    }

    public void createGradFunction(final String ... variablesRequiringGradients) {
        if (this.lossVariables.isEmpty()) {
            if (this.trainingConfig != null && this.trainingConfig.getLossVariables() != null && !this.trainingConfig.getLossVariables().isEmpty()) {
                this.lossVariables.addAll(this.trainingConfig.getLossVariables());
            } else {
                String[] lossInferred = this.bestGuessLossVariables();
                if (lossInferred.size() == 1) {
                    String outName = lossInferred.get(0);
                    String opName = this.variables.get(outName).getOutputOfOp();
                    if (opName == null || !(this.ops.get(opName).getOp() instanceof ExternalErrorsFunction)) {
                        log.info("Inferring output \"{}\" as loss variable as none were previously set.Use SameDiff.setLossVariables() or SDVariable.markAsLoss() to override", lossInferred.get(0));
                    }
                    this.lossVariables.add((String)lossInferred.get(0));
                } else if (lossInferred.isEmpty()) {
                    for (SameDiffOp o : this.ops.values()) {
                        if (!(o.getOp() instanceof ExternalErrorsFunction)) continue;
                        List<String> l = o.getOutputsOfOp();
                        this.lossVariables.add(l.get(0));
                    }
                }
            }
        }
        Preconditions.checkState(!this.lossVariables.isEmpty(), "Cannot create gradient function: No loss variables (variables to minimize) have been specified. Loss variables are the variables that represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()");
        if (log.isTraceEnabled()) {
            log.trace("Defining function \"grad\"");
        }
        if (variablesRequiringGradients != null && variablesRequiringGradients.length > 0) {
            for (String s : variablesRequiringGradients) {
                Preconditions.checkArgument(this.variables.containsKey(s), "Cannot ensure gradient exists for variable: no variable with name \"%s\" exists", (Object)s);
                DataType dt = this.variables.get(s).getVariable().dataType();
                Preconditions.checkState(dt.isFPType(), "Cannot ensure gradient exists for variable \"%s\": variable is not a floating point SDVariable. Only floating point SDVariables have gradients defined - variable has type %s", (Object)s, (Object)dt);
            }
        }
        final SameDiff outer = this;
        this.defineFunction(GRAD_FN_KEY, new SameDiffFunctionDefinition(){

            @Override
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
                List<String> inputsToOp;
                ArrayList allFunctions;
                sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
                if (SameDiff.this.debugMode) {
                    sameDiff.enableDebugMode();
                }
                outer.invokeGraphOn(sameDiff);
                if (SameDiff.this.debugMode) {
                    Preconditions.checkState(sameDiff.ops.keySet().equals(SameDiff.this.ops.keySet()), "ops keysets not equal");
                }
                if ((allFunctions = new ArrayList(sameDiff.ops.values())).isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                }
                for (SameDiffOp op : allFunctions) {
                    SDVariable[] outputs;
                    SDVariable[] args;
                    DifferentialFunction func = op.getOp();
                    for (SDVariable arg : args = func.args()) {
                        arg.setSameDiff(sameDiff);
                    }
                    for (SDVariable output : outputs = func.outputVariables()) {
                        output.setSameDiff(sameDiff);
                    }
                    func.setSameDiff(sameDiff);
                }
                ArrayList<Object> finalOutputs = new ArrayList<Object>(SameDiff.this.lossVariables.size());
                SDVariable initialGrad = sameDiff.var("one-var", Nd4j.scalar(1.0f));
                for (String s : SameDiff.this.lossVariables) {
                    Preconditions.checkNotNull(s, "Encountered null value in loss variables. Null loss variables are not allowed. Use SameDiff.setLossVariables with non-null array names to fix");
                    Preconditions.checkState(SameDiff.this.variables.containsKey(s), "Specified loss function variable \"%s\" does not exist", (Object)s);
                    Object v = ((Variable)SameDiff.this.variables.get(s)).getVariable();
                    Preconditions.checkState(((SDVariable)v).dataType().isFPType(), "Specified loss function variable \"%s\" is not a floatingpoint variable (datatype: %s). Only floating point variables may be used as loss function variable", (Object)s, (Object)((SDVariable)v).dataType());
                    v = ((SDVariable)v).sum(new int[0]);
                    if (((SDVariable)v).dataType() == initialGrad.dataType()) {
                        sameDiff.setGradientForVariableName(((SDVariable)v).name(), initialGrad);
                    } else {
                        sameDiff.setGradientForVariableName(((SDVariable)v).name(), initialGrad.castTo(((SDVariable)v).dataType()));
                    }
                    if (finalOutputs.contains(v)) {
                        log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", (Object)s);
                        continue;
                    }
                    finalOutputs.add(v);
                }
                if (log.isTraceEnabled()) {
                    String s;
                    Object[] initialOutputsStr = ((SameDiffOp)allFunctions.get(allFunctions.size() - 1)).getOp().outputVariablesNames();
                    s = initialOutputsStr == null ? "null" : Arrays.toString(initialOutputsStr);
                    log.trace("Defining backward function: initial outputs {}", (Object)s);
                }
                HashSet<String> allFpVarsConnectedToLoss = new HashSet<String>();
                LinkedList<String> toProcess = new LinkedList<String>();
                for (String s : SameDiff.this.lossVariables) {
                    if (toProcess.contains(s)) continue;
                    toProcess.add(s);
                }
                while (!toProcess.isEmpty()) {
                    Variable v;
                    String next = (String)toProcess.remove();
                    if (allFpVarsConnectedToLoss.contains(next) || !(v = (Variable)SameDiff.this.variables.get(next)).getVariable().dataType().isFPType()) continue;
                    allFpVarsConnectedToLoss.add(v.getName());
                    if (v.getOutputOfOp() == null) continue;
                    String opName = v.getOutputOfOp();
                    SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                    List<String> opInputs = op.getInputsToOp();
                    if (opInputs == null) continue;
                    for (String s : opInputs) {
                        Variable inputVar = (Variable)SameDiff.this.variables.get(s);
                        if (!inputVar.getVariable().dataType().isFPType()) continue;
                        toProcess.add(s);
                    }
                }
                HashSet minimalSubgraphVars = new HashSet(allFpVarsConnectedToLoss);
                LinkedList<String> leafFPVars = new LinkedList<String>();
                for (String s : allFpVarsConnectedToLoss) {
                    boolean isUserRequested;
                    Variable v = (Variable)SameDiff.this.variables.get(s);
                    if (v.getVariable().getVariableType() == VariableType.ARRAY) {
                        String opName = v.getOutputOfOp();
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        inputsToOp = op.getInputsToOp();
                        boolean anyInputsInSubgraph = false;
                        if (inputsToOp != null) {
                            for (String string : inputsToOp) {
                                if (!allFpVarsConnectedToLoss.contains(string)) continue;
                                anyInputsInSubgraph = true;
                                break;
                            }
                        }
                        if (!anyInputsInSubgraph) {
                            leafFPVars.add(s);
                        }
                    }
                    VariableType vt = v.getVariable().getVariableType();
                    boolean bl = isUserRequested = variablesRequiringGradients != null && ArrayUtils.contains(variablesRequiringGradients, s);
                    if (vt != VariableType.CONSTANT && vt != VariableType.PLACEHOLDER || isUserRequested) continue;
                    leafFPVars.add(s);
                }
                while (!leafFPVars.isEmpty()) {
                    String nextLeaf = (String)leafFPVars.remove();
                    Variable v = (Variable)SameDiff.this.variables.get(nextLeaf);
                    minimalSubgraphVars.remove(nextLeaf);
                    List<String> inputsTo = v.getInputsForOp();
                    if (inputsTo == null || inputsTo.isEmpty()) continue;
                    for (String opName : inputsTo) {
                        List<String> list;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        List<String> inputsToOp2 = op.getInputsToOp();
                        boolean anyPresent = false;
                        for (String string : inputsToOp2) {
                            if (!minimalSubgraphVars.contains(string) && (variablesRequiringGradients == null || !ArrayUtils.contains(variablesRequiringGradients, string))) continue;
                            anyPresent = true;
                            break;
                        }
                        if (anyPresent || (list = op.getOutputsOfOp()) == null) continue;
                        for (String s3 : list) {
                            if (leafFPVars.contains(s3)) continue;
                            leafFPVars.add(s3);
                        }
                    }
                }
                Preconditions.checkState(!minimalSubgraphVars.isEmpty(), "Cannot differentiate graph relative to the specified loss function variables %s: graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", (Object)SameDiff.this.lossVariables);
                LinkedList<String> availableForDiff = new LinkedList<String>();
                for (Object lossVar : finalOutputs) {
                    Variable v = (Variable)sameDiff.variables.get(((SDVariable)lossVar).name());
                    if (v.getOutputOfOp() == null) continue;
                    String opName = v.getOutputOfOp();
                    availableForDiff.add(opName);
                }
                HashMap prerequisites = new HashMap();
                for (String var : minimalSubgraphVars) {
                    Variable variable = (Variable)SameDiff.this.variables.get(var);
                    List<String> inputsForOp = variable.getInputsForOp();
                    if (inputsForOp == null) continue;
                    ArrayList<String> req = new ArrayList<String>();
                    for (String string : inputsForOp) {
                        SameDiffOp sameDiffOp = (SameDiffOp)SameDiff.this.ops.get(string);
                        List<String> opOutputs = sameDiffOp.getOutputsOfOp();
                        boolean anyOpOutputsRequired = false;
                        if (opOutputs != null) {
                            for (String s : opOutputs) {
                                if (!minimalSubgraphVars.contains(s)) continue;
                                anyOpOutputsRequired = true;
                                break;
                            }
                        }
                        if (!anyOpOutputsRequired) continue;
                        req.add(string);
                    }
                    prerequisites.put(variable.getName(), req);
                }
                HashSet<String> differentiatedOps = new HashSet<String>();
                while (!availableForDiff.isEmpty()) {
                    List<Object> outputsOfOp;
                    String dfName = (String)availableForDiff.remove();
                    DifferentialFunction df = ((SameDiffOp)sameDiff.ops.get(dfName)).getOp();
                    if (df instanceof GradientBackwardsMarker) {
                        SameDiffOp op = (SameDiffOp)sameDiff.ops.get(df.getOwnName());
                        inputsToOp = op.getInputsToOp();
                        outputsOfOp = Collections.emptyList();
                    } else {
                        inputsToOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getInputsToOp();
                        outputsOfOp = ((SameDiffOp)sameDiff.ops.get(df.getOwnName())).getOutputsOfOp();
                    }
                    ArrayList<SDVariable> grads = new ArrayList<SDVariable>();
                    for (String string : outputsOfOp) {
                        SDVariable g;
                        SDVariable v = sameDiff.getVariable(string);
                        SDVariable sDVariable = g = v.hasGradient() ? v.gradient() : null;
                        if (g == null) {
                            if (!v.dataType().isFPType()) {
                                grads.add(null);
                                continue;
                            }
                            SDVariable gTemp = sameDiff.zerosLike(v);
                            grads.add(gTemp);
                            continue;
                        }
                        grads.add(g);
                    }
                    List<SDVariable> list = df.diff(grads);
                    differentiatedOps.add(df.getOwnName());
                    for (String s : inputsToOp) {
                        Variable v = (Variable)sameDiff.variables.get(s);
                        String opName = v.getOutputOfOp();
                        if (opName == null || differentiatedOps.contains(opName)) continue;
                        boolean isRequiredOp = false;
                        SameDiffOp op = (SameDiffOp)SameDiff.this.ops.get(opName);
                        if (op.getInputsToOp() != null) {
                            List<String> opInputs = op.getInputsToOp();
                            boolean anyInputsRequired = false;
                            for (String s2 : opInputs) {
                                if (!minimalSubgraphVars.contains(s2)) continue;
                                anyInputsRequired = true;
                                break;
                            }
                            if (anyInputsRequired && !differentiatedOps.contains(op.getName())) {
                                isRequiredOp = true;
                            }
                        }
                        if (!isRequiredOp) continue;
                        boolean allAvailable = true;
                        SameDiffOp o = (SameDiffOp)sameDiff.ops.get(opName);
                        for (String opOutput : o.getOutputsOfOp()) {
                            Variable outVar = (Variable)SameDiff.this.variables.get(opOutput);
                            if (!outVar.getVariable().dataType().isFPType() || !minimalSubgraphVars.contains(outVar.getName())) continue;
                            if (outVar.getVariable().gradient() == null) {
                                allAvailable = false;
                                break;
                            }
                            List prereqs = (List)prerequisites.get(outVar.getName());
                            if (prereqs == null || (allAvailable &= differentiatedOps.containsAll(prereqs))) continue;
                            break;
                        }
                        if (!allAvailable || availableForDiff.contains(o.getOp().getOwnName())) continue;
                        availableForDiff.add(o.getOp().getOwnName());
                    }
                }
                for (String s : minimalSubgraphVars) {
                    SDVariable v;
                    SDVariable g;
                    if (SameDiff.this.lossVariables.contains(s) || (g = (v = ((Variable)SameDiff.this.variables.get(s)).getVariable()).gradient()) != null) continue;
                    throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + s + "\" was calculated");
                }
                return new SDVariable[]{sameDiff.var(SameDiff.GRAD_FN_KEY, DataType.FLOAT, 1)};
            }
        });
        this.associateSameDiffWithOpsAndVariables();
    }

    protected List<String> bestGuessLossVariables() {
        ArrayList<String> out = new ArrayList<String>();
        for (Variable v : this.variables.values()) {
            String opName;
            SameDiffOp o;
            if (v.getVariable().isConstant() || v.getVariable().isPlaceHolder() || v.getInputsForOp() != null && !v.getInputsForOp().isEmpty() || v.getControlDepsForOp() != null && !v.getControlDepsForOp().isEmpty() || v.getControlDepsForVar() != null && !v.getControlDepsForVar().isEmpty() || v.getOutputOfOp() != null && v.getVariable().dataType().isFPType() && ((o = this.ops.get(opName = v.getOutputOfOp())).getOp() instanceof Assert || o.getOp() instanceof Switch)) continue;
            out.add(v.getName());
        }
        return out;
    }

    public boolean isPlaceHolder(String varName) {
        Preconditions.checkState(this.variables.containsKey(varName), "No variable present in SameDiff instance with name \"%s\"", (Object)varName);
        return this.variables.get(varName).getVariable().isPlaceHolder();
    }

    public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
        String nameScope;
        if (varToUpdate == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (newVarName != null && (nameScope = this.currentNameScope()) != null && !newVarName.startsWith(nameScope + "/")) {
            newVarName = nameScope + "/" + newVarName;
        }
        if (newVarName != null && this.variables.containsKey(newVarName) && varToUpdate != this.variables.get(newVarName).getVariable()) {
            throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
        }
        if (newVarName == null && this.variables.containsKey(varToUpdate.name()) && this.variables.get(varToUpdate.name()).getVariable() != varToUpdate) {
            newVarName = this.generateNewVarName(varToUpdate.name(), 0);
        }
        if (newVarName == null || varToUpdate.name().equals(newVarName)) {
            return varToUpdate;
        }
        String oldVarName = varToUpdate.name();
        varToUpdate.setVarName(newVarName);
        this.renameVariable(oldVarName, newVarName);
        return varToUpdate;
    }

    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames) {
        int numVariables = variablesToUpdate.length;
        SDVariable[] updatedVariables = new SDVariable[numVariables];
        for (int i = 0; i < numVariables; ++i) {
            SDVariable varToUpdate = variablesToUpdate[i];
            String name = newVariableNames == null ? null : newVariableNames[i];
            updatedVariables[i] = this.updateVariableNameAndReference(varToUpdate, name);
        }
        return updatedVariables;
    }

    protected void associateSameDiffWithOpsAndVariables() {
        for (SDVariable var : this.variableMap().values()) {
            var.setSameDiff(this);
        }
        for (SameDiffOp op : this.ops.values()) {
            SDVariable[] outputs;
            DifferentialFunction df = op.getOp();
            df.setSameDiff(this);
            SDVariable[] args = df.args();
            if (args != null) {
                for (SDVariable arg : args) {
                    arg.setSameDiff(this);
                }
            }
            if ((outputs = df.outputVariables()) == null) continue;
            for (SDVariable out : outputs) {
                out.setSameDiff(this);
            }
        }
    }

    protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull FlatBufferBuilder bufferBuilder) {
        if (scope == null) {
            throw new NullPointerException("scope is marked non-null but is null");
        }
        if (bufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked non-null but is null");
        }
        int scopeName = bufferBuilder.createString(name);
        int flatNode = FlatNode.createFlatNode(bufferBuilder, scopeName, scopeName, (byte)119, 10L, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
        return flatNode;
    }

    public static Pair<String, Integer> parseVariable(@NonNull String varName) {
        if (varName == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        if (!varName.contains(":")) {
            return Pair.pairOf(varName, 0);
        }
        String[] split = varName.split(":");
        Integer index = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf(split[0], index);
        }
        StringBuilder builder = new StringBuilder();
        for (int e = 0; e < split.length - 1; ++e) {
            builder.append(split[e]);
            if (e >= split.length - 2) continue;
            builder.append(":");
        }
        return Pair.pairOf(builder.toString(), index);
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        return this.asFlatBuffers(0L, configuration, includeUpdaterState);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger idCounter = new AtomicInteger(0);
        ArrayList<Integer> flatVariables = new ArrayList<Integer>();
        ArrayList flatOffsets = new ArrayList();
        ArrayList<Integer> flatNodes = new ArrayList<Integer>();
        ArrayList<SDVariable> variableList = new ArrayList<SDVariable>(this.variables());
        LinkedHashMap<String, Integer> reverseMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> forwardMap = new LinkedHashMap<String, Integer>();
        LinkedHashMap<String, Integer> framesMap = new LinkedHashMap<String, Integer>();
        boolean idx = false;
        IdentityHashMap<DifferentialFunction, Integer> idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
        List<SDVariable> allVars = this.variables();
        for (SDVariable variable : allVars) {
            int[] cdsForVar;
            int[] cdsForOp;
            long[] shp;
            int outputNum;
            int varIdx;
            INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr();
            log.trace("Exporting variable: [{}]", (Object)variable.name());
            String varName = variable.name();
            if (this.variables.get(varName).getOutputOfOp() != null) {
                DifferentialFunction df = this.ops.get(this.variables.get(varName).getOutputOfOp()).getOp();
                if (!idxForOps.containsKey(df)) {
                    varIdx = idCounter.incrementAndGet();
                    idxForOps.put(df, varIdx);
                } else {
                    varIdx = (Integer)idxForOps.get(df);
                }
                Object[] outNames = df.outputVariablesNames();
                outputNum = ArrayUtils.indexOf(outNames, varName);
                Preconditions.checkState(outputNum >= 0, "Variable name \"%s\" not found in list of outputs: %s", (Object)varName, (Object)outNames);
            } else {
                varIdx = idCounter.incrementAndGet();
                outputNum = 0;
            }
            reverseMap.put(variable.name(), varIdx);
            log.trace("Adding [{}] as [{}]", (Object)variable.name(), (Object)varIdx);
            int shape = 0;
            int name = bufferBuilder.createString(variable.name());
            int array = 0;
            int id = IntPair.createIntPair(bufferBuilder, varIdx, outputNum);
            byte varType = (byte)variable.getVariableType().ordinal();
            if (variable.isConstant() || variable.isPlaceHolder() || variable.getVariableType() == VariableType.VARIABLE) {
                int n = array = arr == null ? 0 : arr.toFlatArray(bufferBuilder);
            }
            if (variable.getVariableType() == VariableType.PLACEHOLDER && (shp = variable.getShape()) != null) {
                shape = FlatVariable.createShapeVector(bufferBuilder, shp);
            }
            int controlDeps = 0;
            int controlDepsForOp = 0;
            int controlDepsForVar = 0;
            Variable v = this.variables.get(varName);
            int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder);
            if (cds != null) {
                controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds);
            }
            if ((cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder)) != null) {
                controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp);
            }
            if ((cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder)) != null) {
                controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar);
            }
            int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar);
            flatVariables.add(flatVariable);
        }
        for (SameDiffOp op : this.ops.values()) {
            DifferentialFunction func = op.getOp();
            Integer fnId = (Integer)idxForOps.get(func);
            flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId));
        }
        int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
        int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
        int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
        int numPlaceholders = 0;
        for (SDVariable v : this.variables()) {
            if (!v.isPlaceHolder()) continue;
            ++numPlaceholders;
        }
        int[] placeholderOffsets = new int[numPlaceholders];
        if (numPlaceholders > 0) {
            int i = 0;
            for (SDVariable v : this.variables()) {
                if (!v.isPlaceHolder()) continue;
                placeholderOffsets[i++] = bufferBuilder.createString(v.name());
            }
        }
        int placeholdersOffset = FlatGraph.createPlaceholdersVector(bufferBuilder, placeholderOffsets);
        List<String> lossVars = this.getLossVariables();
        int[] lossVarOffsets = new int[lossVars == null ? 0 : lossVars.size()];
        for (int i = 0; i < lossVarOffsets.length; ++i) {
            lossVarOffsets[i] = bufferBuilder.createString(lossVars.get(i));
        }
        int lossVarOffset = FlatGraph.createLossVariablesVector(bufferBuilder, lossVarOffsets);
        int trainingConfigOffset = 0;
        int updaterStateOffset = 0;
        if (this.trainingConfig != null) {
            String json = this.trainingConfig.toJson();
            trainingConfigOffset = bufferBuilder.createString(json);
        }
        if (includeUpdaterState && this.updaterMap != null && !this.updaterMap.isEmpty()) {
            int[] updaterOffsets = new int[this.updaterMap.size()];
            int updaterNum = 0;
            for (Map.Entry<String, GradientUpdater> g : this.updaterMap.entrySet()) {
                int paramNameOffset = bufferBuilder.createString(g.getKey());
                int stateKeyOffset = 0;
                int stateValuesOffset = 0;
                Map<String, INDArray> state = g.getValue().getState();
                if (state != null && !state.isEmpty()) {
                    int[] keysOffsets = new int[state.size()];
                    int[] valuesOffsets = new int[state.size()];
                    int i = 0;
                    for (Map.Entry<String, INDArray> e : state.entrySet()) {
                        keysOffsets[i] = bufferBuilder.createString(e.getKey());
                        valuesOffsets[i] = e.getValue().toFlatArray(bufferBuilder);
                        ++i;
                    }
                    stateKeyOffset = UpdaterState.createUpdaterStateKeysVector(bufferBuilder, keysOffsets);
                    stateValuesOffset = UpdaterState.createUpdaterStateValuesVector(bufferBuilder, valuesOffsets);
                }
                updaterOffsets[updaterNum++] = UpdaterState.createUpdaterState(bufferBuilder, paramNameOffset, stateKeyOffset, stateValuesOffset);
            }
            updaterStateOffset = FlatGraph.createUpdaterStateVector(bufferBuilder, updaterOffsets);
        }
        int fg = FlatGraph.createFlatGraph(bufferBuilder, graphId, variablesOffset, nodesOffset, outputsOffset, configuration.getFlatConfiguration(bufferBuilder), placeholdersOffset, lossVarOffset, trainingConfigOffset, updaterStateOffset);
        bufferBuilder.finish(fg);
        SameDiff sameDiff = this;
        synchronized (sameDiff) {
            for (Map.Entry<String, Integer> e : reverseMap.entrySet()) {
                this.variables.get(e.getKey()).setVariableIndex(e.getValue());
            }
        }
        return bufferBuilder.dataBuffer();
    }

    public FlatGraph asFlatGraph(boolean includeUpdaterState) {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(includeUpdaterState));
    }

    public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState) {
        return FlatGraph.getRootAsFlatGraph(this.asFlatBuffers(graphId, configuration, includeUpdaterState));
    }

    public ByteBuffer asFlatBuffers(boolean includeUpdaterState) {
        ExecutorConfiguration configuration = ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build();
        return this.asFlatBuffers(configuration, includeUpdaterState);
    }

    public void save(@NonNull File file, boolean saveUpdaterState) {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        try {
            this.asFlatFile(file, saveUpdaterState);
        }
        catch (IOException e) {
            throw new RuntimeException("Error saving SameDiff instance to file", e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void save(@NonNull OutputStream outputStream, boolean saveUpdater) {
        if (outputStream == null) {
            throw new NullPointerException("outputStream is marked non-null but is null");
        }
        File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp");
        try {
            this.save(tempFile, saveUpdater);
            if (!(outputStream instanceof BufferedOutputStream)) {
                outputStream = new BufferedOutputStream(outputStream);
            }
            try (OutputStream os = outputStream;
                 BufferedInputStream is = new BufferedInputStream(new FileInputStream(tempFile));){
                IOUtils.copy((InputStream)is, os);
            }
            catch (IOException e) {
                throw new RuntimeException("Error writing to output stream (or reading from temp file)", e);
            }
        }
        finally {
            tempFile.delete();
        }
    }

    public static SameDiff load(@NonNull File file, boolean loadUpdaterState) {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        try {
            return SameDiff.fromFlatFile(file, loadUpdaterState);
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading SameDiff instance from file", e);
        }
    }

    public static SameDiff load(@NonNull InputStream is, boolean loadUpdaterState) {
        if (is == null) {
            throw new NullPointerException("is is marked non-null but is null");
        }
        File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp");
        try {
            try (Object os = new BufferedOutputStream(new FileOutputStream(tempFile));){
                IOUtils.copy(is, (OutputStream)os);
            }
            os = SameDiff.fromFlatFile(tempFile, loadUpdaterState);
            return os;
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading SameDiff instance from file", e);
        }
        finally {
            tempFile.delete();
        }
    }

    public void asFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        this.asFlatFile(file, true);
    }

    public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        ByteBuffer fb = this.asFlatBuffers(withUpdaterState);
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        ByteBuffer fb = this.asFlatBuffers(configuration, includeUpdaterState);
        int offset = fb.position();
        byte[] array = fb.array();
        try (FileOutputStream fos = new FileOutputStream(file);
             BufferedOutputStream bos = new BufferedOutputStream(fos);
             DataOutputStream dos = new DataOutputStream(bos);){
            dos.write(array, offset, array.length - offset);
        }
    }

    public static SameDiff fromFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        return SameDiff.fromFlatFile(file, true);
    }

    public static SameDiff fromFlatFile(@NonNull File file, boolean loadUpdaterState) throws IOException {
        byte[] bytes;
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(file));){
            bytes = IOUtils.toByteArray(is);
        }
        ByteBuffer bbIn = ByteBuffer.wrap(bytes);
        return SameDiff.fromFlatBuffers(bbIn, loadUpdaterState);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException {
        return SameDiff.fromFlatBuffers(bbIn, true);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException {
        String tc;
        FlatGraph fg = FlatGraph.getRootAsFlatGraph(bbIn);
        int numOps = fg.nodesLength();
        int numVars = fg.variablesLength();
        ArrayList<FlatNode> ops = new ArrayList<FlatNode>(numOps);
        for (int i = 0; i < numOps; ++i) {
            ops.add(fg.nodes(i));
        }
        ArrayList<FlatVariable> vars = new ArrayList<FlatVariable>(numVars);
        for (int i = 0; i < numVars; ++i) {
            vars.add(fg.variables(i));
        }
        SameDiff sd = SameDiff.create();
        int numPlaceholders = fg.placeholdersLength();
        LinkedHashSet<String> ph = new LinkedHashSet<String>();
        for (int i = 0; i < numPlaceholders; ++i) {
            ph.add(fg.placeholders(i));
        }
        HashMap varNodeIds = new HashMap();
        HashMap<Pair<Integer, Integer>, SDVariable> variablesByNodeAndOutNum = new HashMap<Pair<Integer, Integer>, SDVariable>();
        HashMap variablesByName = new HashMap();
        for (FlatVariable v : vars) {
            FlatArray fa;
            int i;
            ArrayList<String> l;
            int num;
            int shapeLength = v.shapeLength();
            long[] shape = new long[shapeLength];
            for (int i2 = 0; i2 < shapeLength; ++i2) {
                shape[i2] = v.shape(i2);
            }
            String n = v.name();
            byte dtypeByte = v.dtype();
            DataType dtype = FlatBuffersMapper.getDataTypeFromByte(dtypeByte);
            VariableType vt = VariableType.values()[v.variabletype()];
            SDVariable var = new SDVariable(n, vt, sd, shape, dtype);
            sd.variables.put(n, Variable.builder().name(n).variable(var).build());
            Variable v2 = sd.variables.get(n);
            if (v.controlDepsLength() > 0) {
                num = v.controlDepsLength();
                l = new ArrayList<String>(num);
                for (i = 0; i < num; ++i) {
                    l.add(v.controlDeps(i));
                }
                v2.setControlDeps(l);
            }
            if (v.controlDepForOpLength() > 0) {
                num = v.controlDepForOpLength();
                l = new ArrayList(num);
                for (i = 0; i < num; ++i) {
                    l.add(v.controlDepForOp(i));
                }
                v2.setControlDepsForOp(l);
            }
            if (v.controlDepsForVarLength() > 0) {
                num = v.controlDepsForVarLength();
                l = new ArrayList(num);
                for (i = 0; i < num; ++i) {
                    l.add(v.controlDepsForVar(i));
                }
                v2.setControlDepsForVar(l);
            }
            if ((fa = v.ndarray()) != null && vt != VariableType.ARRAY) {
                INDArray arr;
                try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                    arr = Nd4j.createFromFlatArray(fa);
                }
                sd.setArrayForVariable(n, arr);
            }
            IntPair id = v.id();
            variablesByNodeAndOutNum.put(new Pair<Integer, Integer>(id.first(), id.second()), var);
            if (!variablesByName.containsKey(n)) {
                variablesByName.put(n, new ArrayList());
            }
            List list = (List)variablesByName.get(n);
            list.add(var);
        }
        for (FlatNode fn : ops) {
            int i;
            DifferentialFunction df = FlatBuffersMapper.fromFlatNode(fn);
            String name = fn.name();
            df.setSameDiff(sd);
            df.setOwnName(name);
            if (sd.ops.containsKey(name)) {
                sd.ops.get(name).setOp(df);
            } else {
                sd.ops.put(name, SameDiffOp.builder().name(name).op(df).build());
            }
            int outLength = fn.outputLength();
            int[] outs = new int[outLength];
            for (int i3 = 0; i3 < outLength; ++i3) {
                outs[i3] = fn.output(i3);
            }
            int opId = fn.id();
            int[] output = new int[fn.outputLength()];
            for (int i4 = 0; i4 < output.length; ++i4) {
                output[i4] = fn.output(i4);
            }
            int[] input = new int[fn.inputLength()];
            for (int i5 = 0; i5 < input.length; ++i5) {
                input[i5] = fn.input(i5);
            }
            IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
            ArrayList<Pair<Integer, Integer>> intPairList = new ArrayList<Pair<Integer, Integer>>();
            for (int i6 = 0; i6 < inputPaired.length; ++i6) {
                inputPaired[i6] = fn.inputPaired(i6);
                intPairList.add(new Pair<Integer, Integer>(inputPaired[i6].first(), inputPaired[i6].second()));
            }
            String[] inputNames = new String[inputPaired.length];
            for (int i7 = 0; i7 < inputPaired.length; ++i7) {
                int nodeId = inputPaired[i7].first();
                int nodeOutNum = inputPaired[i7].second();
                SDVariable varIn = (SDVariable)variablesByNodeAndOutNum.get(new Pair<Integer, Integer>(nodeId, nodeOutNum));
                if (varIn == null) {
                    // empty if block
                }
                inputNames[i7] = varIn.name();
            }
            SameDiffOp op = sd.ops.get(df.getOwnName());
            op.setInputsToOp(Arrays.asList(inputNames));
            if (fn.controlDepsLength() > 0) {
                int l = fn.controlDepsLength();
                ArrayList<String> list = new ArrayList<String>(l);
                for (int i8 = 0; i8 < l; ++i8) {
                    list.add(fn.controlDeps(i8));
                }
                op.setControlDeps(list);
            }
            if (fn.varControlDepsLength() > 0) {
                int l = fn.varControlDepsLength();
                ArrayList<String> list = new ArrayList<String>(l);
                for (int i9 = 0; i9 < l; ++i9) {
                    list.add(fn.varControlDeps(i9));
                }
                op.setVarControlDeps(list);
            }
            if (fn.controlDepForLength() > 0) {
                int l = fn.controlDepForLength();
                ArrayList<String> list = new ArrayList<String>(l);
                for (int i10 = 0; i10 < l; ++i10) {
                    list.add(fn.controlDepFor(i10));
                }
                op.setControlDepFor(list);
            }
            for (String inName : inputNames) {
                Variable v = sd.getVariables().get(inName);
                if (v.getInputsForOp() == null) {
                    v.setInputsForOp(new ArrayList<String>());
                }
                if (v.getInputsForOp().contains(df.getOwnName())) continue;
                v.getInputsForOp().add(df.getOwnName());
            }
            List varsForOp = (List)variablesByName.get(name);
            int numOutputs = df.getNumOutputs();
            if (numOutputs <= 0) {
                numOutputs = fn.outputLength();
            }
            String[] varNames = null;
            if (varsForOp != null && varsForOp.size() == numOutputs) {
                varNames = new String[varsForOp.size()];
                for (i = 0; i < varNames.length; ++i) {
                    varNames[i] = ((SDVariable)varsForOp.get(i)).name();
                    sd.getVariables().get(varNames[i]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            } else {
                int outputNamesLength = fn.outputNamesLength();
                varNames = new String[outputNamesLength];
                for (int i11 = 0; i11 < outputNamesLength; ++i11) {
                    String n;
                    varNames[i11] = n = fn.outputNames(i11);
                    if (!sd.variables.containsKey(n)) {
                        SDVariable var = new SDVariable(n, VariableType.VARIABLE, sd, null, null);
                        sd.variables.put(n, Variable.builder().name(n).variable(var).build());
                        variablesByNodeAndOutNum.put(new Pair<Integer, Integer>(opId, i11), var);
                    }
                    sd.getVariables().get(varNames[i11]).setOutputOfOp(df.getOwnName());
                }
                sd.ops.get(df.getOwnName()).setOutputsOfOp(Arrays.asList(varNames));
            }
            for (i = 0; i < varNames.length; ++i) {
                Pair<Integer, Integer> p = new Pair<Integer, Integer>(opId, i);
                if (variablesByNodeAndOutNum.containsKey(p)) continue;
                variablesByNodeAndOutNum.put(p, sd.getVariable(varNames[i]));
            }
        }
        if (fg.lossVariablesLength() > 0) {
            for (int i = 0; i < fg.lossVariablesLength(); ++i) {
                sd.addLossVariable(fg.lossVariables(i));
            }
        }
        if ((tc = fg.trainingConfig()) != null) {
            sd.trainingConfig = TrainingConfig.fromJson(tc);
        }
        if (loadUpdaterState && fg.updaterStateLength() > 0) {
            sd.updaterMap = new HashMap<String, GradientUpdater>();
            int n = fg.updaterStateLength();
            for (int i = 0; i < n; ++i) {
                UpdaterState us = fg.updaterState(i);
                String name = us.paramName();
                int nKeys = us.updaterStateKeysLength();
                HashMap<String, INDArray> m = new HashMap<String, INDArray>();
                for (int j = 0; j < nKeys; ++j) {
                    String key = us.updaterStateKeys(j);
                    FlatArray fa = us.updaterStateValues(j);
                    INDArray stateArr = Nd4j.createFromFlatArray(fa);
                    m.put(key, stateArr);
                }
                GradientUpdater gu = sd.trainingConfig.getUpdater().instantiate(m, false);
                sd.updaterMap.put(name, gu);
            }
            sd.initializedTraining = true;
        }
        return sd;
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        ByteBuffer fb = this.asFlatBuffers(false);
        FlatGraph graph = FlatGraph.getRootAsFlatGraph(fb);
        sb.append("\nExternal variables:\n\n");
        for (int e = 0; e < graph.variablesLength(); ++e) {
            FlatVariable var = graph.variables(e);
            INDArray ndarray = null;
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                FlatArray fa = var.ndarray();
                if (fa != null) {
                    ndarray = Nd4j.createFromFlatArray(fa);
                }
            }
            sb.append(var.id().first()).append(":<").append(var.name()).append("> ");
            if (ndarray == null) {
                sb.append("<no array>").append("; Values: ").append("<no array>").append(";\n");
                continue;
            }
            sb.append(Arrays.toString(ndarray.shapeInfoDataBuffer().asInt())).append("; Values: ");
            if (ndarray.data() == null) {
                sb.append("<empty array>");
            } else if (ndarray.dataType() == DataType.UTF8) {
                sb.append("<string array>");
            } else if (ndarray.length() < 50L) {
                sb.append(Arrays.toString(ndarray.data().asFloat()).replaceAll(" ", ""));
            } else {
                sb.append("[");
                for (int i = 0; i < 50; ++i) {
                    if (i > 0) {
                        sb.append(",");
                    }
                    sb.append(ndarray.data().getFloat(i));
                }
                sb.append("]");
            }
            sb.append(";\n");
        }
        Map<String, CustomOpDescriptor> map = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int e = 0; e < graph.nodesLength(); ++e) {
            FlatNode node = graph.nodes(e);
            log.info("{}:<{}>", (Object)node.id(), (Object)node.name());
            sb.append(node.id()).append(":<").append(node.name()).append("> ").append((Object)FlatBuffersMapper.getTypeFromByte(node.opType()));
            if (FlatBuffersMapper.getTypeFromByte(node.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(node.opNum());
            } else {
                Set<String> keys = map.keySet();
                String opName = null;
                for (String k : keys) {
                    CustomOpDescriptor d = map.get(k);
                    if (d.getHash() != node.opNum()) continue;
                    opName = k;
                }
                if (opName == null) {
                    opName = "unknown";
                }
                sb.append(": ").append(opName);
            }
            sb.append("; Inputs: {");
            for (int i = 0; i < node.inputPairedLength(); ++i) {
                IntPair pair = node.inputPaired(i);
                sb.append("[").append(pair.first()).append(":").append(pair.second()).append("]");
                if (i >= node.inputPairedLength() - 1) continue;
                sb.append(", ");
            }
            sb.append("};");
            sb.append(" OpNum: {").append(node.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String summary() {
        Map<String, SDVariable> varMap = this.variableMap();
        DifferentialFunction[] functions = this.ops();
        int countVarsWithArrays = 0;
        for (String s : varMap.keySet()) {
            if (this.getArrForVarName(s) == null) continue;
            ++countVarsWithArrays;
        }
        StringBuilder sb = new StringBuilder();
        String format = "%-25s%-20s";
        sb.append("--- Summary ---\n");
        sb.append(String.format(format, "Variables:", varMap.size())).append(" (").append(countVarsWithArrays).append(" with arrays)").append("\n").append(String.format(format, "Functions:", functions.length)).append("\n").append(String.format(format, "SameDiff Function Defs:", this.sameDiffFunctionInstances.size())).append("\n").append("Loss function variables: ").append(this.getLossVariables()).append("\n\n");
        sb.append("--- Variables ---\n");
        HashMap<String, String> outputOfFn = new HashMap<String, String>();
        int maxLengthOutputOf = 22;
        int maxLengthOfName = 8;
        for (String s : varMap.keySet()) {
            String outputOf = null;
            for (SameDiffOp op : this.ops.values()) {
                List<String> outputsOfOp = op.getOutputsOfOp();
                if (outputsOfOp == null || !outputsOfOp.contains(s)) continue;
                outputOf = op.getName();
                break;
            }
            if (outputOf == null) {
                outputOf = "<none>";
            } else {
                DifferentialFunction d = this.getOpById(outputOf);
                outputOf = d.getOwnName() + "(" + d.opName() + ")";
            }
            outputOfFn.put(s, outputOf);
            maxLengthOutputOf = Math.max(maxLengthOutputOf, outputOf.length());
            maxLengthOfName = Math.max(maxLengthOfName, s.length());
        }
        format = "%-" + (maxLengthOfName += 2) + "s%-20s%-20s%-20s%-" + (maxLengthOutputOf += 2) + "s%-20s";
        sb.append(String.format(format, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n");
        for (String s : varMap.keySet()) {
            SDVariable v;
            long[] phShape;
            INDArray arr = this.getArrForVarName(s);
            String arrayShape = "-";
            if (arr != null) {
                arrayShape = Arrays.toString(arr.shape());
            } else if (varMap.get(s).isPlaceHolder() && (phShape = (v = varMap.get(s)).placeholderShape()) != null) {
                arrayShape = Arrays.toString(phShape);
            }
            String varType = this.getVariable(s).getVariableType().toString();
            String dtype = this.getVariable(s).dataType().toString();
            List<String> argNames = this.variables.get(s).getInputsForOp();
            String dfArrStr = "";
            if (argNames != null) {
                dfArrStr = argNames.toString();
            }
            String outputOfStr = (String)outputOfFn.get(s);
            sb.append(String.format(format, s, arrayShape, varType, dtype, outputOfStr, dfArrStr)).append("\n");
        }
        sb.append("\n\n--- Functions ---\n");
        ArrayList<String> dfInputStr = new ArrayList<String>();
        ArrayList<String> dfOutputStr = new ArrayList<String>();
        int maxInLength = 10;
        int maxOutLength = 11;
        int maxOpNameLength = 17;
        int maxDfClassNameLength = 10;
        for (DifferentialFunction df : functions) {
            Object[] argNames = df.argNames();
            Object[] outNames = df.outputVariablesNames();
            String argStr = Arrays.toString(argNames);
            String outStr = Arrays.toString(outNames);
            maxInLength = Math.max(maxInLength, argStr.length());
            maxOutLength = Math.max(maxOutLength, outStr.length());
            dfInputStr.add(argStr);
            dfOutputStr.add(outStr);
            String name = df.getOwnName() == null ? df.opName() : df.getOwnName();
            maxOpNameLength = Math.max(maxOpNameLength, name.length());
            maxDfClassNameLength = Math.max(maxDfClassNameLength, df.getClass().getSimpleName().length());
        }
        format = "%-5s%-" + (maxOpNameLength += 2) + "s%-" + (maxDfClassNameLength += 2) + "s%-" + (maxInLength += 2) + "s%-" + (maxOutLength += 2) + "s";
        sb.append(String.format(format, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n");
        for (int i = 0; i < functions.length; ++i) {
            DifferentialFunction df = functions[i];
            String fnName = df.getOwnName() == null ? df.opName() : df.getOwnName();
            sb.append(String.format(format, String.valueOf(i), fnName, df.getClass().getSimpleName(), dfInputStr.get(i), dfOutputStr.get(i))).append("\n");
        }
        if (this.sameDiffFunctionInstances.size() > 0) {
            sb.append("\n\n--- SameDiff Defined Functions ---\n");
            format = "%-20s%-15s%-15s%-15s";
            sb.append(String.format(format, "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n");
            for (Map.Entry<String, SameDiff> e : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff sd = e.getValue();
                int vars = sd.variableMap().size();
                int fns = sd.ops() == null ? 0 : sd.ops().length;
                int defFns = sd.definedFunctionNames().size();
                sb.append(String.format(format, e.getKey(), String.valueOf(vars), String.valueOf(fns), String.valueOf(defFns))).append("\n");
            }
        }
        return sb.toString();
    }

    public String newBlockName(String baseName) {
        if (baseName == null) {
            return null;
        }
        if (!this.blockNames.contains(baseName)) {
            this.blockNames.add(baseName);
            return baseName;
        }
        int i = 1;
        while (this.blockNames.contains(baseName + "_" + i)) {
            ++i;
        }
        this.blockNames.add(baseName + "_" + i);
        return baseName + "_" + i;
    }

    public static SameDiff importFrozenTF(File graphFile) {
        return TFGraphMapper.importGraph(graphFile);
    }

    public static SameDiff importFrozenTF(GraphDef graphDef) {
        return TFGraphMapper.importGraph(graphDef);
    }

    public static SameDiff importFrozenTF(InputStream graph) {
        return TFGraphMapper.importGraph(graph);
    }

    public String getOpName(String base, boolean force) {
        Matcher num;
        base = this.nameWithScope(base);
        if (force && this.ops.containsKey(base)) {
            throw new IllegalArgumentException("Op with name \"" + base + "\" already exists");
        }
        if (force) {
            return base;
        }
        int start = 1;
        if (base.contains("_") && base.matches(".*_\\d+") && (num = Pattern.compile("(.*)_(\\d+)").matcher(base)).find()) {
            start = Integer.parseInt(num.group(2));
            base = num.group(1);
        }
        String name = base;
        int i = start;
        while (true) {
            boolean varWithName = false;
            for (String varName : this.variables.keySet()) {
                if (!varName.startsWith(name + ":") && !varName.equals(name)) continue;
                varWithName = true;
            }
            if (!this.ops.containsKey(name) && !varWithName) break;
            name = base + "_" + i;
            ++i;
        }
        return name;
    }

    public String getOpName(String base) {
        return this.getOpName(base, false);
    }

    public String generateNewVarName(String base, int argIndex, boolean existingOp) {
        Matcher num;
        base = this.nameWithScope(base);
        if (argIndex > 0 && base.contains(":") && (num = Pattern.compile("(.*):(\\d+)").matcher(base)).find()) {
            argIndex = Integer.parseInt(num.group(2)) + 1;
            base = num.group(1);
        }
        if (!existingOp) {
            base = this.getOpName(base);
        }
        if (argIndex > 0) {
            base = base + ":" + argIndex;
        }
        if (this.variables.containsKey(base)) {
            throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists");
        }
        return base;
    }

    public String generateNewVarName(String base, int argIndex) {
        return this.generateNewVarName(base, argIndex, true);
    }

    public String generateDistinctCustomVariableName(String base) {
        if (!this.variables.containsKey(base)) {
            return base;
        }
        int inc = 1;
        while (this.variables.containsKey(base + "_" + inc)) {
            ++inc;
        }
        return base + "_" + inc;
    }

    public String toString() {
        return "SameDiff(nVars=" + this.variables.size() + ",nOps=" + this.ops.size() + ")";
    }

    public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody) {
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (trueBody == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (falseBody == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        return this.ifCond(null, null, cond, trueBody, falseBody);
    }

    public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody) {
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (trueBody == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (falseBody == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        return this.ifCond(null, ifName, cond, trueBody, falseBody);
    }

    public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody) {
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (trueBody == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (falseBody == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        ifName = this.newBlockName(ifName == null ? "if" : ifName);
        NameScope ifScope = this.sd.withNameScope(ifName);
        NameScope condScope = this.withNameScope("cond");
        final SDVariable pred = cond.define(this);
        condScope.close();
        if (pred.dataType() != DataType.BOOL) {
            for (SDVariable v : this.getVariablesInScope(ifScope)) {
                this.getVariables().remove(v.name());
            }
            for (SameDiffOp op : this.getOpsInScope(ifScope)) {
                for (String in : op.getInputsToOp()) {
                    this.removeArgFromOp(in, op.getOp());
                }
                this.getOps().remove(op.getName());
            }
            throw new IllegalStateException("Can not use " + pred.name() + " as the condition of an If statement, the condition must be a boolean.");
        }
        final HashMap<String, SDVariable[]> switches = new HashMap<String, SDVariable[]>();
        final HashSet<String> declared = Sets.newHashSet(this.variableMap().keySet());
        this.addArgumentInterceptor(new ArgumentInterceptor(){

            @Override
            public SDVariable intercept(SDVariable argument) {
                if (!declared.contains(argument.name())) {
                    return argument;
                }
                if (switches.containsKey(argument.name())) {
                    return ((SDVariable[])switches.get(argument.name()))[1];
                }
                SDVariable[] s = SameDiff.this.switchOp(argument, pred);
                switches.put(argument.name(), s);
                return s[1];
            }
        });
        NameScope trueScope = this.withNameScope("trueBody");
        SDVariable trueOut = trueBody.define(this);
        this.removeArgumentInterceptor();
        if (declared.contains(trueOut.name())) {
            SDVariable[] s = this.switchOp(trueOut, pred);
            switches.put(trueOut.name(), s);
            trueOut = s[1];
        }
        trueScope.close();
        final HashSet<String> declared2 = Sets.newHashSet(this.variableMap().keySet());
        this.sd.addArgumentInterceptor(new ArgumentInterceptor(){

            @Override
            public SDVariable intercept(SDVariable argument) {
                if (!declared2.contains(argument.name())) {
                    return argument;
                }
                if (switches.containsKey(argument.name())) {
                    return ((SDVariable[])switches.get(argument.name()))[0];
                }
                SDVariable[] s = SameDiff.this.switchOp(argument, pred);
                switches.put(argument.name(), s);
                return s[0];
            }
        });
        NameScope falseScope = this.withNameScope("falseBody");
        SDVariable falseOut = falseBody.define(this);
        this.removeArgumentInterceptor();
        if (declared2.contains(falseOut.name())) {
            SDVariable[] s = this.switchOp(falseOut, pred);
            switches.put(falseOut.name(), s);
            falseOut = s[0];
        }
        falseScope.close();
        SDVariable output = this.merge(trueOut, falseOut);
        ifScope.close();
        return this.updateVariableNameAndReference(output, outputName);
    }

    public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body) {
        if (loopVars == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (body == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        return this.whileLoop(null, null, loopVars, cond, body);
    }

    public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body) {
        if (loopVars == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (body == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        return this.whileLoop(null, loopName, loopVars, cond, body);
    }

    public SDVariable[] whileLoop(String[] outputNames, String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body) {
        if (loopVars == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (body == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        final String frameName = this.newBlockName(loopName == null ? "while" : loopName);
        NameScope loopScope = this.withNameScope(frameName);
        SDVariable[] entered = new SDVariable[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            entered[i] = new Enter(this, frameName, loopVars[i]).outputVariable();
        }
        SDVariable[] merged = new SDVariable[loopVars.length];
        Merge[] mergeOps = new Merge[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            mergeOps[i] = new Merge(this, entered[i], entered[i]);
            merged[i] = mergeOps[i].outputVariable();
        }
        NameScope condScope = this.withNameScope("cond");
        SDVariable cond_result = cond.define(this, merged);
        condScope.close();
        if (cond_result.dataType() != DataType.BOOL) {
            throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean.");
        }
        final HashSet<String> alreadyEntered = Sets.newHashSet();
        SDVariable[] trueSwitches = new SDVariable[loopVars.length];
        SDVariable[] exits = new SDVariable[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            SDVariable[] s = this.switchOp(merged[i], cond_result);
            trueSwitches[i] = s[1];
            alreadyEntered.add(s[1].name());
            exits[i] = new Exit(this, s[0]).outputVariable();
        }
        final HashSet<String> declared = Sets.newHashSet(this.variableMap().keySet());
        final HashMap done = new HashMap();
        final SameDiff sd = this;
        this.addArgumentInterceptor(new ArgumentInterceptor(){

            @Override
            public SDVariable intercept(SDVariable argument) {
                if (!declared.contains(argument.name())) {
                    return argument;
                }
                if (alreadyEntered.contains(argument.name())) {
                    return argument;
                }
                if (done.containsKey(argument.name())) {
                    return (SDVariable)done.get(argument.name());
                }
                SDVariable e = new Enter(sd, frameName, argument, true).outputVariable();
                done.put(argument.name(), e);
                return e;
            }
        });
        NameScope bodyScope = this.withNameScope("body");
        SDVariable[] outs = body.define(this, trueSwitches);
        bodyScope.close();
        this.removeArgumentInterceptor();
        for (int i = 0; i < loopVars.length; ++i) {
            SDVariable n = new NextIteration(this, outs[i]).outputVariable();
            mergeOps[i].replaceArg(1, n);
        }
        loopScope.close();
        return this.updateVariableNamesAndReferences(exits, outputNames);
    }

    public Map<String, Variable> getVariables() {
        return this.variables;
    }

    public Map<String, SameDiffOp> getOps() {
        return this.ops;
    }

    public Map<Long, InferenceSession> getSessions() {
        return this.sessions;
    }

    public TrainingConfig getTrainingConfig() {
        return this.trainingConfig;
    }

    public boolean isInitializedTraining() {
        return this.initializedTraining;
    }

    public Map<String, GradientUpdater> getUpdaterMap() {
        return this.updaterMap;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public Stack<ArgumentInterceptor> getArgumentInterceptors() {
        return this.argumentInterceptors;
    }

    public Set<ArgumentInterceptor> getPausedArgumentInterceptors() {
        return this.pausedArgumentInterceptors;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean logExecution) {
        this.logExecution = logExecution;
    }

    public SameDiff getParent() {
        return this.parent;
    }

    public SameDiff getChild() {
        return this.child;
    }
}

