anijain2305 commented on a change in pull request #4944: [Relay, Torch] Clean
up and refactor PyTorch frontend
URL: https://github.com/apache/incubator-tvm/pull/4944#discussion_r384674817
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -718,293 +732,241 @@ def _convert_elemwise_input(data, input_type):
"aten::sqrt" : _sqrt()
}
-# Internal graph for parsing
-class Graph(object):
- """ A helper class for parsing PyTorch model to Relay graph."""
+def run_jit_passes(graph):
+ """ The inline pass is nessary to unwrap prim::CallMethod """
+ import torch
+ if version.parse(torch.__version__) >= version.parse("1.4.0"):
+ torch._C._jit_pass_inline(graph)
- def __init__(self, script_module, input_shapes):
- self._script_module = script_module
- self._graph = script_module.graph.copy()
+def is_int_seq(seq):
+ return len(seq) > 0 and all([isinstance(i, int) for i in seq])
- # TODO: Temporary fix to remove prim::CallMethod node introduced in PT
1.4
- import torch
- from packaging import version
- if version.parse(torch.__version__) >= version.parse("1.4.0"):
- torch._C._jit_pass_inline(self._graph)
-
- self._inputs_r = {}
- self._params = {}
- self._param_tensors = {}
- self._consts = {}
- self._ops = {}
- self._op_inputs_r = {}
- self._op_inputs_types = {}
- self._input_shapes = input_shapes if input_shapes else {}
- self._parsed_node_names = {}
-
- def from_pytorch(self):
- """ Construct relay nodes from PyTorch graph
-
- Currently only supports traced PyTorch format which means no control
flow.
- User must perform torch.jit.trace on a model and pass this in.
- Future support should include support scripted models
(torch.jit.script) which
- preserves control flow.
-
- Returns
- -------
- mod : tvm.relay.Module
- The module that optimizations will be performed on.
-
- params : dict of str to tvm.runtime
- Dict of converted parameters stored in tvm.runtime format
- """
- # Check for missing ops
- missing_operators = self._parse_import_prerequisites()
-
- if missing_operators:
- raise tvm.error.OpNotImplemented( \
- "The following operators are not implemented:
{}".format(missing_operators))
-
- # Translate PyTorch graph to by decorating Graph with state dict and
inputs into each op
- self._parse_inputs()
- self._parse_params()
- self._parse_ops()
-
- outputs = []
- nid = 0
-
- for op_name, op_node in self._ops.items():
- if op_node.kind() == "prim::ListConstruct":
- if any(inp.debugName() in self._parsed_node_names.keys() \
- for inp in op_node.inputs()):
- list_constr = []
- for i in op_node.inputs():
- if i.debugName() in self._parsed_node_names.keys():
- list_constr.append( \
-
outputs[self._parsed_node_names[i.debugName()]])
- elif i.node().kind() == "prim::Constant":
-
list_constr.append(int(self._consts[i.debugName()]))
- elif i.debugName() in self._inputs_r.keys():
-
list_constr.append(int(self._inputs_r[i.debugName()]))
-
- # Unwrap for tensors
- if len(list_constr) == 1:
- list_constr = list_constr[0]
-
- outputs.append(list_constr)
- self._parsed_node_names[op_name] = nid
- nid = nid+1
- elif op_node.kind() != "prim::Constant":
- for i in op_node.inputs():
- if i.debugName() in self._parsed_node_names.keys():
- for cnt in range(0, len(self._op_inputs_r[op_name])):
- if isinstance(self._op_inputs_r[op_name][cnt],
str):
- if "call/var" in
self._op_inputs_r[op_name][cnt]:
- self._op_inputs_r[op_name][cnt] = \
-
outputs[self._parsed_node_names[i.debugName()]]
- break
-
- call = _convert_map[op_node.kind()](self._op_inputs_r[op_name],
-
self._op_inputs_types[op_name])
-
- outputs.append(call)
- self._parsed_node_names[op_name] = nid
- nid = nid+1
-
- func = tvm.relay.Function(_analysis.free_vars(outputs[-1]),
outputs[-1])
-
- param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()}
-
- return _module.IRModule.from_expr(func), param
-
- def _parse_inputs(self):
- """ Map inputs to parser and inputs to graph. """
- # Get names and objects of inputs for IR
- ir_inputs = [i for i in self._graph.inputs()]
-
- # Create corresponding shape and add to input
- for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]):
- input_shape = self._input_shapes[input_name]
- ir_input.setDebugName(input_name)
-
- ir_dtype = _convert_data_type(ir_input.type().scalarType().lower())
- self._inputs_r[input_name] = _expr.var(input_name,
-
shape=self._input_shapes[input_name],
- dtype=ir_dtype)
-
- # Add self (first input of a PyTorch graph) to inputs, the value
doesn't matter here
- input_name = ir_inputs[0].debugName()
- self._inputs_r[input_name] = "self"
-
- def _parse_params(self):
- """ Map state dictionary values to corresponding prim::GetAttr op
node. """
- # Grab weights, biases, etc. from graph
- state_dict = self._script_module.state_dict()
- param_names = []
- for key, value in state_dict.items():
- param_str = str(key)
- param_name = param_str.split(".")[-1]
- param_names.append(param_name)
-
- # Get names of all inputs
- input_names = [i for i in self._inputs_r.keys()]
-
- # Iterate through graph for getAttr nodes and match full state_dict
name to nodes
- node_weight_map = {}
- for node in self._graph.nodes():
- if node.kind() == "prim::GetAttr":
-
- attribute_names = node.attributeNames()
- assert len(attribute_names) == 1
- node_getattr_name = node.s(attribute_names[0])
- node_arg = node.input().debugName()
-
- if node.outputsSize() == 1:
- node_name = node.output().debugName()
- else:
- node_name = [output.debugName() for output in
node.outputs()][0]
-
- if node_arg in input_names:
- node_weight_map[node_name] = node_getattr_name
- else:
- previous_map = node_weight_map[node_arg[:]]
- node_weight_map[node_name] =
previous_map+"."+node_getattr_name
-
- if node_getattr_name in param_names:
-
- value = state_dict[node_weight_map[node_name]]
- tensor = tvm.nd.array(value.cpu().numpy())
- shape = tensor.shape
- self._param_tensors[node_name] = tensor
-
- self._params[node_name] = _expr.var(node_name,
- shape=shape,
-
dtype=_convert_data_type(str(value.dtype)))
-
- def _parse_ops(self):
- """ Iterate through nodes and decorate graph with constants, operators,
- and the inputs to each operator. """
- # Traverse nodes and add to graph
- for node in self._graph.nodes():
-
- if node.outputsSize() == 1:
- node_name = node.output().debugName()
- else:
- node_name = [output.debugName() for output in
node.outputs()][0]
-
- if node.kind() == "prim::Constant":
- if node.hasAttributes():
- attribute_names = node.attributeNames()
- attr_name = attribute_names[0]
- ty = node.output().type().kind()
-
- if ty in ["IntType", "BoolType"]:
- self._consts[node_name] = node.i(attr_name)
- elif ty in ["FloatType", "LongType"]:
- self._consts[node_name] = node.f(attr_name)
- elif ty in ["TensorType", "CompleteTensorType"]:
- self._consts[node_name] = node.output().toIValue()
- else:
- self._consts[node_name] = "0"
- else:
- self._consts[node_name] = "0"
- elif node.kind() == "prim::ListConstruct":
- list_shape = []
- for input_node in node.inputs():
- if input_node.debugName() in self._inputs_r.keys():
- c = self._inputs_r[input_node.debugName()]
- assert isinstance(c, int)
- list_shape.append(c)
- elif input_node.debugName() in self._consts.keys():
- c = self._consts[input_node.debugName()]
- assert isinstance(c, int)
- list_shape.append(c)
- self._inputs_r[node_name] = _expr.var(node_name,
shape=list_shape)
-
- if node.kind() != "prim::GetAttr":
- self._add_op(node_name, node)
-
- # Graph Helper Functions
-
- def _add_op(self, node_id, op_node):
- """ Add an operator and its operators inputs to the graph and insert
placeholders
- where an input is a call node.
-
- Parameters
- ----------
- node_id : string
- The ID of the op node
-
- op_node : PyTorch Node object
- The full Node object for the op node
-
- """
- self._ops[(node_id)] = op_node
- input_list_r = []
- input_list_types = []
- for input_value in op_node.inputs():
-
- inode_id = input_value.debugName()
- inode = input_value.node()
-
- if inode_id in self._inputs_r.keys():
- input_list_r.append(self._inputs_r[inode_id])
- elif inode_id in self._params.keys():
- input_list_r.append(self._params[inode_id])
- elif inode.kind() == "prim::Constant":
- input_list_r.append(self._consts[inode_id])
+
+def get_tensor_and_var(torch_tensor, name):
+ tensor = tvm.nd.array(torch_tensor.cpu().numpy())
+ var = _expr.var(name, shape=tensor.shape)
+ return tensor, var
+
+
+def get_output_name(node):
+ assert node.outputsSize() == 1
+ return node.output().debugName()
+
+
+def get_output_names(node):
+ return [output.debugName() for output in node.outputs()]
+
+
+def get_input_names(node_or_graph):
+ return [inp.debugName() for inp in node_or_graph.inputs()]
+
+
+def get_op_inputs(op_node, outputs, output_index_map):
+ input_names = [output_index_map[name]
+ for name in get_input_names(op_node)]
+ return [outputs[name] for name in input_names]
+
+
+def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
+ for output_name, output in name_output_pairs:
+ output_index_map[output_name] = len(outputs)
+ outputs.append(output)
+
+
+def get_all_op_names(graph):
+ nodes = list(graph.nodes())
+ return set(node.kind() for node in nodes)
+
+
+def report_missing_conversion(op_names):
+ """ Check if all ops in an input graph are supported by TVM """
+ known_ops = ["prim::Constant", "prim::GetAttr",
+ "prim::ListConstruct", "prim::ListUnpack",
+ "prim::TupleConstruct", "prim::TupleUnpack"]
+ known_ops += list(_convert_map.keys())
+
+ missing = [op_name for op_name in op_names
+ if op_name not in known_ops]
+
+ if missing:
+ msg = "The following operators are not implemented: {}".format(missing)
+ raise NotImplementedError(msg)
+
+
+def getattr_attr_name(node):
+ attribute_names = node.attributeNames()
+ assert len(attribute_names) == 1
+ attr_name = node.s(attribute_names[0])
+ return attr_name
+
+
+def get_full_attr_name(getattrs):
+ return ".".join([getattr_attr_name(node) for node in getattrs])
+
+
+def get_use_chains(root_node, terminate=lambda _: False):
+ """
+ Track a chain of users of this node forward, returning a list of chains
+ See get_attr_chains below for its usage
+ """
+ def concat_lists(lists):
+ return itertools.chain.from_iterable(lists)
+
+ def inner(current, accum):
+ users = []
+ for output in current.outputs():
+ users += [use.user for use in output.uses()]
+
+ if not users or terminate(users):
+ return [accum]
+
+ return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
+
+ return inner(root_node, [root_node])
+
+
+def get_attr_chains(root_getattr_node):
+ """ Returns chains of attribute access starting from root_getattr_node
+
+ For example, given attribute "block", as in "self.block" when "self" points
+ to the top level torch.nn.Module, it returns lists of attribute "chains",
+ e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
+
+ These sets of attributes form full attribute accessors. For example,
+ "self.block.1", "self.block.2" will return the second and third submodule,
+ and "self.block.0._packed_params" will return the parameters of the first
+ submodule.
+ """
+ def terminate(users):
+ next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
+ return len(next_attrs) == 0
+
+ return get_use_chains(root_getattr_node, terminate)
+
+
+def get_input_types(op_node):
+ """ Returns a torch type for each input nodes """
+ input_list_types = []
+ for input_node in op_node.inputs():
+ in_ty = input_node.type()
+ input_node_kind = in_ty.kind()
+ if input_node_kind == 'TensorType':
+ if in_ty.scalarType() is None:
+ input_list_types.append('float')
else:
- input_list_r.append("call/var."+inode_id)
-
- # If the inputs of a ListConstruct op is a call or var, remove
it from inputs
- if op_node.kind() == "prim::ListConstruct":
- if node_id in self._inputs_r.keys():
- self._inputs_r.pop(node_id)
-
- try:
- input_value_kind = input_value.type().kind()
- if input_value_kind in ["TensorType", "CompleteTensorType"]:
- if input_value.type().scalarType() is None:
- input_list_types.append("float")
- else:
-
input_list_types.append(input_value.type().scalarType().lower())
- elif input_value_kind == "ListType":
-
input_list_types.append(str(input_value.type().getElementType()).lower())
- elif input_value_kind in ["IntType", "FloatType", "BoolType",
"StringType",
- "OptionalType"]:
- input_list_types.append(str(input_value.type()).lower())
- else:
- input_list_types.append("UnsupportedType")
- print("UnsupportedType "+str(input_value.type())+" and
"+str(input_value_kind))
- except Exception as e:
- print("Internal PyTorch error. Failed to grab type.")
-
- if op_node.kind() in ["aten::ones", "aten::zeros"]:
- node_type = op_node.output().type().scalarType()
- input_list_types[0] = node_type.lower()
-
- self._op_inputs_r[node_id] = input_list_r
- self._op_inputs_types[node_id] = input_list_types
-
- def _parse_import_prerequisites(self):
- """ Calculate the named preconditions from PyTorch graph.
-
- Returns
- -------
- missing_operators : set object
- Set of operator names which don't have their mapping in TVM
- i.e. which are not supported
-
- """
- missing_operators = set()
- for node in self._graph.nodes():
- if not node.kind() in ["prim::Constant", "prim::ListConstruct",
"prim::GetAttr"] \
- and not node.kind() in _convert_map:
- missing_operators.add(node.kind())
-
- return missing_operators
+ input_list_types.append(in_ty.scalarType().lower())
+ elif input_node_kind == 'ListType':
+ input_list_types.append(str(in_ty.getElementType()).lower())
+ elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
+ 'StringType', 'OptionalType']:
+ input_list_types.append(str(in_ty).lower())
+ else:
+ input_list_types.append('UnsupportedType')
+
+ if op_node.kind() in ['aten::ones', 'aten::zeros']:
+ node_type = op_node.output().type()
+ scalar_type = node_type.scalarType()
+ if scalar_type:
+ input_list_types[0] = scalar_type.lower()
+
+ return input_list_types
+
+
+def get_constant(node):
+ """ Retrive a constant associated with this prim::Constant node """
Review comment:
Retrieve
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services