This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e9cf04e0e4 [Relay][Frontend] Span Filling PyTorch (#14050)
e9cf04e0e4 is described below

commit e9cf04e0e4ec325ca665fe2d262b481985c8cf42
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Mar 1 20:02:00 2023 +0800

    [Relay][Frontend] Span Filling PyTorch (#14050)
    
    * [Relay][Frontend] Span Filling PyTorch
    
    - Construct debug name of C graph instruction as the source name of span 
for pytorch model.
    - To get the reference of renamed nodes. Add a function to export the 
converted C graph after conversion.
    - Add structural_equal comparisons with and without set_span to the 
existing test cases.
    - Add span test cases for frequent conversions.
    - Add span test case for exporting model parameter.
    
    * [SpanFillingPyTorch]
    
    - Return TupleGetItem expr from TupleWrapper with the span of its Tuple.
    - Add None type symbol in set sapn for certain conversion.
    - Add current_op member varible to PyTorchOpConverter to track which op
      is converting for pytorch frontend.
    
    * [SpanFillingPyTorch]
    
    - Fix the error caused by the quantized params not found after renaming
      the debug name of C graph.
    
    ---------
    
    Co-authored-by: Joey Tsai <[email protected]>
---
 python/tvm/relay/expr.py                           |   2 +-
 python/tvm/relay/frontend/common.py                |   4 +
 python/tvm/relay/frontend/pytorch.py               | 221 +++++++++++++---
 python/tvm/relay/frontend/qnn_torch.py             |   4 +-
 tests/python/frontend/pytorch/qnn_test.py          |  24 +-
 tests/python/frontend/pytorch/test_forward.py      | 284 ++++++++++++++++++++-
 tests/python/frontend/pytorch/test_fx_quant.py     |   7 +-
 tests/python/frontend/pytorch/test_lstm.py         |   6 +-
 .../frontend/pytorch/test_object_detection.py      |   6 +-
 tests/python/frontend/pytorch/test_rnns.py         |  16 +-
 10 files changed, 522 insertions(+), 52 deletions(-)

diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index cb14552ac1..d8bca5c4a4 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -605,7 +605,7 @@ class TupleWrapper(object):
     def __getitem__(self, index):
         if index >= len(self):
             raise IndexError("Tuple index out of range")
-        return TupleGetItem(self.tuple_value, index)
+        return TupleGetItem(self.tuple_value, index, 
span=self.tuple_value.span)
 
     def __len__(self):
         return self.size
diff --git a/python/tvm/relay/frontend/common.py 
b/python/tvm/relay/frontend/common.py
index 5d3b0a3345..39e17b27da 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -1169,6 +1169,10 @@ class _SpanFiller(ExprMutator):
             return sym
         elif isinstance(sym, np.ndarray):
             return sym
+        elif not sym:
+            # some op conversion may return None
+            # e.g. op in frontend/pytorch.py: prim::device
+            return sym
 
         raise RuntimeError(f"unsupported type {type(sym)}")
 
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 3cdfc5cb4e..89464face7 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -22,6 +22,7 @@
 import functools
 import itertools
 import math
+import re
 import sys
 
 import numpy as np
@@ -44,6 +45,7 @@ from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
 from .common import lstm_cell, try_infer_value, unbind, fold_constant
+from .common import set_span
 from .pytorch_utils import is_version_greater_than, getattr_attr_name
 
 __all__ = ["from_pytorch"]
@@ -135,11 +137,15 @@ def _is_int_seq(seq):
 class PyTorchOpConverter:
     """A helper class for holding PyTorch op converters."""
 
-    def __init__(self, prelude, default_dtype):
+    def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
         self.prelude = prelude
         self.default_dtype = default_dtype
         self.create_convert_map()
         self.types = {}  # map from nodes to (Relay) type annotations
+        self.source_map = {}  # map from graph node to its source name
+        self.op_type_dict = {}  # map from op type to its presenting order
+        self.current_op = []  # stack for recording current processing op
+        self.use_parser_friendly_name = use_parser_friendly_name
 
     # this incrementally infers the type, see the comments on the type visitor
     # above.
@@ -344,7 +350,10 @@ class PyTorchOpConverter:
         def _get_value(val, dtype):
             # dtype is a tvm dtype
             if isinstance(val, _expr.Expr):
-                inp = _op.cast(val, dtype)
+                # since "arange" op will fill expr into its attribute
+                # invoke set_span here to prevent expr-rewritten occurrs in 
span-filling stage
+                source_name = self.source_map[self.current_op[-1]]
+                inp = set_span(_op.cast(val, dtype), source_name)
                 ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, 
dtype))
             else:
                 ret = _create_typed_const(val, dtype)
@@ -2405,11 +2414,16 @@ class PyTorchOpConverter:
         iou_threshold = inputs[2]
 
         # TVM NMS assumes score > 0
-        scores = scores - _op.min(scores) + _op.const(1.0)
+        # - since there exists multi-comsumers for "scores", "num_boxes"
+        # - invoke set_span here to prevent expr-rewritten occurrs in 
span-filling stage
+        source_name = self.source_map[self.current_op[-1]]
+        scores = set_span(scores - _op.min(scores) + _op.const(1.0), 
source_name)
 
-        num_boxes = _op.shape_of(scores)
+        num_boxes = set_span(_op.shape_of(scores), source_name)
         # PyTorch NMS doesn't have score_threshold, so no need to run 
get_valid_count
-        indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
+        # - since "arange" op will fill expr into its attribute
+        # - invoke set_span here to prevent expr-rewritten occurrs in 
span-filling stage
+        indices = _op.transform.arange(set_span(_op.squeeze(num_boxes), 
source_name), dtype="int32")
         indices = _op.expand_dims(indices, 0, 1)
 
         # Generate data with shape (1, num_anchors, 5)
@@ -4008,7 +4022,12 @@ class PyTorchOpConverter:
 
     def convert_block(self, block, outputs):
         """Translate Torch "Block", used for prim::If and prim::Loop"""
-        ops = _get_operator_nodes(block.nodes())
+        ops = _get_operator_nodes(
+            block.nodes(),
+            self.source_map,
+            self.op_type_dict,
+            self.use_parser_friendly_name,
+        )
         ret_names = _get_input_names(block.returnNode())
         return self.convert_operators(ops, outputs, ret_names)
 
@@ -4079,13 +4098,19 @@ class PyTorchOpConverter:
                             actual_shape.append(Any())
                         else:
                             actual_shape.append(dim)
-                    return _expr.var(name, shape=actual_shape, 
dtype=checked_type.dtype)
+                    expr = _expr.var(name, shape=actual_shape, 
dtype=checked_type.dtype)
                 else:
-                    return _expr.var(name, type_annotation=checked_type)
+                    expr = _expr.var(name, type_annotation=checked_type)
+                return set_span(expr, val.span) if val.span else expr
             return _expr.var(name)
 
-        loop_iter_var = _expr.var(block_input_names[0], shape=(), 
dtype=loop_iter_dtype)
-        loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+        source_name = self.source_map[loop_node]
+        loop_iter_var = set_span(
+            _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype), 
span=source_name
+        )
+        loop_vars = set_span(
+            [get_var(name, val) for name, val in name_val_pairs[1:]], 
span=source_name
+        )
 
         # Add non constant free variables to loop variables to prevent code 
blow up
         # Without this, if there are two for loops in a row, which often 
happens
@@ -4108,7 +4133,7 @@ class PyTorchOpConverter:
             prev_output = outputs[name]
             new_loop_var = get_var(name, prev_output)
             prev_outputs[name] = prev_output
-            outputs[name] = new_loop_var
+            outputs[name] = set_span(new_loop_var, source_name)
             loop_vars.append(new_loop_var)
             init_vals.append(prev_output)
 
@@ -4156,11 +4181,17 @@ class PyTorchOpConverter:
         for node_name, op_node in operators:
             operator = op_node.kind()
             inputs = _get_op_inputs(op_node, outputs)
+            # we need to record what current operator is to provide correct 
source name
+            # for operators needed to be taken care with (e.g. nms / arange 
...)
+            self.current_op.append(op_node)
 
             if operator == "prim::Constant":
                 outputs[node_name] = _get_constant(op_node)
             elif operator == "prim::ListConstruct" and 
_should_construct_dynamic_list(op_node):
-                outputs[node_name] = self.convert_to_list_adt(inputs)
+                outputs[node_name] = set_span(
+                    self.convert_to_list_adt(inputs),
+                    self.source_map[op_node],
+                )
             elif operator == "prim::ListConstruct":
                 # This assumes that no more elements will be appended to this 
list
                 # In this case, we keep the Python list
@@ -4177,25 +4208,30 @@ class PyTorchOpConverter:
                             inputs_list.append(inputs[i])
                     return _expr.Tuple(inputs_list)
 
-                outputs[node_name] = _handel_nested_input(inputs)
+                outputs[node_name] = set_span(
+                    _handel_nested_input(inputs),
+                    self.source_map[op_node],
+                )
             elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
                 assert len(inputs) == 1
                 if isinstance(inputs[0], (list, _expr.TupleWrapper)):
                     unpacked = inputs[0]
                 else:
                     unpacked = _unpack_tuple(inputs[0])
-                outputs.update(zip(_get_output_names(op_node), unpacked))
+                outputs.update(
+                    zip(_get_output_names(op_node), set_span(unpacked, 
self.source_map[op_node]))
+                )
             elif operator == "prim::prim::RaiseException":
                 logger.warning("raising exceptions is ignored")
                 outputs[node_name] = None
             elif operator == "prim::If":
                 if_out = self.convert_if(op_node, outputs)
-                outputs[node_name] = if_out
+                outputs[node_name] = set_span(if_out, self.source_map[op_node])
             elif operator == "prim::Loop":
                 loop_out = self.convert_loop(op_node, outputs)
                 unpacked_names = _get_output_names(op_node)
                 assert len(loop_out) == len(unpacked_names)
-                outputs.update(zip(unpacked_names, loop_out))
+                outputs.update(zip(unpacked_names, set_span(loop_out, 
self.source_map[op_node])))
             else:
                 if operator not in self.convert_map:
                     # At this point, the only possible ops that are not in 
convert_map are
@@ -4210,9 +4246,14 @@ class PyTorchOpConverter:
                 else:
                     relay_op = self.convert_map[operator]
 
+                self._set_parameter_source_name(op_node, outputs)
                 relay_out = relay_op(
-                    inputs, _get_input_types(op_node, outputs, 
default_dtype=self.default_dtype)
+                    # since the elements in "outputs" may change due to 
span-filling process
+                    # we have to call "_get_op_inputs" again rather than use 
"inputs" directly
+                    _get_op_inputs(op_node, outputs),
+                    _get_input_types(op_node, outputs, 
default_dtype=self.default_dtype),
                 )
+                relay_out = set_span(relay_out, self.source_map[op_node])
                 self.record_output_type(relay_out)
 
                 if isinstance(relay_out, tuple):
@@ -4224,8 +4265,28 @@ class PyTorchOpConverter:
                     assert op_node.outputsSize() == 1
                     outputs[node_name] = relay_out
 
+            self.current_op.pop()
+
         return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
 
+    def _set_parameter_source_name(self, op_node, outputs):
+        """A helper function to rewrite source_name of parameter."""
+        for name in _get_input_names(op_node):
+            expr = outputs[name]
+            if isinstance(expr, (_expr.Var, _expr.Constant)):
+                name_sep = "_" if self.use_parser_friendly_name else "."
+                source_name = [self.source_map[op_node]]
+                if isinstance(expr, _expr.Var):
+                    # variable name should have contained node source name
+                    # for op with attributes in convert_params stage
+                    # e.g. "aten::batch_norm_5.running_mean"
+                    if expr.name_hint.startswith(source_name[0]):
+                        source_name[0] = expr.name_hint
+                    else:
+                        source_name.append(expr.name_hint)
+                new_expr = set_span(expr, name_sep.join(source_name))
+                outputs[name] = new_expr
+
 
 def _pytorch_result_type(dtypes, non_tensor_inputs):
     """This promotes TVM dtypes like PyTorch would"""
@@ -4493,13 +4554,67 @@ def _get_constant(node):
         return None
 
 
-def _get_operator_nodes(nodes):
+def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name):
+    """Rewrite debug name of node outputs with its operator type"""
+
+    def _get_source_name(op_type):
+        op_idx = 0
+        if op_type in op_type_dict:
+            op_idx = op_type_dict[op_type] + 1
+        op_type_dict[op_type] = op_idx
+        return "_".join([op_type, str(op_idx)])
+
+    # get source name of operator and rename all of its outputs
+    # e.g. node.kind(): aten::adaptive_max_pool2d
+    # node_src_name -> aten::adaptive_max_pool2d_x
+    # output_1 -> aten::adaptive_max_pool2d_x_0
+    # output_2 -> aten::adaptive_max_pool2d_x_1
+    if node.kind() != "prim::GetAttr":
+        node_src_name = _get_source_name(node.kind())
+        for index, output in enumerate(node.outputs()):
+            output.setDebugName("_".join([node_src_name, str(index)]))
+        # update source map
+        # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> 
prim__Constant_0
+        if use_parser_friendly_name:
+            node_src_name = re.sub(r":|\.", "_", node_src_name)
+        source_map[node] = node_src_name
+
+
+def _debug_rename(graph, use_parser_friendly_name):
+    """Returns map between node and source name"""
+    source_map, op_type_dict = {}, {}
+    prim_with_blocks = ["prim::If", "prim::Loop"]
+
+    def _traverse_graph(nodes):
+        for node in nodes:
+            if node.outputsSize() == 0:
+                continue
+            if node.kind() in prim_with_blocks:
+                for block in node.blocks():
+                    _traverse_graph(block.nodes())
+            _rename_outputs(node, source_map, op_type_dict, 
use_parser_friendly_name)
+
+    _traverse_graph(graph.nodes())
+    return source_map
+
+
+def _get_operator_nodes(
+    nodes,
+    source_map=None,
+    op_type_dict=None,
+    use_parser_friendly_name=False,
+):
     """Returns torch IR nodes that need conversion to Relay"""
-    ops = []
+    ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None
+
     # Traverse nodes and add to graph
     for node in nodes:
         if node.outputsSize() == 0:
             continue
+
+        if should_rename_graph:
+            _rename_outputs(node, source_map, op_type_dict, 
use_parser_friendly_name)
+
         if node.outputsSize() > 1:
             node_name = "_".join(_get_output_names(node))
         else:
@@ -4670,7 +4785,7 @@ def get_attr_chains(root_getattr_node):
     return get_use_chains(root_getattr_node, terminate)
 
 
-def convert_params(graph, state_dict, use_parser_friendly_name=False):
+def convert_params(graph, state_dict, source_map, 
use_parser_friendly_name=False):
     """
     Return Relay vars and TVM NDArrays for input parameters
     A chain of prim::GetAttr nodes is processed one at a time
@@ -4679,6 +4794,7 @@ def convert_params(graph, state_dict, 
use_parser_friendly_name=False):
     params = {}
     param_tensors = {}
     packed_param_map = {}
+    param_debug_name_map = {}
     vars_by_name = {}
     seen = set()
     attr_name_sep = "_" if use_parser_friendly_name else "."
@@ -4692,20 +4808,30 @@ def convert_params(graph, state_dict, 
use_parser_friendly_name=False):
 
             full_attr = _getattr_full_name(getattrs, attr_name_sep)
             full_attr_node_name = _get_output_name(getattrs[-1])
+            # set variable name by concatenating first consumer's name with 
full attribute
+            # e.g. "aten::batch_norm_5.running_mean"
+            var_name = attr_name_sep.join(
+                [
+                    source_map[_get_users(getattrs[-1])[0]],
+                    full_attr.split(attr_name_sep)[-1],
+                ]
+            )
 
             if full_attr.endswith("_packed_params"):  # for quantized models
                 packed_param_map[full_attr_node_name] = full_attr
             elif full_attr in state_dict:
-                if full_attr in vars_by_name:
-                    var = vars_by_name[full_attr]
+                if var_name in vars_by_name:
+                    var = vars_by_name[var_name]
                 else:
                     torch_tensor = state_dict[full_attr]
-                    tensor, var = _get_tensor_and_var(torch_tensor, full_attr)
-                    param_tensors[full_attr] = tensor
-                    vars_by_name[full_attr] = var
+                    tensor, var = _get_tensor_and_var(torch_tensor, var_name)
+                    param_tensors[var_name] = tensor
+                    # for quantized parameters to be correctly located
+                    param_debug_name_map[full_attr_node_name] = var_name
+                    vars_by_name[var_name] = var
                 params[full_attr_node_name] = var
 
-    return params, param_tensors, packed_param_map
+    return params, param_tensors, packed_param_map, param_debug_name_map
 
 
 def get_all_op_names(graph):
@@ -4720,6 +4846,19 @@ def get_all_op_names(graph):
     return set(node.kind() for node in nodes)
 
 
+def export_c_graph(location, graph):
+    """Convert the graph to an onnx model and export it to the location."""
+    import datetime
+    import os
+
+    if not os.path.exists(location):
+        os.makedirs(location)
+    time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
+    fname = os.path.join(location, 
"tvm_exported_c_graph_{}.txt".format(time_stamp))
+    with open(f"{fname}", "w") as f:
+        f.write(str(graph))
+
+
 def from_pytorch(
     script_module,
     input_infos,
@@ -4727,6 +4866,7 @@ def from_pytorch(
     default_dtype="float32",
     use_parser_friendly_name=False,
     keep_quantized_weight=False,
+    export_renamed_c_graph_path=None,
 ):
     """Load PyTorch model in the form of a scripted PyTorch model and convert 
into relay.
     The companion parameters will be handled automatically.
@@ -4769,6 +4909,11 @@ def from_pytorch(
         we quantize weights in the frontend using a function that is 
equivalent to
         qnn.op.quantize(...) operating on Numpy arrays.
 
+    export_renamed_c_graph_path : str, optional
+        Export the renamed torch._C.Graph to the path.
+        During the conversion, variable names in torch._C.Graph will be 
assigned based on their op
+        types. The exported text file can be the reference to spans.
+
     Returns
     -------
     mod : tvm.IRModule
@@ -4783,7 +4928,7 @@ def from_pytorch(
     prelude = Prelude(mod)
     enable_lower_all_tuples = True
 
-    converter = PyTorchOpConverter(prelude, default_dtype)
+    converter = PyTorchOpConverter(prelude, default_dtype, 
use_parser_friendly_name)
 
     graph = script_module.graph.copy()
 
@@ -4812,12 +4957,16 @@ def from_pytorch(
         new_names = [key.replace(".", "_") for key in params.keys()]
         params = dict(zip(new_names, params.values()))
 
-    param_vars, tensors, packed_param_map = convert_params(graph, params, 
use_parser_friendly_name)
+    # rename _C.Graph here for constructing meaningful source name of graph 
nodes
+    # by doing so, we could Use source_map as the reference to rename model 
parameters
+    source_map = _debug_rename(graph, use_parser_friendly_name)
+    param_vars, tensors, packed_param_map, param_debug_name_map = 
convert_params(
+        graph, params, source_map, use_parser_friendly_name
+    )
 
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
     outputs.update(param_vars)
-    ret_name = _get_input_names(graph.return_node())
 
     # For quantized models
     quantized_ops = set(["aten::quantize_per_tensor", 
"quantized::linear_dynamic"])
@@ -4825,7 +4974,7 @@ def from_pytorch(
         weight_quant_params = qnn_torch.get_weight_quant_params(
             script_module, packed_param_map.values()
         )
-        qnn_torch.inline_input_quant_params_for_fx(graph, tensors)
+        qnn_torch.inline_input_quant_params_for_fx(graph, tensors, 
param_debug_name_map)
         input_scales_for_bias = 
qnn_torch.add_input_quant_params_to_op_inputs(graph)
         qnn_torch.add_quant_params_to_outputs(
             outputs,
@@ -4837,7 +4986,14 @@ def from_pytorch(
         qnn_torch.add_quant_params(tvm_params, weight_quant_params)
         converter.update_convert_map(qnn_torch.convert_map)
 
-    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), 
outputs, ret_name)
+    operator_nodes = _get_operator_nodes(
+        graph.nodes(),
+        converter.source_map,
+        converter.op_type_dict,
+        use_parser_friendly_name,
+    )
+    ret_name = _get_input_names(graph.return_node())
+    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
 
     # ListConstruct kept original python list. Convert to tuple.
     outputs = [_expr.Tuple(output) if isinstance(output, list) else output for 
output in outputs]
@@ -4859,4 +5015,7 @@ def from_pytorch(
 
     mod["main"] = tvm.relay.Function(func_args, ret)
 
+    if export_renamed_c_graph_path:
+        export_c_graph(export_renamed_c_graph_path, graph)
+
     return transform.RemoveUnusedFunctions()(mod), tvm_params
diff --git a/python/tvm/relay/frontend/qnn_torch.py 
b/python/tvm/relay/frontend/qnn_torch.py
index a4eb56c104..131a471fd5 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -534,7 +534,7 @@ def add_quant_params(params, quant_params):
             params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
 
 
-def inline_input_quant_params_for_fx(graph, params):
+def inline_input_quant_params_for_fx(graph, params, param_debug_name_map):
     """
     Canonicalize input scale and zero point access for FX-quantized graphs.
     We expect input qparams to aten::quantize_per_tensor to be prim::Constant, 
but that's
@@ -568,7 +568,7 @@ def inline_input_quant_params_for_fx(graph, params):
         out_name = node.output().debugName()
 
         if "_scale" in out_name or "_zero_point" in out_name:
-            full_attr = get_full_attr_name(node)
+            full_attr = param_debug_name_map[get_full_attr_name(node)]
             assert full_attr in params, "%s not found in param dict." % 
full_attr
             param_np = params[full_attr].numpy()
             new_const_node = graph.create("prim::Constant")
diff --git a/tests/python/frontend/pytorch/qnn_test.py 
b/tests/python/frontend/pytorch/qnn_test.py
index e9fbe12e97..beaeeb9999 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -45,9 +45,15 @@ def torch_version_check():
 
 def get_tvm_runtime(script_module, input_name, ishape, 
keep_quantized_weight=False, target="llvm"):
     input_shapes = [(input_name, ishape)]
-    mod, params = relay.frontend.from_pytorch(
-        script_module, input_shapes, 
keep_quantized_weight=keep_quantized_weight
-    )
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(
+            script_module, input_shapes, 
keep_quantized_weight=keep_quantized_weight
+        )
+    with tvm.testing.enable_span_filling():
+        mod_with_span, _ = relay.frontend.from_pytorch(
+            script_module, input_shapes, 
keep_quantized_weight=keep_quantized_weight
+        )
+    assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
 
     if keep_quantized_weight:
         for p in params.values():
@@ -629,7 +635,11 @@ def pattern_table():
 
 def run_qnn_mergecomposite(script_module, input_name, ishape):
     input_shapes = [(input_name, ishape)]
-    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+    with tvm.testing.enable_span_filling():
+        mod_with_span, _ = relay.frontend.from_pytorch(script_module, 
input_shapes)
+    assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
     pattern_table = get_pattern_table("test_table")
     with tvm.transform.PassContext(opt_level=3):
         pass_list = [
@@ -778,7 +788,11 @@ def test_tuple_lowered():
     script_module = torch.jit.trace(model_int8, fp32_input).eval()
 
     input_infos = [("input", (fp32_input.shape, "float32"))]
-    mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
+    with tvm.testing.disable_span_filling():
+        mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
+    with tvm.testing.enable_span_filling():
+        mod_with_span, _ = relay.frontend.from_pytorch(script_module, 
input_infos)
+    assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
     output = mod["main"].body
 
     assert isinstance(output, relay.Tuple) and len(output) == 2
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 807c44a364..b5fcaaecae 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -29,7 +29,8 @@ import tvm.testing
 from tvm import relay
 from tvm.contrib import graph_executor
 from tvm.contrib.nvcc import have_fp16
-from tvm.contrib import cudnn
+from tvm.contrib import cudnn, utils
+from relay.utils.tag_span import _create_span, _set_span, 
_verify_structural_equal_with_span
 
 import torch
 from torch.nn import Module
@@ -135,6 +136,7 @@ def verify_model(
     kind="graph",
     check_correctness=True,
     cpu_only=False,
+    validate_structural_equal=True,
 ):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
@@ -175,7 +177,13 @@ def verify_model(
 
     input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
     input_shapes = list(zip(input_names, [inp.shape for inp in 
baseline_input]))
-    mod, params = relay.frontend.from_pytorch(trace, input_shapes, 
custom_convert_map)
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(trace, input_shapes, 
custom_convert_map)
+    if validate_structural_equal:
+        with tvm.testing.enable_span_filling():
+            mod_with_span, _ = relay.frontend.from_pytorch(trace, 
input_shapes, custom_convert_map)
+        assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
+
     for arg in mod["main"].params[: len(input_names)]:
         assert arg.name_hint in input_names
     compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp 
in baseline_input]))
@@ -231,6 +239,7 @@ def verify_model_with_input(
     rtol=1e-5,
     atol=1e-5,
     assert_shape_only=False,
+    validate_structural_equal=True,
 ):
     """Generic function to generate and compare Pytorch and TVM output"""
     input_dict = input_dict or {}
@@ -239,7 +248,13 @@ def verify_model_with_input(
     trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
     input_names = [f"input{idx}" for idx, _ in enumerate(input_data)]
     input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
-    mod, params = relay.frontend.from_pytorch(trace, input_shapes, 
custom_convert_map)
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(trace, input_shapes, 
custom_convert_map)
+    if validate_structural_equal:
+        with tvm.testing.enable_span_filling():
+            mod_with_span, _ = relay.frontend.from_pytorch(trace, 
input_shapes, custom_convert_map)
+        assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
+
     with tvm.transform.PassContext(opt_level=3):
         for target in ["llvm", "cuda"]:
             if not tvm.runtime.enabled(target):
@@ -257,6 +272,20 @@ def verify_model_with_input(
                 tvm.testing.assert_allclose(baseline_outputs, compiled_output, 
rtol=rtol, atol=atol)
 
 
+def gen_ir_module(model, inputs, use_parser_friendly_name=False):
+    """Helper function to generate IRModule with meaningful source 
information"""
+
+    trace = torch.jit.trace(model, inputs)
+    input_names = ["input{}".format(idx) for idx, _ in enumerate(inputs)]
+    input_shapes = list(zip(input_names, [inp.shape for inp in inputs]))
+    mod, _ = relay.frontend.from_pytorch(
+        trace,
+        input_shapes,
+        use_parser_friendly_name=use_parser_friendly_name,
+    )
+    return mod
+
+
 # Single operator tests
 @tvm.testing.uses_gpu
 def test_forward_pixel_shuffle():
@@ -2596,7 +2625,11 @@ def verify_model_vm(input_model, ishapes, idtype=None, 
idata=None, targets=None)
             input_data = [torch.randn(shape, dtype=idtype) for shape in 
ishapes]
 
     # Compile via VM
-    mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
+    with tvm.testing.enable_span_filling():
+        mod_with_span, _ = relay.frontend.from_pytorch(input_model, 
input_shapes)
+    assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
 
     for tgt in targets:
         if not tvm.testing.device_enabled(tgt):
@@ -3951,7 +3984,8 @@ def test_forward_dtypes():
 def test_weight_names():
     tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
     _, params = relay.frontend.from_pytorch(tm, [("input", (2, 3))])
-    assert set(params.keys()) == set(n for n, _ in tm.named_parameters())
+    keys = [key.split(".")[-1] for key in params.keys()]
+    assert set(keys) == set(n for n, p in tm.named_parameters())
 
 
 @tvm.testing.uses_gpu
@@ -4355,12 +4389,12 @@ def test_randn():
     def test_func():
         return torch.randn([1, 3, 10, 10])
 
-    verify_model_with_input(test_func, [], assert_shape_only=True)
+    verify_model_with_input(test_func, [], assert_shape_only=True, 
validate_structural_equal=False)
 
     def test_func1():
         return torch.randn(1, 3, 10, 10)
 
-    verify_model_with_input(test_func1, [], assert_shape_only=True)
+    verify_model_with_input(test_func1, [], assert_shape_only=True, 
validate_structural_equal=False)
 
 
 def test_forward_pretrained_bert_base_uncased():
@@ -5137,18 +5171,25 @@ def test_trilu():
 
 
 def test_multinomial():
+    """test_multinomial"""
+
     def _test_multinomial(num_samples):
         return lambda inp: torch.multinomial(inp, num_samples=num_samples, 
replacement=True)
 
     # Dont check output since it's random. Instead we'll just make sure shapes 
are right.
     verify_model(
-        _test_multinomial(2), [torch.rand(size=[3]).float()], cpu_only=True, 
check_correctness=False
+        _test_multinomial(2),
+        [torch.rand(size=[3]).float()],
+        cpu_only=True,
+        check_correctness=False,
+        validate_structural_equal=False,
     )
     verify_model(
         _test_multinomial(1),
         [torch.rand(size=[4, 5]).float()],
         cpu_only=True,
         check_correctness=False,
+        validate_structural_equal=False,
     )
 
 
@@ -5190,5 +5231,232 @@ def test_baddbmm():
     verify_model(test_fn(0.5, 1.0), [M, batch1, batch2])
 
 
+def test_exporting_renamed_c_graph():
+    """test exproting model when export_renamed_model is set"""
+
+    # model definition
+    class Conv2D(Module):
+        def __init__(self):
+            super(Conv2D, self).__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+
+        def forward(self, *args):
+            return self.conv(args[0])
+
+    input_name, input_shape = "input", [1, 3, 10, 10]
+    shape_list = [(input_name, input_shape)]
+    temp_dir = utils.tempdir().path
+    script_module = torch.jit.trace(Conv2D(), [torch.rand(input_shape)])
+    _, _ = relay.frontend.from_pytorch(
+        script_module, shape_list, export_renamed_c_graph_path=temp_dir
+    )
+
+    exported_c_graph_name = os.listdir(temp_dir)[0]
+    assert "tvm_exported_c_graph_" in exported_c_graph_name
+
+    # make sure the renamed output variable presents in the restored _C.Graph
+    with open(f"{temp_dir}/{exported_c_graph_name}", "r") as f:
+        graph = f.read()
+        assert "%aten::_convolution_0" in graph
+
+
+class TestSetSpan:
+    """test structural equal between translated / hand-crafted relay IR with 
span tagged."""
+
+    def _verify(self, res_fptr, golden_fptr):
+        with tvm.testing.enable_span_filling():
+            with_span = res_fptr()
+        with tvm.testing.disable_span_filling():
+            without_span = res_fptr()
+        assert tvm.ir.structural_equal(with_span, without_span)
+        _verify_structural_equal_with_span(with_span, golden_fptr())
+
+    def test_conv2d_bias_add(self):
+        ker_sz, in_chs, out_chs = 7, 3, 6
+        input_shape = [1, 3, 10, 10]
+
+        def _res():
+            # model definition
+            class Conv2D(Module):
+                def __init__(self):
+                    super(Conv2D, self).__init__()
+                    self.conv = torch.nn.Conv2d(in_chs, out_chs, ker_sz, 
bias=True)
+
+                def forward(self, *args):
+                    return self.conv(args[0])
+
+            # get frontend model
+            mod = gen_ir_module(Conv2D(), [torch.rand(input_shape)])
+            return mod["main"]
+
+        def _golden():
+            conv_si = "aten::_convolution_0"
+            input_name = "input0"
+            input_0 = relay.var(
+                input_name,
+                shape=tuple(input_shape),
+                span=_create_span(f"{conv_si}.{input_name}"),
+            )
+            weight_name = f"{conv_si}.weight"
+            conv_weight = relay.var(
+                weight_name,
+                shape=(out_chs, in_chs, ker_sz, ker_sz),
+                span=_create_span(weight_name),
+            )
+            bias_name = f"{conv_si}.bias"
+            conv_bias = relay.var(
+                bias_name,
+                shape=(out_chs,),
+                span=_create_span(bias_name),
+            )
+            conv_out = _set_span(
+                relay.nn.conv2d(
+                    input_0,
+                    conv_weight,
+                    padding=[0] * 4,
+                    channels=out_chs,
+                    kernel_size=[ker_sz] * 2,
+                ),
+                conv_si,
+            )
+            bias_out = _set_span(relay.nn.bias_add(conv_out, conv_bias), 
conv_si)
+            return relay.Function([input_0, conv_weight, conv_bias], bias_out)
+
+        self._verify(_res, _golden)
+
+    def test_batchnorm_span(self):
+        features = 16
+        input_shape = [1, 16, 10, 10]
+
+        def _res():
+            # model definition
+            bn_2d = torch.nn.BatchNorm2d(features)
+
+            # get frontend model
+            mod = gen_ir_module(bn_2d, [torch.rand(input_shape)])
+            return mod["main"]
+
+        def _golden():
+            bn_si = "aten::batch_norm_0"
+            input_name = "input0"
+            input_0 = relay.var(
+                input_name,
+                shape=tuple(input_shape),
+                span=_create_span(f"{bn_si}.{input_name}"),
+            )
+            weight_name = f"{bn_si}.weight"
+            bn_weight = relay.var(
+                weight_name,
+                shape=(features,),
+                span=_create_span(weight_name),
+            )
+            bias_name = f"{bn_si}.bias"
+            bn_bias = relay.var(
+                bias_name,
+                shape=(features,),
+                span=_create_span(bias_name),
+            )
+            rm_name = f"{bn_si}.running_mean"
+            bn_rm = relay.var(
+                rm_name,
+                shape=(features,),
+                span=_create_span(rm_name),
+            )
+            rv_name = f"{bn_si}.running_var"
+            bn_rv = relay.var(
+                rv_name,
+                shape=(features,),
+                span=_create_span(rv_name),
+            )
+            bn_out = _set_span(
+                relay.nn.batch_norm(input_0, bn_weight, bn_bias, bn_rm, bn_rv),
+                bn_si,
+            )
+            bn_tuple_get_item = 
_set_span(relay.TupleGetItem(bn_out.tuple_value, 0), bn_si)
+            return relay.Function([input_0, bn_weight, bn_bias, bn_rm, bn_rv], 
bn_tuple_get_item)
+
+        self._verify(_res, _golden)
+
+    def test_reshape_span(self):
+        input_shape = [2, 1, 10, 1, 10]
+        new_shape = [2, 1, 10, 10]
+
+        def _res():
+            # model definition
+            class Reshape(Module):
+                def forward(self, *args):
+                    return args[0].reshape(new_shape)
+
+            # get frontend model
+            mod = gen_ir_module(Reshape(), [torch.rand(input_shape)])
+            return mod["main"]
+
+        def _golden():
+            reshape_si = "aten::reshape_0"
+            input_name = "input0"
+            input_0 = relay.var(
+                input_name,
+                shape=tuple(input_shape),
+                span=_create_span(f"{reshape_si}.{input_name}"),
+            )
+            reshape_out = _set_span(
+                relay.reshape(input_0, newshape=new_shape),
+                reshape_si,
+            )
+            return relay.Function([input_0], reshape_out)
+
+        self._verify(_res, _golden)
+
+    def test_dense_bias_add(self):
+        in_f, out_f = 10, 7
+        input_shape = [in_f, in_f]
+
+        def _res():
+            # model definition
+            class Dense(Module):
+                def __init__(self):
+                    super(Dense, self).__init__()
+                    self.linear = torch.nn.Linear(in_f, out_f, bias=True)
+
+                def forward(self, *args):
+                    return self.linear(args[0])
+
+            # get frontend model
+            mod = gen_ir_module(Dense(), [torch.rand(input_shape)])
+            return mod["main"]
+
+        def _golden():
+            dense_si = "aten::linear_0"
+            input_name = "input0"
+            input_0 = relay.var(
+                input_name,
+                shape=tuple(input_shape),
+                span=_create_span(f"{dense_si}.{input_name}"),
+            )
+            weight_name = f"{dense_si}.weight"
+            dense_weight = relay.var(
+                weight_name,
+                shape=(out_f, in_f),
+                span=_create_span(weight_name),
+            )
+            bias_name = f"{dense_si}.bias"
+            dense_bias = relay.var(
+                bias_name,
+                shape=(out_f,),
+                span=_create_span(bias_name),
+            )
+            dense_out = _set_span(
+                relay.nn.dense(input_0, dense_weight),
+                dense_si,
+            )
+            bias_out = _set_span(
+                relay.nn.bias_add(dense_out, dense_bias, axis=-1),
+                dense_si,
+            )
+            return relay.Function([input_0, dense_weight, dense_bias], 
bias_out)
+
+        self._verify(_res, _golden)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/frontend/pytorch/test_fx_quant.py 
b/tests/python/frontend/pytorch/test_fx_quant.py
index f35094a831..564900cbf2 100644
--- a/tests/python/frontend/pytorch/test_fx_quant.py
+++ b/tests/python/frontend/pytorch/test_fx_quant.py
@@ -23,6 +23,7 @@ from torch.quantization.quantize_fx import prepare_fx, 
convert_fx
 from torchvision.models.efficientnet import efficientnet_b4
 from torchvision.models.resnet import resnet50
 from tvm import relay
+import tvm.testing
 
 
 def quantize(model):
@@ -38,7 +39,11 @@ def quantize_and_build(model, in_size):
 
     with torch.no_grad():
         script_module = torch.jit.trace(qmodel, inp)
-        mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, 
inp.shape)])
+        with tvm.testing.disable_span_filling():
+            mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, 
inp.shape)])
+        with tvm.testing.enable_span_filling():
+            mod_with_span, _ = relay.frontend.from_pytorch(script_module, 
[(input_name, inp.shape)])
+        assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
         mod = relay.transform.InferType()(mod)
 
         # Make sure that the model is quantized
diff --git a/tests/python/frontend/pytorch/test_lstm.py 
b/tests/python/frontend/pytorch/test_lstm.py
index 25d4563ee6..e9dd2b380c 100644
--- a/tests/python/frontend/pytorch/test_lstm.py
+++ b/tests/python/frontend/pytorch/test_lstm.py
@@ -337,7 +337,11 @@ def test_custom_lstm():
 
     for (name, raw_model, states, input_shapes) in models:
         script_module = torch.jit.script(raw_model)
-        mod, params = from_pytorch(script_module, input_shapes)
+        with tvm.testing.disable_span_filling():
+            mod, params = from_pytorch(script_module, input_shapes)
+        with tvm.testing.enable_span_filling():
+            mod_with_span, _ = from_pytorch(script_module, input_shapes)
+        assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
 
         with torch.no_grad():
             pt_result = raw_model(inp.clone(), states)
diff --git a/tests/python/frontend/pytorch/test_object_detection.py 
b/tests/python/frontend/pytorch/test_object_detection.py
index 83b13f686b..25e784b00a 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -104,7 +104,11 @@ def test_detection_models():
     shape_list = [(input_name, input_shape)]
 
     scripted_model = generate_jit_model(1)
-    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+    with tvm.testing.disable_span_filling():
+        mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+    with tvm.testing.enable_span_filling():
+        mod_with_span, _ = relay.frontend.from_pytorch(scripted_model, 
shape_list)
+    assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
 
     data = process_image(img)
     data_np = data.detach().numpy()
diff --git a/tests/python/frontend/pytorch/test_rnns.py 
b/tests/python/frontend/pytorch/test_rnns.py
index fba55b9c4c..3ea4232500 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -456,7 +456,15 @@ def check_rnn(rnn_type, rnn_mod, 
target=tvm.target.Target("llvm -mcpu=core-avx2"
                         traced_script_module = torch.jit.trace(model, 
dummy_inputs[0]).eval()
 
                         # Import model to Relay
-                        mod, params = 
relay.frontend.from_pytorch(traced_script_module, shape_desc)
+                        with tvm.testing.disable_span_filling():
+                            mod, params = relay.frontend.from_pytorch(
+                                traced_script_module, shape_desc
+                            )
+                        with tvm.testing.enable_span_filling():
+                            mod_with_span, _ = relay.frontend.from_pytorch(
+                                traced_script_module, shape_desc
+                            )
+                        assert tvm.ir.structural_equal(mod, mod_with_span, 
map_free_vars=True)
                     elif format == "onnx":
                         try:
                             onnx_model = get_onnx_model(model)
@@ -468,7 +476,11 @@ def check_rnn(rnn_type, rnn_mod, 
target=tvm.target.Target("llvm -mcpu=core-avx2"
                             continue
 
                         # Import model to Relay
-                        mod, params = relay.frontend.from_onnx(onnx_model, 
shape_desc)
+                        with tvm.testing.disable_span_filling():
+                            mod, params = relay.frontend.from_onnx(onnx_model, 
shape_desc)
+                        with tvm.testing.enable_span_filling():
+                            mod_with_span, _ = 
relay.frontend.from_onnx(onnx_model, shape_desc)
+                        assert tvm.ir.structural_equal(mod, mod_with_span, 
map_free_vars=True)
 
                     # Model compilation by tvm
                     with tvm.transform.PassContext(opt_level=3):


Reply via email to