GRAPH_TYPE - the proto type for the graphNODE_TYPE - the proto type for the nodeATTR_TYPE - the proto type for the attributeTENSOR_TYPE - the proto type for the tensorpublic interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE>
SameDiff instances| Modifier and Type | Method and Description |
|---|---|
boolean |
alreadySeen(NODE_TYPE nodeType) |
DataBuffer.Type |
dataTypeForTensor(TENSOR_TYPE tensorType) |
void |
dumpBinaryProtoAsText(File inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
void |
dumpBinaryProtoAsText(InputStream inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
INDArray |
getArrayFrom(NODE_TYPE nodeType,
GRAPH_TYPE graph) |
Map<String,ATTR_TYPE> |
getAttrMap(NODE_TYPE nodeType)
Get the attribute
map for given node
|
String |
getAttrValueFromNode(NODE_TYPE nodeType,
String key) |
String |
getInputFromNode(NODE_TYPE node,
int index)
Get the input node for the given node
|
DifferentialFunction |
getMappedOp(String name)
Get the mapped op name
for a given op
relative to the type of node being mapped.
|
String |
getName(NODE_TYPE nodeType)
Get the name of the node
|
INDArray |
getNDArrayFromTensor(String tensorName,
TENSOR_TYPE tensorType,
GRAPH_TYPE graph) |
com.github.os72.protobuf351.Message.Builder |
getNewGraphBuilder()
Returns a graph builder for initial definition and parsing.
|
List<NODE_TYPE> |
getNodeList(GRAPH_TYPE graphType) |
NODE_TYPE |
getNodeWithNameFromGraph(GRAPH_TYPE graph,
String name)
Get the node from the graph
|
String |
getOpType(NODE_TYPE nodeType) |
int[] |
getShape(NODE_TYPE nodeType) |
int[] |
getShapeFromAttr(ATTR_TYPE attr)
Get the shape of the attribute value
|
int[] |
getShapeFromAttribute(ATTR_TYPE attrType) |
int[] |
getShapeFromTensor(TENSOR_TYPE tensorType)
Get the shape for the given tensor type
|
String |
getTargetMappingForOp(DifferentialFunction function,
NODE_TYPE node)
Get the target mapping key (usually based on the node name)
for the given function
|
boolean |
hasShape(NODE_TYPE nodeType) |
SameDiff |
importGraph(File graphFile)
Import a graph as same diff
from the given file
|
SameDiff |
importGraph(GRAPH_TYPE tfGraph)
This method converts given TF
|
SameDiff |
importGraph(InputStream graphFile)
Import a graph as same diff
from the given file
|
boolean |
isOpIgnoreException(NODE_TYPE node)
Returns true if this node is a special case
(maybe because of name or other scenarios)
that should override
opsToIgnore()
in certain circumstances |
boolean |
isPlaceHolder(TENSOR_TYPE nodeType)
Returns true if the given node is a place holder type
(think a yet to be determined shape)_
|
boolean |
isPlaceHolderNode(TENSOR_TYPE node)
Returns true if the given node is a place holder
|
boolean |
isVariableNode(NODE_TYPE nodeType) |
void |
mapNodeType(NODE_TYPE tfNode,
ImportState<GRAPH_TYPE,TENSOR_TYPE> importState)
Map a node in to the import state covering
the
SameDiff instance |
void |
mapProperties(DifferentialFunction on,
NODE_TYPE node,
GRAPH_TYPE graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappings) |
void |
mapProperty(String name,
DifferentialFunction on,
NODE_TYPE node,
GRAPH_TYPE graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction) |
Map<String,NODE_TYPE> |
nameIndexForGraph(GRAPH_TYPE graph) |
Map<String,NODE_TYPE> |
nodesByName(GRAPH_TYPE graph)
Get the nodes sorted by n ame
from a given graph
|
int |
numInputsFor(NODE_TYPE nodeType)
Get the number of inputs for a node.
|
Set<String> |
opsToIgnore()
Ops to ignore for mapping
|
Op.Type |
opTypeForNode(NODE_TYPE nodeType)
Returns an op type for the given input node
|
GRAPH_TYPE |
parseGraphFrom(byte[] inputStream)
Parse a graph from an input stream
|
GRAPH_TYPE |
parseGraphFrom(InputStream inputStream)
Parse a graph from an input stream
|
boolean |
shouldSkip(NODE_TYPE opType) |
String |
translateToSameDiffName(String name,
NODE_TYPE node) |
boolean |
validTensorDataType(TENSOR_TYPE tensorType)
Whether the data type for the tensor is valid
for creating an
INDArray |
Map<String,TENSOR_TYPE> |
variablesForGraph(GRAPH_TYPE graphType)
Get the variables for the given graph
|
boolean isOpIgnoreException(NODE_TYPE node)
opsToIgnore()
in certain circumstancesnode - the node to checkMap<String,NODE_TYPE> nodesByName(GRAPH_TYPE graph)
graph - the graph to get the nodes forString getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node)
function - the functionnode - the node to derive the target mapping fromvoid mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappings)
on - node - graph - sameDiff - propertyMappings - void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction)
name - on - node - graph - sameDiff - propertyMappingsForFunction - NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph, String name)
graph - the graph to get the node fromname - the name of the node to get from the graphboolean isPlaceHolderNode(TENSOR_TYPE node)
node - the node to checkvoid dumpBinaryProtoAsText(File inputFile, File outputFile)
inputFile - outputFile - void dumpBinaryProtoAsText(InputStream inputFile, File outputFile)
inputFile - outputFile - DifferentialFunction getMappedOp(String name)
name - the tensorflow or onnx nameDifferentialFunctionClassHolderMap<String,TENSOR_TYPE> variablesForGraph(GRAPH_TYPE graphType)
graphType - the graph to get the variables forString translateToSameDiffName(String name, NODE_TYPE node)
name - node - Map<String,NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph)
graph - Op.Type opTypeForNode(NODE_TYPE nodeType)
nodeType - the node to usecom.github.os72.protobuf351.Message.Builder getNewGraphBuilder()
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException
inputStream - the input stream to load fromIOExceptionGRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException
inputStream - the input stream to load fromIOExceptionvoid mapNodeType(NODE_TYPE tfNode, ImportState<GRAPH_TYPE,TENSOR_TYPE> importState)
SameDiff instancetfNode - the node to mapimportState - the current import stateDataBuffer.Type dataTypeForTensor(TENSOR_TYPE tensorType)
tensorType - String getAttrValueFromNode(NODE_TYPE nodeType, String key)
nodeType - key - int[] getShapeFromAttribute(ATTR_TYPE attrType)
attrType - boolean isPlaceHolder(TENSOR_TYPE nodeType)
nodeType - INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph)
tensorName - tensorType - graph - int[] getShapeFromTensor(TENSOR_TYPE tensorType)
tensorType - String getInputFromNode(NODE_TYPE node, int index)
node - the nodeindex - hte indexint numInputsFor(NODE_TYPE nodeType)
nodeType - the node to get the number of inputs forboolean validTensorDataType(TENSOR_TYPE tensorType)
INDArraytensorType - the tensor proto to testint[] getShapeFromAttr(ATTR_TYPE attr)
attr - the attribute valueMap<String,ATTR_TYPE> getAttrMap(NODE_TYPE nodeType)
nodeType - the nodeString getName(NODE_TYPE nodeType)
nodeType - the node
to get the name forboolean alreadySeen(NODE_TYPE nodeType)
nodeType - boolean isVariableNode(NODE_TYPE nodeType)
nodeType - boolean shouldSkip(NODE_TYPE opType)
opType - boolean hasShape(NODE_TYPE nodeType)
nodeType - int[] getShape(NODE_TYPE nodeType)
nodeType - INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph)
nodeType - graph - List<NODE_TYPE> getNodeList(GRAPH_TYPE graphType)
graphType - SameDiff importGraph(InputStream graphFile)
graphFile - SameDiff importGraph(File graphFile)
graphFile - SameDiff importGraph(GRAPH_TYPE tfGraph)
tfGraph - Copyright © 2018. All rights reserved.