This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 520f2c594b [Relay][Frontend] Span filling common API (#13402)
520f2c594b is described below

commit 520f2c594b3a4dde60ebdb10de448112130ab38a
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Dec 28 02:37:51 2022 +0800

    [Relay][Frontend] Span filling common API (#13402)
    
    - Expose and add span attribute of Expr-derived types from C++ to Python
    - Add common API of span filling
    - Add test cases of span filling
    - Add function to control whether to fill span via environment variable
    - Modify the way of pretty-print to print span
    
    Co-authored-by: Joey Tsai <[email protected]>
---
 python/tvm/relay/expr.py             | 202 +++++++++++++++++++++++++++++++----
 python/tvm/relay/frontend/common.py  | 165 +++++++++++++++++++++++++++-
 python/tvm/relay/function.py         |   7 +-
 python/tvm/relay/loops.py            |   2 +-
 python/tvm/testing/utils.py          |  22 ++++
 src/ir/span.cc                       |   4 +
 src/relay/ir/expr.cc                 |  88 ++++++++++++---
 src/relay/ir/function.cc             |   4 +-
 tests/python/frontend/test_common.py | 194 ++++++++++++++++++++++++++++++++-
 tests/python/relay/utils/tag_span.py | 108 +++++++++++++++++++
 10 files changed, 750 insertions(+), 46 deletions(-)

diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index fefc285723..88b84bbe7e 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -171,10 +171,28 @@ class Constant(ExprWithOp):
     ----------
     data : tvm.nd.NDArray
         The data content of the constant expression.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, data):
-        self.__init_handle_by_constructor__(_ffi_api.Constant, data)
+    def __init__(self, data, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)
+
+
+@tvm._ffi.register_func("relay.ConstantWithFields")
+def ConstantWithFields(
+    constant,
+    data=None,
+    virtual_device=None,
+    span=None,
+):
+    """
+    Returns constant with the given properties. A None property denotes 'no 
change'.
+    Returns constant if all properties are unchanged. Otherwise, returns a 
copy with the new
+    fields.
+    """
+    return _ffi_api.ConstantWithFields(constant, data, virtual_device, span)
 
 
 @tvm._ffi.register_object("relay.Tuple")
@@ -187,7 +205,7 @@ class Tuple(ExprWithOp):
         The fields in the tuple.
 
     span: Optional[tvm.relay.Span]
-        Span that points to original source code
+        Span that points to original source code.
     """
 
     def __init__(self, fields, span=None):
@@ -205,6 +223,16 @@ class Tuple(ExprWithOp):
         raise TypeError("astype cannot be used on tuple")
 
 
+@tvm._ffi.register_func("relay.TupleWithFields")
+def TupleWithFields(tup, fields=None, virtual_device=None, span=None):
+    """
+    Returns tuple with the given properties. A None property denotes 'no 
change'.
+    Returns tuple if all properties are unchanged. Otherwise, returns a copy 
with the new
+    fields.
+    """
+    return _ffi_api.TupleWithFields(tup, fields, virtual_device, span)
+
+
 @tvm._ffi.register_object("relay.Var")
 class Var(ExprWithOp):
     """A local variable in Relay.
@@ -221,10 +249,13 @@ class Var(ExprWithOp):
 
     type_annotation: tvm.relay.Type, optional
         The type annotation on the variable.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, name_hint, type_annotation=None):
-        self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, 
type_annotation)
+    def __init__(self, name_hint, type_annotation=None, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, 
type_annotation, span)
 
     @property
     def name_hint(self):
@@ -233,6 +264,16 @@ class Var(ExprWithOp):
         return name
 
 
+@tvm._ffi.register_func("relay.VarWithFields")
+def VarWithFields(variable, vid=None, type_annotation=None, 
virtual_device=None, span=None):
+    """
+    Returns var with the given properties. A None property denotes 'no change'.
+    Returns var if all properties are unchanged. Otherwise, returns a copy 
with the new
+    fields.
+    """
+    return _ffi_api.VarWithFields(variable, vid, type_annotation, 
virtual_device, span)
+
+
 @tvm._ffi.register_object("relay.Call")
 class Call(ExprWithOp):
     """Function call node in Relay.
@@ -256,7 +297,7 @@ class Call(ExprWithOp):
         used in advanced usecase of template functions.
 
     span: Optional[tvm.relay.Span]
-        Span that points to original source code
+        Span that points to original source code.
     """
 
     def __init__(self, op, args, attrs=None, type_args=None, span=None):
@@ -265,6 +306,18 @@ class Call(ExprWithOp):
         self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs, 
type_args, span)
 
 
+@tvm._ffi.register_func("relay.CallWithFields")
+def CallWithFields(
+    call, op=None, args=None, attrs=None, type_args=None, virtual_device=None, 
span=None
+):
+    """
+    Returns call with the given properties. A None property denotes 'no 
change'.
+    Returns call if all properties are unchanged. Otherwise, returns a copy 
with the new
+    fields.
+    """
+    return _ffi_api.CallWithFields(call, op, args, attrs, type_args, 
virtual_device, span)
+
+
 @tvm._ffi.register_object("relay.Let")
 class Let(ExprWithOp):
     """Let variable binding expression.
@@ -279,10 +332,23 @@ class Let(ExprWithOp):
 
     body: tvm.relay.Expr
         The body of the let binding.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, variable, value, body):
-        self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, 
body)
+    def __init__(self, variable, value, body, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.Let, variable, value, 
body, span)
+
+
+@tvm._ffi.register_func("relay.LetWithFields")
+def LetWithFields(let, variable=None, value=None, body=None, 
virtual_device=None, span=None):
+    """
+    Returns let with the given properties. A None property denotes 'no change'.
+    Returns let if all properties are unchanged. Otherwise, returns a copy 
with the new
+    fields.
+    """
+    return _ffi_api.LetWithFields(let, variable, value, body, virtual_device, 
span)
 
 
 @tvm._ffi.register_object("relay.If")
@@ -299,10 +365,25 @@ class If(ExprWithOp):
 
     false_branch: tvm.relay.Expr
         The expression evaluated when condition is false.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, cond, true_branch, false_branch):
-        self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, 
false_branch)
+    def __init__(self, cond, true_branch, false_branch, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch, 
false_branch, span)
+
+
+@tvm._ffi.register_func("relay.IfWithFields")
+def IfWithFields(
+    if_expr, cond=None, true_branch=None, false_branch=None, 
virtual_device=None, span=None
+):
+    """
+    Returns if with the given properties. A None property denotes 'no change'.
+    Returns if if all properties are unchanged. Otherwise, returns a copy with 
the new
+    fields.
+    """
+    return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch, 
virtual_device, span)
 
 
 @tvm._ffi.register_object("relay.TupleGetItem")
@@ -316,10 +397,25 @@ class TupleGetItem(ExprWithOp):
 
     index: int
         The index.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, tuple_value, index):
-        self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, 
tuple_value, index)
+    def __init__(self, tuple_value, index, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, 
tuple_value, index, span)
+
+
+@tvm._ffi.register_func("relay.TupleGetItemWithFields")
+def TupleGetItemWithFields(
+    tuple_get_item, tuple_value=None, index=None, virtual_device=None, 
span=None
+):
+    """
+    Returns tuple_get_item with the given properties. A None property denotes 
'no change'.
+    Returns tuple_get_item if all properties are unchanged. Otherwise, returns 
a copy with the new
+    fields.
+    """
+    return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index, 
virtual_device, span)
 
 
 @tvm._ffi.register_object("relay.RefCreate")
@@ -329,10 +425,28 @@ class RefCreate(ExprWithOp):
     ----------
     value: tvm.relay.Expr
        The initial value.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, value):
-        self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
+    def __init__(self, value, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)
+
+
+@tvm._ffi.register_func("relay.RefCreateWithFields")
+def RefCreateWithFields(
+    ref_create,
+    value=None,
+    virtual_device=None,
+    span=None,
+):
+    """
+    Returns ref_create with the given properties. A None property denotes 'no 
change'.
+    Returns ref_create if all properties are unchanged. Otherwise, returns a 
copy with the new
+    fields.
+    """
+    return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device, 
span)
 
 
 @tvm._ffi.register_object("relay.RefRead")
@@ -342,10 +456,28 @@ class RefRead(ExprWithOp):
     ----------
     ref: tvm.relay.Expr
          The reference.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, ref):
-        self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
+    def __init__(self, ref, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)
+
+
+@tvm._ffi.register_func("relay.RefReadWithFields")
+def RefReadWithFields(
+    ref_read,
+    ref=None,
+    virtual_device=None,
+    span=None,
+):
+    """
+    Returns ref_read with the given properties. A None property denotes 'no 
change'.
+    Returns ref_read if all properties are unchanged. Otherwise, returns a 
copy with the new
+    fields.
+    """
+    return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span)
 
 
 @tvm._ffi.register_object("relay.RefWrite")
@@ -357,12 +489,32 @@ class RefWrite(ExprWithOp):
     ----------
     ref: tvm.relay.Expr
         The reference.
+
     value: tvm.relay.Expr
         The new value.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, ref, value):
-        self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
+    def __init__(self, ref, value, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value, 
span)
+
+
+@tvm._ffi.register_func("relay.RefWriteWithFields")
+def RefWriteWithFields(
+    ref_write,
+    ref=None,
+    value=None,
+    virtual_device=None,
+    span=None,
+):
+    """
+    Returns ref_write with the given properties. A None property denotes 'no 
change'.
+    Returns ref_write if all properties are unchanged. Otherwise, returns a 
copy with the new
+    fields.
+    """
+    return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device, 
span)
 
 
 class TempExpr(ExprWithOp):
@@ -433,7 +585,7 @@ class TupleWrapper(object):
         raise TypeError("astype cannot be used on tuple")
 
 
-def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
+def var(name_hint, type_annotation=None, shape=None, dtype="float32", 
span=None):
     """Create a new tvm.relay.Var.
 
     This is a simple wrapper function that allows specify
@@ -456,6 +608,9 @@ def var(name_hint, type_annotation=None, shape=None, 
dtype="float32"):
     dtype: str, optional
         The data type of the tensor.
 
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
+
     Examples
     --------
     .. code-block:: python
@@ -476,10 +631,10 @@ def var(name_hint, type_annotation=None, shape=None, 
dtype="float32"):
         type_annotation = _ty.TensorType(shape, dtype)
     elif isinstance(type_annotation, str):
         type_annotation = _ty.TensorType((), type_annotation)
-    return Var(name_hint, type_annotation)
+    return Var(name_hint, type_annotation, span)
 
 
-def const(value, dtype=None):
+def const(value, dtype=None, span=None):
     """Create a constant value.
 
     Parameters
@@ -490,6 +645,9 @@ def const(value, dtype=None):
     dtype: str, optional
         The data type of the resulting constant.
 
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
+
     Note
     ----
     When dtype is None, we use the following rule:
@@ -516,7 +674,7 @@ def const(value, dtype=None):
     if not isinstance(value, _nd.NDArray):
         raise ValueError("value has to be scalar or NDArray")
 
-    return Constant(value)
+    return Constant(value, span)
 
 
 def bind(expr, binds):
diff --git a/python/tvm/relay/frontend/common.py 
b/python/tvm/relay/frontend/common.py
index 660426fb4a..5d3b0a3345 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -24,6 +24,7 @@ import tvm
 from tvm.ir import IRModule
 from tvm.topi.utils import get_const_tuple
 
+from ..expr_functor import ExprMutator
 from .. import expr as _expr
 from .. import function as _function
 from .. import transform as _transform
@@ -304,13 +305,16 @@ class ExprTable(object):
         self.const_ctr = 1
         self.in_padding = False
 
-    def new_const(self, value, shape=None, dtype="float32"):
+    def new_const(self, value, shape=None, dtype="float32", source_name=None):
+        """Construct a new var expr and add to exprs dictionary"""
         name = "_param_%d" % (self.const_ctr)
         if hasattr(value, "shape"):
             shape = value.shape
         self.const_ctr += 1
         self.params[name] = value
         self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
+        if source_name:
+            self.exprs[name] = set_span(self.exprs[name], source_name)
         return self.exprs[name]
 
     def get_expr(self, name):
@@ -1048,3 +1052,162 @@ def try_resolve_var_to_const(x, graph_params):
         return _op.const(value, dtype)
 
     return x
+
+
+class _SpanFiller(ExprMutator):
+    """SpanFiller"""
+
+    def __init__(self, span):
+        ExprMutator.__init__(self)
+        if isinstance(span, tvm.relay.Span):
+            self._span = span
+        elif isinstance(span, str):
+            self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0)
+        elif isinstance(span, bytes):
+            self._span = 
tvm.relay.Span(tvm.relay.SourceName(span.decode("utf-8")), 0, 0, 0, 0)
+        else:
+            assert False, f"unsupported span type: {type(span)}"
+
+    def visit(self, expr):
+        if hasattr(expr, "span") and expr.span:
+            return expr
+
+        return super().visit(expr)
+
+    def visit_function(self, fn):
+        new_params = [self.visit(x) for x in fn.params]
+        new_body = self.visit(fn.body)
+        return _function.FunctionWithFields(
+            fn, list(new_params), new_body, fn.ret_type, fn.type_params, 
fn.attrs, None, self._span
+        )
+
+    def visit_let(self, let):
+        new_variable = self.visit(let.var)
+        new_value = self.visit(let.value)
+        new_body = self.visit(let.body)
+        return _expr.LetWithFields(let, new_variable, new_value, new_body, 
None, self._span)
+
+    def visit_call(self, call):
+        new_args = [self.visit(arg) for arg in call.args]
+        # call.op might be RelayExpr or Op type
+        # ExprMutator will return directly if subject belongs to Op type
+        new_op = self.visit(call.op)
+        return _expr.CallWithFields(
+            call, new_op, new_args, call.attrs, call.type_args, None, 
self._span
+        )
+
+    def visit_var(self, var):
+        return _expr.VarWithFields(var, var.vid, var.type_annotation, None, 
self._span)
+
+    def visit_if(self, ite):
+        return _expr.IfWithFields(
+            ite,
+            self.visit(ite.cond),
+            self.visit(ite.true_branch),
+            self.visit(ite.false_branch),
+            None,
+            self._span,
+        )
+
+    def visit_tuple(self, tup):
+        return _expr.TupleWithFields(
+            tup, [self.visit(field) for field in tup.fields], None, self._span
+        )
+
+    def visit_tuple_getitem(self, op):
+        return _expr.TupleGetItemWithFields(
+            op, self.visit(op.tuple_value), op.index, None, self._span
+        )
+
+    def visit_constant(self, const):
+        return _expr.ConstantWithFields(const, const.data, None, self._span)
+
+    # TODO: Frontend model translation could not use following relay 
expressions so far,
+    #       enable them when new models/impls leverage these kinds of relay 
expressions.
+    def visit_ref_create(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_write(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_read(self, _):
+        raise NotImplementedError()
+
+    def visit_match(self, _):
+        raise NotImplementedError()
+
+    def fill(self, sym):
+        """Fill span to sym when it is an expr, or return it without change
+
+        Parameters
+        ----------
+        sym :
+            A symbol which is generated from the conversion of a frontend 
operator.
+
+        Returns
+        -------
+        sym:
+            A expr with span-filled or the original sym.
+        """
+        if isinstance(sym, _expr.TupleWrapper):
+            return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
+        elif isinstance(sym, _expr.RelayExpr):
+            return self.visit(sym)
+        elif isinstance(sym, list):
+            assert all(
+                isinstance(expr, _expr.RelayExpr) for expr in sym
+            ), f"unexpected relay expressions in {sym}"
+            return [self.visit(expr) for expr in sym]
+        elif isinstance(sym, tuple):
+            # some op conversion may return dummy elements
+            # e.g. op in frontend/pytorch.py: min_max_common
+            assert all(
+                isinstance(expr, (_expr.RelayExpr, type(None))) for expr in sym
+            ), f"unexpected relay expressions in {sym}"
+            return tuple(self.visit(expr) if expr else None for expr in sym)
+        elif isinstance(sym, (float, int)):
+            return sym
+        elif isinstance(sym, np.ndarray):
+            return sym
+
+        raise RuntimeError(f"unsupported type {type(sym)}")
+
+
+def set_span(sym, span):
+    """
+    Recursively tag the span to the symbol. Stop when it encounters a 
span-tagged expr. Disabled
+    when setting the "relay.frontend.fill_span" as False to the config of 
PassContext
+
+    Parameters
+    ----------
+    sym :
+        A symbol is generated from the conversion of a frontend operator. 
Raise an error when the
+        type of the symbol is not supported.
+
+    span : String, Span, or bytes
+        The source information of the corresponding symbol.
+
+    Returns
+    -------
+    result :
+        The symbol tagged with span.
+
+    Examples
+    --------
+    .. code-block:: python
+
+      x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+      w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+      y = set_span(
+          relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 
1)), "conv2d"
+      )
+      print(relay.Function([x], y))
+
+      #fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {
+      #  nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, ...) /* 
span=conv2d:0:0 */
+      #}
+    """
+
+    if 
tvm.transform.PassContext.current().config.get("relay.frontend.fill_span", 
True):
+        return _SpanFiller(span).fill(sym)
+    return sym
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index 6b3513cb5e..68d8953900 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -44,14 +44,17 @@ class Function(BaseFunc):
     type_params: Optional[List[tvm.relay.TypeParam]]
         The additional type parameters, this is only
         used in advanced usecase of template functions.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code.
     """
 
-    def __init__(self, params, body, ret_type=None, type_params=None, 
attrs=None):
+    def __init__(self, params, body, ret_type=None, type_params=None, 
attrs=None, span=None):
         if type_params is None:
             type_params = convert([])
 
         self.__init_handle_by_constructor__(
-            _ffi_api.Function, params, body, ret_type, type_params, attrs
+            _ffi_api.Function, params, body, ret_type, type_params, attrs, span
         )
 
     def __call__(self, *args):
diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py
index 6c2ab2e23d..d46e34860f 100644
--- a/python/tvm/relay/loops.py
+++ b/python/tvm/relay/loops.py
@@ -54,7 +54,7 @@ def while_loop(cond, loop_vars, loop_bodies):
 
     for i, loop_var in enumerate(loop_vars):
         name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else 
"arg{}".format(i)
-        new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
+        new_var = _expr.var(name, type_annotation=sb.type_of(loop_var), 
span=loop_var.span)
         fresh_vars.append(new_var)
 
     with sb.if_scope(cond(*fresh_vars)):
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 74ca326bca..899b054403 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -2081,3 +2081,25 @@ class CompareBeforeAfter:
                 f"or an instance of `tvm.tir.PrimFunc`.  "
                 f"Instead, received {type(expected)}."
             )
+
+
+class _control_span_filling:
+    def __init__(self, on=True):
+        self._on = on
+        self._pass_ctx = 
tvm.transform.PassContext(config={"relay.frontend.fill_span": self._on})
+
+    def __enter__(self):
+        self._pass_ctx.__enter__()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._pass_ctx.__exit__(exc_type, exc_val, exc_tb)
+
+
+class enable_span_filling(_control_span_filling):
+    def __init__(self):
+        super().__init__()
+
+
+class disable_span_filling(_control_span_filling):
+    def __init__(self):
+        super().__init__(on=False)
diff --git a/src/ir/span.cc b/src/ir/span.cc
index e19bef4cb8..39f0044d16 100644
--- a/src/ir/span.cc
+++ b/src/ir/span.cc
@@ -20,13 +20,17 @@
  * \file span.cc
  * \brief The span data structure.
  */
+#include <tvm/ir/expr.h>
 #include <tvm/ir/span.h>
+#include <tvm/ir/transform.h>
 #include <tvm/runtime/registry.h>
 
 #include <algorithm>
 
 namespace tvm {
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.frontend.fill_span", Bool);
+
 ObjectPtr<Object> GetSourceNameNode(const String& name) {
   // always return pointer as the reference can change as map re-allocate.
   // or use another level of indirection by creating a unique_ptr
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 5c85b3b29d..062d9206cf 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -72,9 +72,14 @@ Constant::Constant(runtime::NDArray data, Span span) {
 
 TVM_REGISTER_NODE_TYPE(ConstantNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray 
data) {
-  return Constant(data);
+TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray 
data, Span span) {
+  return Constant(data, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields")
+    .set_body_typed([](Constant constant, Optional<runtime::NDArray> opt_data,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(constant, opt_data, opt_virtual_device, opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -129,6 +134,11 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
 
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> 
fields, Span span) {
   return Tuple(fields, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields")
+    .set_body_typed([](Tuple tuple, Optional<Array<Expr>> opt_fields,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(tuple, opt_fields, opt_virtual_device, opt_span);
+    });
 
 Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
                  Optional<VirtualDevice> opt_virtual_device, Optional<Span> 
opt_span) {
@@ -200,9 +210,14 @@ Var WithFields(Var var, Optional<Id> opt_vid, 
Optional<Type> opt_type_annotation
 
 TVM_REGISTER_NODE_TYPE(VarNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type 
type_annotation) {
-  return Var(str, type_annotation);
+TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type 
type_annotation, Span span) {
+  return Var(str, type_annotation, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.VarWithFields")
+    .set_body_typed([](Var var, Optional<Id> opt_vid, Optional<Type> 
opt_type_annotation,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device, 
opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -278,6 +293,13 @@ TVM_REGISTER_GLOBAL("relay.ir.Call")
     .set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> 
type_args, Span span) {
       return Call(op, args, attrs, type_args, span);
     });
+TVM_REGISTER_GLOBAL("relay.ir.CallWithFields")
+    .set_body_typed([](Call call, Optional<Expr> opt_op, Optional<Array<Expr>> 
opt_args,
+                       Optional<Attrs> opt_attrs, Optional<Array<Type>> 
opt_type_args,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args, 
opt_virtual_device,
+                        opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -320,9 +342,15 @@ Let WithFields(Let let, Optional<Var> opt_var, 
Optional<Expr> opt_value, Optiona
 
 TVM_REGISTER_NODE_TYPE(LetNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, 
Expr body) {
-  return Let(var, value, body);
+TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, 
Expr body, Span span) {
+  return Let(var, value, body, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.LetWithFields")
+    .set_body_typed([](Let let, Optional<Var> opt_var, Optional<Expr> 
opt_value,
+                       Optional<Expr> opt_body, Optional<VirtualDevice> 
opt_virtual_device,
+                       Optional<Span> opt_span) {
+      return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device, 
opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -367,8 +395,15 @@ If WithFields(If if_expr, Optional<Expr> opt_cond, 
Optional<Expr> opt_true_branc
 TVM_REGISTER_NODE_TYPE(IfNode);
 
 TVM_REGISTER_GLOBAL("relay.ir.If")
-    .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) {
-      return If(cond, true_branch, false_branch);
+    .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span 
span) {
+      return If(cond, true_branch, false_branch, span);
+    });
+TVM_REGISTER_GLOBAL("relay.ir.IfWithFields")
+    .set_body_typed([](If if_expr, Optional<Expr> opt_cond, Optional<Expr> 
opt_true_branch,
+                       Optional<Expr> opt_false_branch, 
Optional<VirtualDevice> opt_virtual_device,
+                       Optional<Span> opt_span) {
+      return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch, 
opt_virtual_device,
+                        opt_span);
     });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -410,9 +445,15 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, 
Optional<Expr> opt_tuple,
 
 TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int 
index) {
-  return TupleGetItem(tuple, index);
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int 
index, Span span) {
+  return TupleGetItem(tuple, index, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields")
+    .set_body_typed([](TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
+                       Optional<Integer> opt_index, Optional<VirtualDevice> 
opt_virtual_device,
+                       Optional<Span> opt_span) {
+      return WithFields(tuple_get_item, opt_tuple, opt_index, 
opt_virtual_device, opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -448,9 +489,14 @@ RefCreate WithFields(RefCreate ref_create, Optional<Expr> 
opt_value,
 
 TVM_REGISTER_NODE_TYPE(RefCreateNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) {
-  return RefCreate(value);
+TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span 
span) {
+  return RefCreate(value, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields")
+    .set_body_typed([](RefCreate ref_create, Optional<Expr> opt_value,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(ref_create, opt_value, opt_virtual_device, opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -486,7 +532,14 @@ RefRead WithFields(RefRead ref_read, Optional<Expr> 
opt_ref,
 
 TVM_REGISTER_NODE_TYPE(RefReadNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return 
RefRead(ref); });
+TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span) 
{
+  return RefRead(ref, span);
+});
+TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields")
+    .set_body_typed([](RefRead ref_read, Optional<Expr> opt_ref,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -525,9 +578,14 @@ RefWrite WithFields(RefWrite ref_write, Optional<Expr> 
opt_ref, Optional<Expr> o
 
 TVM_REGISTER_NODE_TYPE(RefWriteNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr 
value) {
-  return RefWrite(ref, value);
+TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr 
value, Span span) {
+  return RefWrite(ref, value, span);
 });
+TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields")
+    .set_body_typed([](RefWrite ref_write, Optional<Expr> opt_ref, 
Optional<Expr> opt_value,
+                       Optional<VirtualDevice> opt_virtual_device, 
Optional<Span> opt_span) {
+      return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device, 
opt_span);
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 1a3db9974f..07cfb27b1d 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -124,8 +124,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
 
 TVM_REGISTER_GLOBAL("relay.ir.Function")
     .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
-                       tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
-      return Function(params, body, ret_type, ty_params, attrs);
+                       tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs, 
Span span) {
+      return Function(params, body, ret_type, ty_params, attrs, span);
     });
 TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields")
     .set_body_typed([](Function function, Optional<Array<Var>> opt_params, 
Optional<Expr> opt_body,
diff --git a/tests/python/frontend/test_common.py 
b/tests/python/frontend/test_common.py
index e706f2af30..2b35ae71f2 100644
--- a/tests/python/frontend/test_common.py
+++ b/tests/python/frontend/test_common.py
@@ -14,7 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from tvm.relay.frontend.common import StrAttrsDict
+
+import numpy as np
+
+from tvm import relay, testing, transform
+from tvm.relay.frontend.common import StrAttrsDict, set_span
+from relay.utils.tag_span import _set_span, _create_span, 
_verify_structural_equal_with_span
 
 
 def test_key_is_present():
@@ -27,6 +32,189 @@ def test_key_is_not_present():
     assert not attrs.has_attr("b")
 
 
+class TestSetSpan:
+    def test_pass_ctx_switch(self):
+        def _res(should_fill):
+            if should_fill:
+                with testing.enable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), 
"x_var")
+            else:
+                with testing.disable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), 
"x_var")
+
+        disable = relay.var("x", shape=(1, 64, 56, 56))
+        enable = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+
+        _verify_structural_equal_with_span(_res(False), disable)
+        _verify_structural_equal_with_span(_res(True), enable)
+
+    # Should tag all exprs without span, and stop when expr is span-tagged
+    def test_builtin_tuple(self):
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            return set_span(tuple([a, b]), "tuple")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", 
span=_create_span("tuple"))
+            return tuple([a, b])
+
+        res_tuple, golden_tuple = _res(), _golden()
+        assert len(res_tuple) == len(golden_tuple)
+        for i in range(len(res_tuple)):
+            _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
+
+    def test_builtin_list(self):
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            t_a = relay.TupleGetItem(t, 0)
+            t_b = relay.TupleGetItem(t, 1)
+            return set_span([t_a, t_b], "list")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", 
span=_create_span("list"))
+            t = relay.Tuple([a, b], span=_create_span("list"))
+            t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
+            t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
+            return [t_a, t_b]
+
+        res_list, golden_list = _res(), _golden()
+        assert len(res_list) == len(golden_list)
+        for i in range(len(res_list)):
+            _verify_structural_equal_with_span(res_list[i], golden_list[i])
+
+    def test_var(self):
+        x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+        x_expected = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+        _verify_structural_equal_with_span(x, x_expected)
+
+    def test_constant(self):
+        c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), 
"const_c")
+        c_expected = relay.const(
+            np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("const_c")
+        )
+        _verify_structural_equal_with_span(c, c_expected)
+
+    def test_call(self):
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("conv2d"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+    def test_tuple(self):
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = set_span(relay.Tuple([a, b]), "t")
+            return relay.Function([], t)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("t"))
+            t = relay.Tuple([a, b], span=_create_span("t"))
+            return relay.Function([], t)
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+    def test_tuple_getitem(self):
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            i = set_span(relay.TupleGetItem(t, 0), "i")
+            return relay.Function([], i)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("i"))
+            t = relay.Tuple([a, b], span=_create_span("i"))
+            i = relay.TupleGetItem(t, 0, span=_create_span("i"))
+            return relay.Function([], i)
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+    def test_let(self):
+        def _res():
+            x = set_span(relay.Var("x"), "x_var")
+            c_1 = relay.const(np.ones(10))
+            add = relay.add(x, x)
+            body = set_span(relay.Let(x, c_1, add), "let")
+
+            c_2 = set_span(relay.const(np.zeros(10)), "zeros")
+            y = set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.Var("x", span=_create_span("x_var"))
+            c_1 = relay.const(np.ones(10), span=_create_span("let"))
+            add = _set_span(relay.add(x, x), "let")
+            body = relay.Let(x, c_1, add, span=_create_span("let"))
+
+            c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
+            y = _set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+    def test_if(self):
+        def _res():
+            x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
+            y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
+            eq = relay.equal(x, y)
+
+            true_branch = set_span(relay.add(x, y), "true_branch")
+            false_branch = relay.subtract(x, y)
+            ife = set_span(relay.If(eq, true_branch, false_branch), "if")
+            return relay.Function([x, y], ife)
+
+        def _golden():
+            x = relay.var("x", shape=[], dtype="float32", 
span=_create_span("x_var"))
+            y = relay.var("y", shape=[], dtype="float32", 
span=_create_span("y_var"))
+            eq = _set_span(relay.equal(x, y), "if")
+
+            true_branch = _set_span(relay.add(x, y), "true_branch")
+            false_branch = _set_span(relay.subtract(x, y), "if")
+            ife = relay.If(eq, true_branch, false_branch, 
span=_create_span("if"))
+            return relay.Function([x, y], ife)
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+    def test_fn(self):
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1))
+            f = set_span(relay.Function([x], y), "func")
+            return f
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("func"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "func"
+            )
+            f = relay.Function([x], y, span=_create_span("func"))
+            return f
+
+        _verify_structural_equal_with_span(_res(), _golden())
+
+
 if __name__ == "__main__":
-    test_key_is_present()
-    test_key_is_present()
+    testing.main()
diff --git a/tests/python/relay/utils/tag_span.py 
b/tests/python/relay/utils/tag_span.py
new file mode 100644
index 0000000000..77042be602
--- /dev/null
+++ b/tests/python/relay/utils/tag_span.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay, tir
+from tvm.relay import expr as _expr
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def _set_span(expr, src):
+    if isinstance(expr, _expr.Call):
+        return _expr.CallWithFields(
+            expr, expr.op, expr.args, expr.attrs, expr.type_args, None, 
_create_span(src)
+        )
+    elif isinstance(expr, _expr.Var):
+        return _expr.VarWithFields(expr, expr.vid, expr.type_annotation, None, 
_create_span(src))
+    elif isinstance(expr, _expr.TupleGetItem):
+        return _expr.TupleGetItemWithFields(
+            expr, expr.tuple_value, expr.index, None, _create_span(src)
+        )
+    elif isinstance(expr, _expr.Constant):
+        return _expr.ConstantWithFields(expr, expr.data, None, 
_create_span(src))
+    elif isinstance(expr, _expr.Tuple):
+        return _expr.TupleWithFields(expr, expr.fields, None, 
_create_span(src))
+    elif isinstance(expr, _expr.TupleWrapper):
+        return _expr.TupleWrapper(_set_span(expr.tuple_value, src), expr.size)
+
+    assert False, f"unsupported type {type(expr)}"
+
+
+def _create_span(src):
+    if isinstance(src, list):
+        tmp_list = []
+        for s in src:
+            if isinstance(s, str):
+                tmp_list.append(_create_span(s))
+            elif isinstance(s, relay.Span):
+                tmp_list.append(s)
+            elif isinstance(s, relay.SequentialSpan):
+                tmp_list.extend(s.spans)
+            elif s is None:
+                tmp_list.append(s)
+            else:
+                assert False, f"unsupported type {type(s)}"
+        return relay.SequentialSpan(tmp_list)
+    return relay.Span(relay.SourceName(src), 0, 0, 0, 0)
+
+
+def _collect_spans(objref):
+    class Collector:
+        def __init__(self):
+            self._spans = []
+
+        def collect(self, objref):
+            if hasattr(objref, "span"):
+                self._spans.append(objref.span)
+
+        @property
+        def get_spans(self):
+            return self._spans
+
+    pov = None
+    if isinstance(objref, relay.Expr):
+        pov = relay.analysis.post_order_visit
+    elif isinstance(objref, (tir.Stmt, tir.expr.PrimExprWithOp)):
+        pov = tir.stmt_functor.post_order_visit
+    else:
+        assert False, f"unsupported type {type(objref)}"
+
+    c = Collector()
+    pov(objref, c.collect)
+    return c.get_spans
+
+
+def _verify_span(lhs, rhs):
+    lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs)
+
+    assert len(lhs_spans) == len(rhs_spans)
+
+    for i in range(len(lhs_spans)):
+        assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i])
+
+
+def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, 
map_free_vars=False):
+    if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var):
+        # SEqualReduce compares the vid of Var type. Threrfore we only compare 
span here.
+        _verify_span(lhs, rhs)
+        return
+
+    if assert_mode:
+        tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars)
+    else:
+        assert tvm.ir.structural_equal(lhs, rhs, map_free_vars)
+
+    _verify_span(lhs, rhs)


Reply via email to