/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.graphmapper.tf;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers;
import org.nd4j.imports.tensorflow.TFImportOverride;
import org.nd4j.imports.tensorflow.TFOpImportFilter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.shade.guava.primitives.Floats;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

public class TFGraphMapper {
    private static final Logger log = LoggerFactory.getLogger(TFGraphMapper.class);

    @Deprecated
    public static TFGraphMapper getInstance() {
        return new TFGraphMapper();
    }

    public static SameDiff importGraph(@NonNull File f) {
        if (f == null) {
            throw new NullPointerException("f is marked @NonNull but is null");
        }
        return TFGraphMapper.importGraph(f, null, null);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static SameDiff importGraph(@NonNull File f, Map<String, TFImportOverride> importOverride, TFOpImportFilter opFilter) {
        if (f == null) {
            throw new NullPointerException("f is marked @NonNull but is null");
        }
        Preconditions.checkState((boolean)f.exists(), (String)"File does not exist: %s", (Object)f);
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(f));){
            SameDiff sameDiff = TFGraphMapper.importGraph(is, importOverride, opFilter);
            return sameDiff;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static SameDiff importGraph(@NonNull InputStream is) {
        if (is == null) {
            throw new NullPointerException("is is marked @NonNull but is null");
        }
        return TFGraphMapper.importGraph(is, null, null);
    }

    public static SameDiff importGraphTxt(@NonNull InputStream is, Map<String, TFImportOverride> importOverride, TFOpImportFilter opFilter) {
        GraphDef tfGraph;
        if (is == null) {
            throw new NullPointerException("is is marked @NonNull but is null");
        }
        try {
            GraphDef.Builder builder = GraphDef.newBuilder();
            String content = IOUtils.toString((InputStream)is, (Charset)StandardCharsets.UTF_8);
            TextFormat.getParser().merge((CharSequence)content, (Message.Builder)builder);
            tfGraph = (GraphDef)builder.build();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return TFGraphMapper.importGraph(tfGraph, importOverride, opFilter);
    }

    public static SameDiff importGraph(@NonNull InputStream is, Map<String, TFImportOverride> importOverride, TFOpImportFilter opFilter) {
        GraphDef tfGraph;
        if (is == null) {
            throw new NullPointerException("is is marked @NonNull but is null");
        }
        try {
            tfGraph = GraphDef.parseFrom(is);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return TFGraphMapper.importGraph(tfGraph, importOverride, opFilter);
    }

    public static SameDiff importGraph(@NonNull GraphDef tfGraph) {
        if (tfGraph == null) {
            throw new NullPointerException("tfGraph is marked @NonNull but is null");
        }
        return TFGraphMapper.importGraph(tfGraph, null, null);
    }

    public static SameDiff importGraph(@NonNull GraphDef tfGraph, Map<String, TFImportOverride> importOverride, TFOpImportFilter opFilter) {
        String name;
        if (tfGraph == null) {
            throw new NullPointerException("tfGraph is marked @NonNull but is null");
        }
        HashSet<String> availableToAddSet = new HashSet<String>();
        LinkedList<NodeDef> availableToAdd = new LinkedList<NodeDef>();
        HashMap<String, NodeDef> remainingNodes = new HashMap<String, NodeDef>();
        HashMap nodeInputTo = new HashMap();
        int nNodes = tfGraph.getNodeCount();
        SameDiff sd = SameDiff.create();
        for (int i = 0; i < nNodes; ++i) {
            NodeDef nd = tfGraph.getNode(i);
            String op = nd.getOp();
            name = nd.getName();
            int nInputs = nd.getInputCount();
            if ("Const".equals(op) || "Placeholder".equals(op) || nInputs == 0) {
                availableToAdd.add(nd);
                availableToAddSet.add(name);
                continue;
            }
            remainingNodes.put(name, nd);
            for (int in = 0; in < nInputs; ++in) {
                String inOpName = TFGraphMapper.stripControl(nd.getInput(in));
                if (!nodeInputTo.containsKey(inOpName = TFGraphMapper.stripVarSuffix(inOpName))) {
                    nodeInputTo.put(inOpName, new HashSet());
                }
                ((Set)nodeInputTo.get(inOpName)).add(name);
            }
        }
        HashMap<String, String> mergeOpsPostProcess = new HashMap<String, String>();
        HashMap constControlDeps = new HashMap();
        while (!availableToAdd.isEmpty()) {
            NodeDef nd = (NodeDef)availableToAdd.remove();
            name = nd.getName();
            String opName = nd.getOp();
            int nIn = nd.getInputCount();
            availableToAddSet.remove(name);
            log.trace("Adding operation to graph: {} (name={})", (Object)opName, (Object)name);
            boolean skipCase = false;
            if (opFilter != null && opFilter.skipOp(nd, sd, nd.getAttrMap(), tfGraph)) {
                log.debug("Skipping op {} of type {} due to op filter", (Object)name, (Object)opName);
                skipCase = true;
            } else if (importOverride == null || !importOverride.containsKey(name)) {
                if ("Const".equals(opName)) {
                    TensorProto tfTensor = nd.getAttrOrThrow("value").getTensor();
                    TFTensorMapper<?, ?> m = TFTensorMappers.newMapper(tfTensor);
                    INDArray arr = m.toNDArray();
                    sd.constant(name, arr);
                    int inputCount = nd.getInputCount();
                    if (inputCount > 0) {
                        ArrayList<String> l = new ArrayList<String>(inputCount);
                        for (int i = 0; i < inputCount; ++i) {
                            String n = nd.getInput(i);
                            if (!TFGraphMapper.isControlDep(n)) {
                                throw new IllegalStateException("Found non-control dependency input \"" + n + "\" for constant \"" + name + "\"");
                            }
                            String n2 = TFGraphMapper.stripControl(n);
                            l.add(n2);
                        }
                        constControlDeps.put(name, l);
                    }
                } else if ("Placeholder".equals(opName) || "PlaceholderWithDefault".equals(opName)) {
                    long[] shape;
                    Map<String, AttrValue> attrMap = nd.getAttrMap();
                    boolean shapeAvailable = attrMap.containsKey("shape");
                    if (shapeAvailable) {
                        TensorShapeProto shapeProto = attrMap.get("shape").getShape();
                        shape = TFGraphMapper.shapeFromShapeProto(shapeProto);
                    } else {
                        shape = null;
                    }
                    org.tensorflow.framework.DataType tfDtype = attrMap.get("dtype").getType();
                    DataType dt = TFGraphMapper.convertType(tfDtype);
                    sd.placeHolder(name, dt, shape);
                } else {
                    DifferentialFunction df;
                    DifferentialFunction dfInstance = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
                    Preconditions.checkState((dfInstance != null ? 1 : 0) != 0, (String)"Could not find class for TF Ops: %s", (Object)opName);
                    try {
                        df = (DifferentialFunction)dfInstance.getClass().newInstance();
                    }
                    catch (Throwable t) {
                        throw new RuntimeException(t);
                    }
                    df.setSameDiff(sd);
                    df.setOwnName(name);
                    ArrayList<String> inNames = new ArrayList<String>(nIn);
                    ArrayList<String> controlDeps = null;
                    for (int i = 0; i < nIn; ++i) {
                        Variable v;
                        String origInName = nd.getInput(i);
                        String inName = TFGraphMapper.stripControl(origInName);
                        boolean isControlDep = TFGraphMapper.isControlDep(origInName);
                        if (isControlDep) {
                            if (controlDeps == null) {
                                controlDeps = new ArrayList<String>();
                            }
                            controlDeps.add(inName);
                        }
                        if (!isControlDep) {
                            inNames.add(inName);
                        }
                        if ((v = sd.getVariables().get(inName)) == null && df instanceof Merge) {
                            mergeOpsPostProcess.put(df.getOwnName(), inName);
                            continue;
                        }
                        if (!(isControlDep || v.getInputsForOp() != null && v.getInputsForOp().contains(name))) {
                            if (v.getInputsForOp() == null) {
                                v.setInputsForOp(new ArrayList<String>());
                            }
                            v.getInputsForOp().add(name);
                            continue;
                        }
                        if (!isControlDep) continue;
                        if (v.getControlDepsForOp() == null) {
                            v.setControlDepsForOp(new ArrayList<String>());
                        }
                        if (v.getControlDepsForOp().contains(name)) continue;
                        v.getControlDepsForOp().add(name);
                    }
                    SameDiffOp op = SameDiffOp.builder().name(name).op(df).inputsToOp(inNames).controlDeps(controlDeps).build();
                    sd.getOps().put(name, op);
                    Map<String, AttrValue> attrMap = nd.getAttrMap();
                    df.initFromTensorFlow(nd, sd, attrMap, tfGraph);
                    List<String> newInNames = sd.getOps().get(name).getInputsToOp();
                    ArrayList<DataType> newInDtypes = new ArrayList<DataType>(newInNames.size());
                    if (df instanceof Merge) {
                        SDVariable v1 = sd.getVariable(newInNames.get(0));
                        SDVariable v2 = sd.getVariable(newInNames.get(1));
                        DataType dt1 = v1 == null ? v2.dataType() : v1.dataType();
                        DataType dt2 = v2 == null ? v1.dataType() : v2.dataType();
                        newInDtypes.add(dt1);
                        newInDtypes.add(dt2);
                    } else {
                        for (String s : newInNames) {
                            SDVariable v = sd.getVariable(s);
                            newInDtypes.add(v.dataType());
                        }
                    }
                    List<DataType> outDTypes = df.calculateOutputDataTypes(newInDtypes);
                    SDVariable[] outSDVars = new SDVariable[outDTypes.size()];
                    Variable[] outVars = new Variable[outDTypes.size()];
                    ArrayList<String> outNames = new ArrayList<String>(outDTypes.size());
                    for (int i = 0; i < outDTypes.size(); ++i) {
                        DataType dt = outDTypes.get(i);
                        String varName = name + (i == 0 ? "" : ":" + i);
                        outSDVars[i] = sd.var(varName, VariableType.ARRAY, null, dt, (long[])null);
                        outNames.add(varName);
                        outVars[i] = Variable.builder().name(varName).variable(outSDVars[i]).inputsForOp(null).controlDepsForOp(null).controlDepsForVar(null).outputOfOp(name).build();
                        sd.getVariables().put(varName, outVars[i]);
                        log.trace("Added variable to graph: {} (output of op {})", (Object)varName, (Object)name);
                    }
                    sd.getOps().get(name).setOutputsOfOp(outNames);
                    log.trace("Imported op: {} (name={})", (Object)opName, (Object)name);
                }
            } else {
                TFImportOverride o = importOverride.get(name);
                log.debug("Importing op {} using override {}", (Object)opName, importOverride);
                ArrayList inputs = new ArrayList(nIn);
                ArrayList<SDVariable> controlDeps = null;
                for (int i = 0; i < nIn; ++i) {
                    String inName = nd.getInput(i);
                    boolean controlDep = TFGraphMapper.isControlDep(inName);
                    SDVariable v = sd.getVariable(name);
                    if (controlDep) {
                        if (controlDeps == null) {
                            controlDeps = new ArrayList<SDVariable>();
                        }
                        controlDeps.add(v);
                    } else {
                        inputs.add(v);
                    }
                    o.initFromTensorFlow(inputs, controlDeps, nd, sd, nd.getAttrMap(), tfGraph);
                }
            }
            if (nodeInputTo.containsKey(name)) {
                Set set = (Set)nodeInputTo.get(name);
                for (String nextOp : set) {
                    boolean mergeCase;
                    NodeDef nextOpDef = (NodeDef)remainingNodes.get(nextOp);
                    if (nextOpDef == null) {
                        if (sd.getOps().containsKey(nextOp)) continue;
                        throw new IllegalStateException("Could not find op definition for op to import: " + nextOp);
                    }
                    int nInNext = nextOpDef.getInputCount();
                    boolean allAlreadyInGraph = true;
                    int nonControlSeenCount = 0;
                    for (int i = 0; i < nInNext; ++i) {
                        String s = nextOpDef.getInput(i);
                        String inName = TFGraphMapper.stripControl(nextOpDef.getInput(i));
                        if (!sd.hasVariable(inName) && !skipCase) {
                            allAlreadyInGraph = false;
                            break;
                        }
                        if (TFGraphMapper.isControlDep(s)) continue;
                        ++nonControlSeenCount;
                    }
                    boolean bl = mergeCase = nonControlSeenCount > 0 && "Merge".equals(nextOpDef.getOp());
                    if (!allAlreadyInGraph && !mergeCase || availableToAddSet.contains(nextOp)) continue;
                    availableToAdd.add(nextOpDef);
                    availableToAddSet.add(nextOp);
                    log.trace("Added to processing queue: {} (name={})", (Object)nextOpDef.getOp(), (Object)nextOp);
                }
            }
            remainingNodes.remove(name);
        }
        for (Map.Entry e : constControlDeps.entrySet()) {
            String varName = (String)e.getKey();
            List cdOpNames = (List)e.getValue();
            sd.getVariables().get(varName).setControlDeps(cdOpNames);
            for (String s : cdOpNames) {
                List<String> l;
                SameDiffOp sdo = sd.getOps().get(s);
                if (sdo.getControlDepFor() == null) {
                    sdo.setControlDepFor(new ArrayList<String>());
                }
                if ((l = sdo.getControlDepFor()).contains(s)) continue;
                l.add(varName);
            }
        }
        for (Map.Entry e : mergeOpsPostProcess.entrySet()) {
            Variable v = sd.getVariables().get(e.getValue());
            if (v.getInputsForOp() == null) {
                v.setInputsForOp(new ArrayList<String>());
            }
            v.getInputsForOp().add((String)e.getKey());
        }
        Preconditions.checkState((boolean)remainingNodes.isEmpty(), (String)"%s Unprocessed nodes: %s", (Object)remainingNodes.size(), remainingNodes.keySet());
        return sd;
    }

    private static long[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
        long[] shape = new long[tensorShapeProto.getDimList().size()];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = tensorShapeProto.getDim(i).getSize();
        }
        return shape;
    }

    public static DataType convertType(org.tensorflow.framework.DataType tfType) {
        switch (tfType) {
            case DT_DOUBLE: {
                return DataType.DOUBLE;
            }
            case DT_FLOAT: {
                return DataType.FLOAT;
            }
            case DT_HALF: {
                return DataType.HALF;
            }
            case DT_BFLOAT16: {
                return DataType.BFLOAT16;
            }
            case DT_INT8: {
                return DataType.BYTE;
            }
            case DT_INT16: {
                return DataType.SHORT;
            }
            case DT_INT32: {
                return DataType.INT;
            }
            case DT_INT64: {
                return DataType.LONG;
            }
            case DT_UINT8: {
                return DataType.UBYTE;
            }
            case DT_STRING: {
                return DataType.UTF8;
            }
            case DT_BOOL: {
                return DataType.BOOL;
            }
        }
        return DataType.UNKNOWN;
    }

    protected static boolean isControlDep(String name) {
        return name.startsWith("^");
    }

    protected static String stripControl(String name) {
        if (name.startsWith("^")) {
            return name.substring(1);
        }
        return name;
    }

    protected static String stripVarSuffix(String varName) {
        if (varName.matches(".*:\\d+")) {
            int idx = varName.lastIndexOf(58);
            String ret = varName.substring(0, idx);
            return ret;
        }
        return varName;
    }

    public static INDArray getNDArrayFromTensor(NodeDef node) {
        if (!node.getAttrMap().containsKey("value")) {
            return null;
        }
        TensorProto tfTensor = node.getAttrOrThrow("value").getTensor();
        INDArray out = TFGraphMapper.mapTensorProto(tfTensor);
        return out;
    }

    public static INDArray mapTensorProto(TensorProto tfTensor) {
        TFTensorMapper<?, ?> m = TFTensorMappers.newMapper(tfTensor);
        if (m == null) {
            throw new RuntimeException("Not implemented datatype: " + (Object)((Object)tfTensor.getDtype()));
        }
        INDArray out = m.toNDArray();
        return out;
    }

    @Deprecated
    public static NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) {
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            NodeDef node = graph.getNode(i);
            if (!node.getName().equals(name)) continue;
            return node;
        }
        return null;
    }

    @Deprecated
    public static INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) {
        if (nodeDef == null) {
            return null;
        }
        return TFGraphMapper.getNDArrayFromTensor(nodeDef);
    }

    @Deprecated
    public static void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
        Map<String, PropertyMapping> map;
        Map<String, Map<String, PropertyMapping>> properties = on.mappingsForFunction();
        Map<String, PropertyMapping> tfProperties = properties.get(mappedTfName);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        Map<String, Map<String, AttributeAdapter>> attributeAdapters = on.attributeAdaptersForFunction();
        if (tfProperties == null) {
            return;
        }
        if (attributeAdapters == null || !attributeAdapters.containsKey(mappedTfName)) {
            map = tfProperties;
        } else {
            map = new LinkedHashMap<String, PropertyMapping>();
            for (Map.Entry<String, PropertyMapping> e : tfProperties.entrySet()) {
                if (attributeAdapters.get(mappedTfName).containsKey(e.getKey())) continue;
                map.put(e.getKey(), e.getValue());
            }
            for (Map.Entry<String, PropertyMapping> e : tfProperties.entrySet()) {
                if (map.containsKey(e.getKey())) continue;
                map.put(e.getKey(), e.getValue());
            }
        }
        for (Map.Entry<String, PropertyMapping> entry : map.entrySet()) {
            NodeDef inputFromNode;
            INDArray tensor;
            String tfAttrName = entry.getValue().getTfAttrName();
            Field currentField = fields.get(entry.getKey());
            AttributeAdapter adapter = null;
            if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
                AttributeAdapter adapterFor;
                Map<String, AttributeAdapter> mappers = attributeAdapters.get(mappedTfName);
                adapter = adapterFor = mappers.get(entry.getKey());
            }
            if (tfAttrName != null) {
                if (currentField == null || !attributesForNode.containsKey(tfAttrName)) continue;
                AttrValue attr = attributesForNode.get(tfAttrName);
                switch (attr.getValueCase()) {
                    case B: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor(attr.getB(), currentField, on);
                        break;
                    }
                    case F: {
                        break;
                    }
                    case FUNC: {
                        break;
                    }
                    case S: {
                        String setString = attr.getS().toStringUtf8();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setString, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setString);
                        break;
                    }
                    case I: {
                        int setInt = (int)attr.getI();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setInt, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setInt);
                        break;
                    }
                    case SHAPE: {
                        List<TensorShapeProto.Dim> shape = attr.getShape().getDimList();
                        int[] dimsToSet = new int[shape.size()];
                        for (int i = 0; i < dimsToSet.length; ++i) {
                            dimsToSet[i] = (int)shape.get(i).getSize();
                        }
                        if (adapter != null) {
                            adapter.mapAttributeFor(dimsToSet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, dimsToSet);
                        break;
                    }
                    case VALUE_NOT_SET: {
                        break;
                    }
                    case PLACEHOLDER: {
                        break;
                    }
                    case LIST: {
                        AttrValue.ListValue setList = attr.getList();
                        if (!setList.getIList().isEmpty()) {
                            int[] intList = Ints.toArray(setList.getIList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(intList, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, intList);
                            break;
                        }
                        if (!setList.getBList().isEmpty()) break;
                        if (!setList.getFList().isEmpty()) {
                            float[] floats = Floats.toArray(setList.getFList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(floats, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, floats);
                            break;
                        }
                        if (!setList.getFuncList().isEmpty() || setList.getTensorList().isEmpty()) break;
                        break;
                    }
                    case TENSOR: {
                        INDArray tensorToGet = TFGraphMapper.mapTensorProto(attr.getTensor());
                        if (adapter != null) {
                            adapter.mapAttributeFor(tensorToGet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, tensorToGet);
                        break;
                    }
                    case TYPE: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor((Object)attr.getType(), currentField, on);
                    }
                }
                continue;
            }
            if (entry.getValue().getTfInputPosition() == null) continue;
            int position = entry.getValue().getTfInputPosition();
            if (position < 0) {
                position += node.getInputCount();
            }
            INDArray iNDArray = tensor = (inputFromNode = TFGraphMapper.getNodeWithNameFromGraph(graph, node.getInput(position))) != null ? TFGraphMapper.getNDArrayFromTensor(inputFromNode) : null;
            if (tensor == null) {
                tensor = on.getSameDiff().getArrForVarName(TFGraphMapper.getNodeName(node.getInput(position)));
            }
            if (tensor == null) continue;
            if (adapter != null) {
                adapter.mapAttributeFor(tensor, currentField, on);
                continue;
            }
            if (currentField.getType().equals(int[].class)) {
                on.setValueFor(currentField, tensor.data().asInt());
                continue;
            }
            if (currentField.getType().equals(double[].class)) {
                on.setValueFor(currentField, tensor.data().asDouble());
                continue;
            }
            if (currentField.getType().equals(float[].class)) {
                on.setValueFor(currentField, tensor.data().asFloat());
                continue;
            }
            if (currentField.getType().equals(INDArray.class)) {
                on.setValueFor(currentField, tensor);
                continue;
            }
            if (currentField.getType().equals(Integer.TYPE)) {
                on.setValueFor(currentField, tensor.getInt(0));
                continue;
            }
            if (currentField.getType().equals(Double.TYPE)) {
                on.setValueFor(currentField, tensor.getDouble(0L));
                continue;
            }
            if (!currentField.getType().equals(Float.TYPE)) continue;
            on.setValueFor(currentField, Float.valueOf(tensor.getFloat(0L)));
        }
    }

    @Deprecated
    public static String getNodeName(String name) {
        String ret = name;
        if (ret.startsWith("^")) {
            ret = ret.substring(1);
        }
        if (ret.endsWith("/read")) {
            ret = ret.replace("/read", "");
        }
        if (ret.endsWith(":0")) {
            ret = ret.substring(0, ret.length() - 2);
        }
        return ret;
    }

    public static boolean isVariableNode(NodeDef nodeDef) {
        boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
        return isVar;
    }

    public static boolean isPlaceHolder(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Placeholder");
    }
}

