This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e9cf04e0e4 [Relay][Frontend] Span Filling PyTorch (#14050)
e9cf04e0e4 is described below
commit e9cf04e0e4ec325ca665fe2d262b481985c8cf42
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Mar 1 20:02:00 2023 +0800
[Relay][Frontend] Span Filling PyTorch (#14050)
* [Relay][Frontend] Span Filling PyTorch
- Construct debug name of C graph instruction as the source name of span
for pytorch model.
- To get the reference of renamed nodes. Add a function to export the
converted C graph after conversion.
- Add structural_equal comparisons with and without set_span to the
existing test cases.
- Add span test cases for frequent conversions.
- Add span test case for exporting model parameter.
* [SpanFillingPyTorch]
- Return TupleGetItem expr from TupleWrapper with the span of its Tuple.
- Add None type symbol in set sapn for certain conversion.
- Add current_op member varible to PyTorchOpConverter to track which op
is converting for pytorch frontend.
* [SpanFillingPyTorch]
- Fix the error caused by the quantized params not found after renaming
the debug name of C graph.
---------
Co-authored-by: Joey Tsai <[email protected]>
---
python/tvm/relay/expr.py | 2 +-
python/tvm/relay/frontend/common.py | 4 +
python/tvm/relay/frontend/pytorch.py | 221 +++++++++++++---
python/tvm/relay/frontend/qnn_torch.py | 4 +-
tests/python/frontend/pytorch/qnn_test.py | 24 +-
tests/python/frontend/pytorch/test_forward.py | 284 ++++++++++++++++++++-
tests/python/frontend/pytorch/test_fx_quant.py | 7 +-
tests/python/frontend/pytorch/test_lstm.py | 6 +-
.../frontend/pytorch/test_object_detection.py | 6 +-
tests/python/frontend/pytorch/test_rnns.py | 16 +-
10 files changed, 522 insertions(+), 52 deletions(-)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index cb14552ac1..d8bca5c4a4 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -605,7 +605,7 @@ class TupleWrapper(object):
def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
- return TupleGetItem(self.tuple_value, index)
+ return TupleGetItem(self.tuple_value, index,
span=self.tuple_value.span)
def __len__(self):
return self.size
diff --git a/python/tvm/relay/frontend/common.py
b/python/tvm/relay/frontend/common.py
index 5d3b0a3345..39e17b27da 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -1169,6 +1169,10 @@ class _SpanFiller(ExprMutator):
return sym
elif isinstance(sym, np.ndarray):
return sym
+ elif not sym:
+ # some op conversion may return None
+ # e.g. op in frontend/pytorch.py: prim::device
+ return sym
raise RuntimeError(f"unsupported type {type(sym)}")
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 3cdfc5cb4e..89464face7 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -22,6 +22,7 @@
import functools
import itertools
import math
+import re
import sys
import numpy as np
@@ -44,6 +45,7 @@ from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind, fold_constant
+from .common import set_span
from .pytorch_utils import is_version_greater_than, getattr_attr_name
__all__ = ["from_pytorch"]
@@ -135,11 +137,15 @@ def _is_int_seq(seq):
class PyTorchOpConverter:
"""A helper class for holding PyTorch op converters."""
- def __init__(self, prelude, default_dtype):
+ def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
self.prelude = prelude
self.default_dtype = default_dtype
self.create_convert_map()
self.types = {} # map from nodes to (Relay) type annotations
+ self.source_map = {} # map from graph node to its source name
+ self.op_type_dict = {} # map from op type to its presenting order
+ self.current_op = [] # stack for recording current processing op
+ self.use_parser_friendly_name = use_parser_friendly_name
# this incrementally infers the type, see the comments on the type visitor
# above.
@@ -344,7 +350,10 @@ class PyTorchOpConverter:
def _get_value(val, dtype):
# dtype is a tvm dtype
if isinstance(val, _expr.Expr):
- inp = _op.cast(val, dtype)
+ # since "arange" op will fill expr into its attribute
+ # invoke set_span here to prevent expr-rewritten occurrs in
span-filling stage
+ source_name = self.source_map[self.current_op[-1]]
+ inp = set_span(_op.cast(val, dtype), source_name)
ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret,
dtype))
else:
ret = _create_typed_const(val, dtype)
@@ -2405,11 +2414,16 @@ class PyTorchOpConverter:
iou_threshold = inputs[2]
# TVM NMS assumes score > 0
- scores = scores - _op.min(scores) + _op.const(1.0)
+ # - since there exists multi-comsumers for "scores", "num_boxes"
+ # - invoke set_span here to prevent expr-rewritten occurrs in
span-filling stage
+ source_name = self.source_map[self.current_op[-1]]
+ scores = set_span(scores - _op.min(scores) + _op.const(1.0),
source_name)
- num_boxes = _op.shape_of(scores)
+ num_boxes = set_span(_op.shape_of(scores), source_name)
# PyTorch NMS doesn't have score_threshold, so no need to run
get_valid_count
- indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32")
+ # - since "arange" op will fill expr into its attribute
+ # - invoke set_span here to prevent expr-rewritten occurrs in
span-filling stage
+ indices = _op.transform.arange(set_span(_op.squeeze(num_boxes),
source_name), dtype="int32")
indices = _op.expand_dims(indices, 0, 1)
# Generate data with shape (1, num_anchors, 5)
@@ -4008,7 +4022,12 @@ class PyTorchOpConverter:
def convert_block(self, block, outputs):
"""Translate Torch "Block", used for prim::If and prim::Loop"""
- ops = _get_operator_nodes(block.nodes())
+ ops = _get_operator_nodes(
+ block.nodes(),
+ self.source_map,
+ self.op_type_dict,
+ self.use_parser_friendly_name,
+ )
ret_names = _get_input_names(block.returnNode())
return self.convert_operators(ops, outputs, ret_names)
@@ -4079,13 +4098,19 @@ class PyTorchOpConverter:
actual_shape.append(Any())
else:
actual_shape.append(dim)
- return _expr.var(name, shape=actual_shape,
dtype=checked_type.dtype)
+ expr = _expr.var(name, shape=actual_shape,
dtype=checked_type.dtype)
else:
- return _expr.var(name, type_annotation=checked_type)
+ expr = _expr.var(name, type_annotation=checked_type)
+ return set_span(expr, val.span) if val.span else expr
return _expr.var(name)
- loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
- loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+ source_name = self.source_map[loop_node]
+ loop_iter_var = set_span(
+ _expr.var(block_input_names[0], shape=(), dtype=loop_iter_dtype),
span=source_name
+ )
+ loop_vars = set_span(
+ [get_var(name, val) for name, val in name_val_pairs[1:]],
span=source_name
+ )
# Add non constant free variables to loop variables to prevent code
blow up
# Without this, if there are two for loops in a row, which often
happens
@@ -4108,7 +4133,7 @@ class PyTorchOpConverter:
prev_output = outputs[name]
new_loop_var = get_var(name, prev_output)
prev_outputs[name] = prev_output
- outputs[name] = new_loop_var
+ outputs[name] = set_span(new_loop_var, source_name)
loop_vars.append(new_loop_var)
init_vals.append(prev_output)
@@ -4156,11 +4181,17 @@ class PyTorchOpConverter:
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
+ # we need to record what current operator is to provide correct
source name
+ # for operators needed to be taken care with (e.g. nms / arange
...)
+ self.current_op.append(op_node)
if operator == "prim::Constant":
outputs[node_name] = _get_constant(op_node)
elif operator == "prim::ListConstruct" and
_should_construct_dynamic_list(op_node):
- outputs[node_name] = self.convert_to_list_adt(inputs)
+ outputs[node_name] = set_span(
+ self.convert_to_list_adt(inputs),
+ self.source_map[op_node],
+ )
elif operator == "prim::ListConstruct":
# This assumes that no more elements will be appended to this
list
# In this case, we keep the Python list
@@ -4177,25 +4208,30 @@ class PyTorchOpConverter:
inputs_list.append(inputs[i])
return _expr.Tuple(inputs_list)
- outputs[node_name] = _handel_nested_input(inputs)
+ outputs[node_name] = set_span(
+ _handel_nested_input(inputs),
+ self.source_map[op_node],
+ )
elif operator in ["prim::ListUnpack", "prim::TupleUnpack"]:
assert len(inputs) == 1
if isinstance(inputs[0], (list, _expr.TupleWrapper)):
unpacked = inputs[0]
else:
unpacked = _unpack_tuple(inputs[0])
- outputs.update(zip(_get_output_names(op_node), unpacked))
+ outputs.update(
+ zip(_get_output_names(op_node), set_span(unpacked,
self.source_map[op_node]))
+ )
elif operator == "prim::prim::RaiseException":
logger.warning("raising exceptions is ignored")
outputs[node_name] = None
elif operator == "prim::If":
if_out = self.convert_if(op_node, outputs)
- outputs[node_name] = if_out
+ outputs[node_name] = set_span(if_out, self.source_map[op_node])
elif operator == "prim::Loop":
loop_out = self.convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
- outputs.update(zip(unpacked_names, loop_out))
+ outputs.update(zip(unpacked_names, set_span(loop_out,
self.source_map[op_node])))
else:
if operator not in self.convert_map:
# At this point, the only possible ops that are not in
convert_map are
@@ -4210,9 +4246,14 @@ class PyTorchOpConverter:
else:
relay_op = self.convert_map[operator]
+ self._set_parameter_source_name(op_node, outputs)
relay_out = relay_op(
- inputs, _get_input_types(op_node, outputs,
default_dtype=self.default_dtype)
+ # since the elements in "outputs" may change due to
span-filling process
+ # we have to call "_get_op_inputs" again rather than use
"inputs" directly
+ _get_op_inputs(op_node, outputs),
+ _get_input_types(op_node, outputs,
default_dtype=self.default_dtype),
)
+ relay_out = set_span(relay_out, self.source_map[op_node])
self.record_output_type(relay_out)
if isinstance(relay_out, tuple):
@@ -4224,8 +4265,28 @@ class PyTorchOpConverter:
assert op_node.outputsSize() == 1
outputs[node_name] = relay_out
+ self.current_op.pop()
+
return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
+ def _set_parameter_source_name(self, op_node, outputs):
+ """A helper function to rewrite source_name of parameter."""
+ for name in _get_input_names(op_node):
+ expr = outputs[name]
+ if isinstance(expr, (_expr.Var, _expr.Constant)):
+ name_sep = "_" if self.use_parser_friendly_name else "."
+ source_name = [self.source_map[op_node]]
+ if isinstance(expr, _expr.Var):
+ # variable name should have contained node source name
+ # for op with attributes in convert_params stage
+ # e.g. "aten::batch_norm_5.running_mean"
+ if expr.name_hint.startswith(source_name[0]):
+ source_name[0] = expr.name_hint
+ else:
+ source_name.append(expr.name_hint)
+ new_expr = set_span(expr, name_sep.join(source_name))
+ outputs[name] = new_expr
+
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
@@ -4493,13 +4554,67 @@ def _get_constant(node):
return None
-def _get_operator_nodes(nodes):
+def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name):
+ """Rewrite debug name of node outputs with its operator type"""
+
+ def _get_source_name(op_type):
+ op_idx = 0
+ if op_type in op_type_dict:
+ op_idx = op_type_dict[op_type] + 1
+ op_type_dict[op_type] = op_idx
+ return "_".join([op_type, str(op_idx)])
+
+ # get source name of operator and rename all of its outputs
+ # e.g. node.kind(): aten::adaptive_max_pool2d
+ # node_src_name -> aten::adaptive_max_pool2d_x
+ # output_1 -> aten::adaptive_max_pool2d_x_0
+ # output_2 -> aten::adaptive_max_pool2d_x_1
+ if node.kind() != "prim::GetAttr":
+ node_src_name = _get_source_name(node.kind())
+ for index, output in enumerate(node.outputs()):
+ output.setDebugName("_".join([node_src_name, str(index)]))
+ # update source map
+ # if use_parser_friendly_name is True: e.g. prim::Constant_0 ->
prim__Constant_0
+ if use_parser_friendly_name:
+ node_src_name = re.sub(r":|\.", "_", node_src_name)
+ source_map[node] = node_src_name
+
+
+def _debug_rename(graph, use_parser_friendly_name):
+ """Returns map between node and source name"""
+ source_map, op_type_dict = {}, {}
+ prim_with_blocks = ["prim::If", "prim::Loop"]
+
+ def _traverse_graph(nodes):
+ for node in nodes:
+ if node.outputsSize() == 0:
+ continue
+ if node.kind() in prim_with_blocks:
+ for block in node.blocks():
+ _traverse_graph(block.nodes())
+ _rename_outputs(node, source_map, op_type_dict,
use_parser_friendly_name)
+
+ _traverse_graph(graph.nodes())
+ return source_map
+
+
+def _get_operator_nodes(
+ nodes,
+ source_map=None,
+ op_type_dict=None,
+ use_parser_friendly_name=False,
+):
"""Returns torch IR nodes that need conversion to Relay"""
- ops = []
+ ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None
+
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() == 0:
continue
+
+ if should_rename_graph:
+ _rename_outputs(node, source_map, op_type_dict,
use_parser_friendly_name)
+
if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
else:
@@ -4670,7 +4785,7 @@ def get_attr_chains(root_getattr_node):
return get_use_chains(root_getattr_node, terminate)
-def convert_params(graph, state_dict, use_parser_friendly_name=False):
+def convert_params(graph, state_dict, source_map,
use_parser_friendly_name=False):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
@@ -4679,6 +4794,7 @@ def convert_params(graph, state_dict,
use_parser_friendly_name=False):
params = {}
param_tensors = {}
packed_param_map = {}
+ param_debug_name_map = {}
vars_by_name = {}
seen = set()
attr_name_sep = "_" if use_parser_friendly_name else "."
@@ -4692,20 +4808,30 @@ def convert_params(graph, state_dict,
use_parser_friendly_name=False):
full_attr = _getattr_full_name(getattrs, attr_name_sep)
full_attr_node_name = _get_output_name(getattrs[-1])
+ # set variable name by concatenating first consumer's name with
full attribute
+ # e.g. "aten::batch_norm_5.running_mean"
+ var_name = attr_name_sep.join(
+ [
+ source_map[_get_users(getattrs[-1])[0]],
+ full_attr.split(attr_name_sep)[-1],
+ ]
+ )
if full_attr.endswith("_packed_params"): # for quantized models
packed_param_map[full_attr_node_name] = full_attr
elif full_attr in state_dict:
- if full_attr in vars_by_name:
- var = vars_by_name[full_attr]
+ if var_name in vars_by_name:
+ var = vars_by_name[var_name]
else:
torch_tensor = state_dict[full_attr]
- tensor, var = _get_tensor_and_var(torch_tensor, full_attr)
- param_tensors[full_attr] = tensor
- vars_by_name[full_attr] = var
+ tensor, var = _get_tensor_and_var(torch_tensor, var_name)
+ param_tensors[var_name] = tensor
+ # for quantized parameters to be correctly located
+ param_debug_name_map[full_attr_node_name] = var_name
+ vars_by_name[var_name] = var
params[full_attr_node_name] = var
- return params, param_tensors, packed_param_map
+ return params, param_tensors, packed_param_map, param_debug_name_map
def get_all_op_names(graph):
@@ -4720,6 +4846,19 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
+def export_c_graph(location, graph):
+ """Convert the graph to an onnx model and export it to the location."""
+ import datetime
+ import os
+
+ if not os.path.exists(location):
+ os.makedirs(location)
+ time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
+ fname = os.path.join(location,
"tvm_exported_c_graph_{}.txt".format(time_stamp))
+ with open(f"{fname}", "w") as f:
+ f.write(str(graph))
+
+
def from_pytorch(
script_module,
input_infos,
@@ -4727,6 +4866,7 @@ def from_pytorch(
default_dtype="float32",
use_parser_friendly_name=False,
keep_quantized_weight=False,
+ export_renamed_c_graph_path=None,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert
into relay.
The companion parameters will be handled automatically.
@@ -4769,6 +4909,11 @@ def from_pytorch(
we quantize weights in the frontend using a function that is
equivalent to
qnn.op.quantize(...) operating on Numpy arrays.
+ export_renamed_c_graph_path : str, optional
+ Export the renamed torch._C.Graph to the path.
+ During the conversion, variable names in torch._C.Graph will be
assigned based on their op
+ types. The exported text file can be the reference to spans.
+
Returns
-------
mod : tvm.IRModule
@@ -4783,7 +4928,7 @@ def from_pytorch(
prelude = Prelude(mod)
enable_lower_all_tuples = True
- converter = PyTorchOpConverter(prelude, default_dtype)
+ converter = PyTorchOpConverter(prelude, default_dtype,
use_parser_friendly_name)
graph = script_module.graph.copy()
@@ -4812,12 +4957,16 @@ def from_pytorch(
new_names = [key.replace(".", "_") for key in params.keys()]
params = dict(zip(new_names, params.values()))
- param_vars, tensors, packed_param_map = convert_params(graph, params,
use_parser_friendly_name)
+ # rename _C.Graph here for constructing meaningful source name of graph
nodes
+ # by doing so, we could Use source_map as the reference to rename model
parameters
+ source_map = _debug_rename(graph, use_parser_friendly_name)
+ param_vars, tensors, packed_param_map, param_debug_name_map =
convert_params(
+ graph, params, source_map, use_parser_friendly_name
+ )
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
outputs.update(param_vars)
- ret_name = _get_input_names(graph.return_node())
# For quantized models
quantized_ops = set(["aten::quantize_per_tensor",
"quantized::linear_dynamic"])
@@ -4825,7 +4974,7 @@ def from_pytorch(
weight_quant_params = qnn_torch.get_weight_quant_params(
script_module, packed_param_map.values()
)
- qnn_torch.inline_input_quant_params_for_fx(graph, tensors)
+ qnn_torch.inline_input_quant_params_for_fx(graph, tensors,
param_debug_name_map)
input_scales_for_bias =
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(
outputs,
@@ -4837,7 +4986,14 @@ def from_pytorch(
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
converter.update_convert_map(qnn_torch.convert_map)
- outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()),
outputs, ret_name)
+ operator_nodes = _get_operator_nodes(
+ graph.nodes(),
+ converter.source_map,
+ converter.op_type_dict,
+ use_parser_friendly_name,
+ )
+ ret_name = _get_input_names(graph.return_node())
+ outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
# ListConstruct kept original python list. Convert to tuple.
outputs = [_expr.Tuple(output) if isinstance(output, list) else output for
output in outputs]
@@ -4859,4 +5015,7 @@ def from_pytorch(
mod["main"] = tvm.relay.Function(func_args, ret)
+ if export_renamed_c_graph_path:
+ export_c_graph(export_renamed_c_graph_path, graph)
+
return transform.RemoveUnusedFunctions()(mod), tvm_params
diff --git a/python/tvm/relay/frontend/qnn_torch.py
b/python/tvm/relay/frontend/qnn_torch.py
index a4eb56c104..131a471fd5 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -534,7 +534,7 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
-def inline_input_quant_params_for_fx(graph, params):
+def inline_input_quant_params_for_fx(graph, params, param_debug_name_map):
"""
Canonicalize input scale and zero point access for FX-quantized graphs.
We expect input qparams to aten::quantize_per_tensor to be prim::Constant,
but that's
@@ -568,7 +568,7 @@ def inline_input_quant_params_for_fx(graph, params):
out_name = node.output().debugName()
if "_scale" in out_name or "_zero_point" in out_name:
- full_attr = get_full_attr_name(node)
+ full_attr = param_debug_name_map[get_full_attr_name(node)]
assert full_attr in params, "%s not found in param dict." %
full_attr
param_np = params[full_attr].numpy()
new_const_node = graph.create("prim::Constant")
diff --git a/tests/python/frontend/pytorch/qnn_test.py
b/tests/python/frontend/pytorch/qnn_test.py
index e9fbe12e97..beaeeb9999 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -45,9 +45,15 @@ def torch_version_check():
def get_tvm_runtime(script_module, input_name, ishape,
keep_quantized_weight=False, target="llvm"):
input_shapes = [(input_name, ishape)]
- mod, params = relay.frontend.from_pytorch(
- script_module, input_shapes,
keep_quantized_weight=keep_quantized_weight
- )
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(
+ script_module, input_shapes,
keep_quantized_weight=keep_quantized_weight
+ )
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(
+ script_module, input_shapes,
keep_quantized_weight=keep_quantized_weight
+ )
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
if keep_quantized_weight:
for p in params.values():
@@ -629,7 +635,11 @@ def pattern_table():
def run_qnn_mergecomposite(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
- mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(script_module,
input_shapes)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
pattern_table = get_pattern_table("test_table")
with tvm.transform.PassContext(opt_level=3):
pass_list = [
@@ -778,7 +788,11 @@ def test_tuple_lowered():
script_module = torch.jit.trace(model_int8, fp32_input).eval()
input_infos = [("input", (fp32_input.shape, "float32"))]
- mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
+ with tvm.testing.disable_span_filling():
+ mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(script_module,
input_infos)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
output = mod["main"].body
assert isinstance(output, relay.Tuple) and len(output) == 2
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 807c44a364..b5fcaaecae 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -29,7 +29,8 @@ import tvm.testing
from tvm import relay
from tvm.contrib import graph_executor
from tvm.contrib.nvcc import have_fp16
-from tvm.contrib import cudnn
+from tvm.contrib import cudnn, utils
+from relay.utils.tag_span import _create_span, _set_span,
_verify_structural_equal_with_span
import torch
from torch.nn import Module
@@ -135,6 +136,7 @@ def verify_model(
kind="graph",
check_correctness=True,
cpu_only=False,
+ validate_structural_equal=True,
):
"""Assert that the output of a compiled model matches with that of its
baseline."""
@@ -175,7 +177,13 @@ def verify_model(
input_names = [f"input{idx}" for idx, _ in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in
baseline_input]))
- mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
+ if validate_structural_equal:
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(trace,
input_shapes, custom_convert_map)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
+
for arg in mod["main"].params[: len(input_names)]:
assert arg.name_hint in input_names
compiled_input = dict(zip(input_names, [inp.clone().cpu().numpy() for inp
in baseline_input]))
@@ -231,6 +239,7 @@ def verify_model_with_input(
rtol=1e-5,
atol=1e-5,
assert_shape_only=False,
+ validate_structural_equal=True,
):
"""Generic function to generate and compare Pytorch and TVM output"""
input_dict = input_dict or {}
@@ -239,7 +248,13 @@ def verify_model_with_input(
trace = torch.jit.trace(test_func, [input.clone() for input in input_data])
input_names = [f"input{idx}" for idx, _ in enumerate(input_data)]
input_shapes = list(zip(input_names, [inp.shape for inp in input_data]))
- mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
+ if validate_structural_equal:
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(trace,
input_shapes, custom_convert_map)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
+
with tvm.transform.PassContext(opt_level=3):
for target in ["llvm", "cuda"]:
if not tvm.runtime.enabled(target):
@@ -257,6 +272,20 @@ def verify_model_with_input(
tvm.testing.assert_allclose(baseline_outputs, compiled_output,
rtol=rtol, atol=atol)
+def gen_ir_module(model, inputs, use_parser_friendly_name=False):
+ """Helper function to generate IRModule with meaningful source
information"""
+
+ trace = torch.jit.trace(model, inputs)
+ input_names = ["input{}".format(idx) for idx, _ in enumerate(inputs)]
+ input_shapes = list(zip(input_names, [inp.shape for inp in inputs]))
+ mod, _ = relay.frontend.from_pytorch(
+ trace,
+ input_shapes,
+ use_parser_friendly_name=use_parser_friendly_name,
+ )
+ return mod
+
+
# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
@@ -2596,7 +2625,11 @@ def verify_model_vm(input_model, ishapes, idtype=None,
idata=None, targets=None)
input_data = [torch.randn(shape, dtype=idtype) for shape in
ishapes]
# Compile via VM
- mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(input_model,
input_shapes)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
for tgt in targets:
if not tvm.testing.device_enabled(tgt):
@@ -3951,7 +3984,8 @@ def test_forward_dtypes():
def test_weight_names():
tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])
_, params = relay.frontend.from_pytorch(tm, [("input", (2, 3))])
- assert set(params.keys()) == set(n for n, _ in tm.named_parameters())
+ keys = [key.split(".")[-1] for key in params.keys()]
+ assert set(keys) == set(n for n, p in tm.named_parameters())
@tvm.testing.uses_gpu
@@ -4355,12 +4389,12 @@ def test_randn():
def test_func():
return torch.randn([1, 3, 10, 10])
- verify_model_with_input(test_func, [], assert_shape_only=True)
+ verify_model_with_input(test_func, [], assert_shape_only=True,
validate_structural_equal=False)
def test_func1():
return torch.randn(1, 3, 10, 10)
- verify_model_with_input(test_func1, [], assert_shape_only=True)
+ verify_model_with_input(test_func1, [], assert_shape_only=True,
validate_structural_equal=False)
def test_forward_pretrained_bert_base_uncased():
@@ -5137,18 +5171,25 @@ def test_trilu():
def test_multinomial():
+ """test_multinomial"""
+
def _test_multinomial(num_samples):
return lambda inp: torch.multinomial(inp, num_samples=num_samples,
replacement=True)
# Dont check output since it's random. Instead we'll just make sure shapes
are right.
verify_model(
- _test_multinomial(2), [torch.rand(size=[3]).float()], cpu_only=True,
check_correctness=False
+ _test_multinomial(2),
+ [torch.rand(size=[3]).float()],
+ cpu_only=True,
+ check_correctness=False,
+ validate_structural_equal=False,
)
verify_model(
_test_multinomial(1),
[torch.rand(size=[4, 5]).float()],
cpu_only=True,
check_correctness=False,
+ validate_structural_equal=False,
)
@@ -5190,5 +5231,232 @@ def test_baddbmm():
verify_model(test_fn(0.5, 1.0), [M, batch1, batch2])
+def test_exporting_renamed_c_graph():
+ """test exproting model when export_renamed_model is set"""
+
+ # model definition
+ class Conv2D(Module):
+ def __init__(self):
+ super(Conv2D, self).__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 3, bias=True)
+
+ def forward(self, *args):
+ return self.conv(args[0])
+
+ input_name, input_shape = "input", [1, 3, 10, 10]
+ shape_list = [(input_name, input_shape)]
+ temp_dir = utils.tempdir().path
+ script_module = torch.jit.trace(Conv2D(), [torch.rand(input_shape)])
+ _, _ = relay.frontend.from_pytorch(
+ script_module, shape_list, export_renamed_c_graph_path=temp_dir
+ )
+
+ exported_c_graph_name = os.listdir(temp_dir)[0]
+ assert "tvm_exported_c_graph_" in exported_c_graph_name
+
+ # make sure the renamed output variable presents in the restored _C.Graph
+ with open(f"{temp_dir}/{exported_c_graph_name}", "r") as f:
+ graph = f.read()
+ assert "%aten::_convolution_0" in graph
+
+
+class TestSetSpan:
+ """test structural equal between translated / hand-crafted relay IR with
span tagged."""
+
+ def _verify(self, res_fptr, golden_fptr):
+ with tvm.testing.enable_span_filling():
+ with_span = res_fptr()
+ with tvm.testing.disable_span_filling():
+ without_span = res_fptr()
+ assert tvm.ir.structural_equal(with_span, without_span)
+ _verify_structural_equal_with_span(with_span, golden_fptr())
+
+ def test_conv2d_bias_add(self):
+ ker_sz, in_chs, out_chs = 7, 3, 6
+ input_shape = [1, 3, 10, 10]
+
+ def _res():
+ # model definition
+ class Conv2D(Module):
+ def __init__(self):
+ super(Conv2D, self).__init__()
+ self.conv = torch.nn.Conv2d(in_chs, out_chs, ker_sz,
bias=True)
+
+ def forward(self, *args):
+ return self.conv(args[0])
+
+ # get frontend model
+ mod = gen_ir_module(Conv2D(), [torch.rand(input_shape)])
+ return mod["main"]
+
+ def _golden():
+ conv_si = "aten::_convolution_0"
+ input_name = "input0"
+ input_0 = relay.var(
+ input_name,
+ shape=tuple(input_shape),
+ span=_create_span(f"{conv_si}.{input_name}"),
+ )
+ weight_name = f"{conv_si}.weight"
+ conv_weight = relay.var(
+ weight_name,
+ shape=(out_chs, in_chs, ker_sz, ker_sz),
+ span=_create_span(weight_name),
+ )
+ bias_name = f"{conv_si}.bias"
+ conv_bias = relay.var(
+ bias_name,
+ shape=(out_chs,),
+ span=_create_span(bias_name),
+ )
+ conv_out = _set_span(
+ relay.nn.conv2d(
+ input_0,
+ conv_weight,
+ padding=[0] * 4,
+ channels=out_chs,
+ kernel_size=[ker_sz] * 2,
+ ),
+ conv_si,
+ )
+ bias_out = _set_span(relay.nn.bias_add(conv_out, conv_bias),
conv_si)
+ return relay.Function([input_0, conv_weight, conv_bias], bias_out)
+
+ self._verify(_res, _golden)
+
+ def test_batchnorm_span(self):
+ features = 16
+ input_shape = [1, 16, 10, 10]
+
+ def _res():
+ # model definition
+ bn_2d = torch.nn.BatchNorm2d(features)
+
+ # get frontend model
+ mod = gen_ir_module(bn_2d, [torch.rand(input_shape)])
+ return mod["main"]
+
+ def _golden():
+ bn_si = "aten::batch_norm_0"
+ input_name = "input0"
+ input_0 = relay.var(
+ input_name,
+ shape=tuple(input_shape),
+ span=_create_span(f"{bn_si}.{input_name}"),
+ )
+ weight_name = f"{bn_si}.weight"
+ bn_weight = relay.var(
+ weight_name,
+ shape=(features,),
+ span=_create_span(weight_name),
+ )
+ bias_name = f"{bn_si}.bias"
+ bn_bias = relay.var(
+ bias_name,
+ shape=(features,),
+ span=_create_span(bias_name),
+ )
+ rm_name = f"{bn_si}.running_mean"
+ bn_rm = relay.var(
+ rm_name,
+ shape=(features,),
+ span=_create_span(rm_name),
+ )
+ rv_name = f"{bn_si}.running_var"
+ bn_rv = relay.var(
+ rv_name,
+ shape=(features,),
+ span=_create_span(rv_name),
+ )
+ bn_out = _set_span(
+ relay.nn.batch_norm(input_0, bn_weight, bn_bias, bn_rm, bn_rv),
+ bn_si,
+ )
+ bn_tuple_get_item =
_set_span(relay.TupleGetItem(bn_out.tuple_value, 0), bn_si)
+ return relay.Function([input_0, bn_weight, bn_bias, bn_rm, bn_rv],
bn_tuple_get_item)
+
+ self._verify(_res, _golden)
+
+ def test_reshape_span(self):
+ input_shape = [2, 1, 10, 1, 10]
+ new_shape = [2, 1, 10, 10]
+
+ def _res():
+ # model definition
+ class Reshape(Module):
+ def forward(self, *args):
+ return args[0].reshape(new_shape)
+
+ # get frontend model
+ mod = gen_ir_module(Reshape(), [torch.rand(input_shape)])
+ return mod["main"]
+
+ def _golden():
+ reshape_si = "aten::reshape_0"
+ input_name = "input0"
+ input_0 = relay.var(
+ input_name,
+ shape=tuple(input_shape),
+ span=_create_span(f"{reshape_si}.{input_name}"),
+ )
+ reshape_out = _set_span(
+ relay.reshape(input_0, newshape=new_shape),
+ reshape_si,
+ )
+ return relay.Function([input_0], reshape_out)
+
+ self._verify(_res, _golden)
+
+ def test_dense_bias_add(self):
+ in_f, out_f = 10, 7
+ input_shape = [in_f, in_f]
+
+ def _res():
+ # model definition
+ class Dense(Module):
+ def __init__(self):
+ super(Dense, self).__init__()
+ self.linear = torch.nn.Linear(in_f, out_f, bias=True)
+
+ def forward(self, *args):
+ return self.linear(args[0])
+
+ # get frontend model
+ mod = gen_ir_module(Dense(), [torch.rand(input_shape)])
+ return mod["main"]
+
+ def _golden():
+ dense_si = "aten::linear_0"
+ input_name = "input0"
+ input_0 = relay.var(
+ input_name,
+ shape=tuple(input_shape),
+ span=_create_span(f"{dense_si}.{input_name}"),
+ )
+ weight_name = f"{dense_si}.weight"
+ dense_weight = relay.var(
+ weight_name,
+ shape=(out_f, in_f),
+ span=_create_span(weight_name),
+ )
+ bias_name = f"{dense_si}.bias"
+ dense_bias = relay.var(
+ bias_name,
+ shape=(out_f,),
+ span=_create_span(bias_name),
+ )
+ dense_out = _set_span(
+ relay.nn.dense(input_0, dense_weight),
+ dense_si,
+ )
+ bias_out = _set_span(
+ relay.nn.bias_add(dense_out, dense_bias, axis=-1),
+ dense_si,
+ )
+ return relay.Function([input_0, dense_weight, dense_bias],
bias_out)
+
+ self._verify(_res, _golden)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/frontend/pytorch/test_fx_quant.py
b/tests/python/frontend/pytorch/test_fx_quant.py
index f35094a831..564900cbf2 100644
--- a/tests/python/frontend/pytorch/test_fx_quant.py
+++ b/tests/python/frontend/pytorch/test_fx_quant.py
@@ -23,6 +23,7 @@ from torch.quantization.quantize_fx import prepare_fx,
convert_fx
from torchvision.models.efficientnet import efficientnet_b4
from torchvision.models.resnet import resnet50
from tvm import relay
+import tvm.testing
def quantize(model):
@@ -38,7 +39,11 @@ def quantize_and_build(model, in_size):
with torch.no_grad():
script_module = torch.jit.trace(qmodel, inp)
- mod, _ = relay.frontend.from_pytorch(script_module, [(input_name,
inp.shape)])
+ with tvm.testing.disable_span_filling():
+ mod, _ = relay.frontend.from_pytorch(script_module, [(input_name,
inp.shape)])
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(script_module,
[(input_name, inp.shape)])
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
mod = relay.transform.InferType()(mod)
# Make sure that the model is quantized
diff --git a/tests/python/frontend/pytorch/test_lstm.py
b/tests/python/frontend/pytorch/test_lstm.py
index 25d4563ee6..e9dd2b380c 100644
--- a/tests/python/frontend/pytorch/test_lstm.py
+++ b/tests/python/frontend/pytorch/test_lstm.py
@@ -337,7 +337,11 @@ def test_custom_lstm():
for (name, raw_model, states, input_shapes) in models:
script_module = torch.jit.script(raw_model)
- mod, params = from_pytorch(script_module, input_shapes)
+ with tvm.testing.disable_span_filling():
+ mod, params = from_pytorch(script_module, input_shapes)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = from_pytorch(script_module, input_shapes)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
with torch.no_grad():
pt_result = raw_model(inp.clone(), states)
diff --git a/tests/python/frontend/pytorch/test_object_detection.py
b/tests/python/frontend/pytorch/test_object_detection.py
index 83b13f686b..25e784b00a 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -104,7 +104,11 @@ def test_detection_models():
shape_list = [(input_name, input_shape)]
scripted_model = generate_jit_model(1)
- mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(scripted_model,
shape_list)
+ assert tvm.ir.structural_equal(mod, mod_with_span, map_free_vars=True)
data = process_image(img)
data_np = data.detach().numpy()
diff --git a/tests/python/frontend/pytorch/test_rnns.py
b/tests/python/frontend/pytorch/test_rnns.py
index fba55b9c4c..3ea4232500 100644
--- a/tests/python/frontend/pytorch/test_rnns.py
+++ b/tests/python/frontend/pytorch/test_rnns.py
@@ -456,7 +456,15 @@ def check_rnn(rnn_type, rnn_mod,
target=tvm.target.Target("llvm -mcpu=core-avx2"
traced_script_module = torch.jit.trace(model,
dummy_inputs[0]).eval()
# Import model to Relay
- mod, params =
relay.frontend.from_pytorch(traced_script_module, shape_desc)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_pytorch(
+ traced_script_module, shape_desc
+ )
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ = relay.frontend.from_pytorch(
+ traced_script_module, shape_desc
+ )
+ assert tvm.ir.structural_equal(mod, mod_with_span,
map_free_vars=True)
elif format == "onnx":
try:
onnx_model = get_onnx_model(model)
@@ -468,7 +476,11 @@ def check_rnn(rnn_type, rnn_mod,
target=tvm.target.Target("llvm -mcpu=core-avx2"
continue
# Import model to Relay
- mod, params = relay.frontend.from_onnx(onnx_model,
shape_desc)
+ with tvm.testing.disable_span_filling():
+ mod, params = relay.frontend.from_onnx(onnx_model,
shape_desc)
+ with tvm.testing.enable_span_filling():
+ mod_with_span, _ =
relay.frontend.from_onnx(onnx_model, shape_desc)
+ assert tvm.ir.structural_equal(mod, mod_with_span,
map_free_vars=True)
# Model compilation by tvm
with tvm.transform.PassContext(opt_level=3):