This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 520f2c594b [Relay][Frontend] Span filling common API (#13402)
520f2c594b is described below
commit 520f2c594b3a4dde60ebdb10de448112130ab38a
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Dec 28 02:37:51 2022 +0800
[Relay][Frontend] Span filling common API (#13402)
- Expose and add span attribute of Expr-derived types from C++ to Python
- Add common API of span filling
- Add test cases of span filling
- Add function to control whether to fill span via environment variable
- Modify the way of pretty-print to print span
Co-authored-by: Joey Tsai <[email protected]>
---
python/tvm/relay/expr.py | 202 +++++++++++++++++++++++++++++++----
python/tvm/relay/frontend/common.py | 165 +++++++++++++++++++++++++++-
python/tvm/relay/function.py | 7 +-
python/tvm/relay/loops.py | 2 +-
python/tvm/testing/utils.py | 22 ++++
src/ir/span.cc | 4 +
src/relay/ir/expr.cc | 88 ++++++++++++---
src/relay/ir/function.cc | 4 +-
tests/python/frontend/test_common.py | 194 ++++++++++++++++++++++++++++++++-
tests/python/relay/utils/tag_span.py | 108 +++++++++++++++++++
10 files changed, 750 insertions(+), 46 deletions(-)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index fefc285723..88b84bbe7e 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -171,10 +171,28 @@ class Constant(ExprWithOp):
----------
data : tvm.nd.NDArray
The data content of the constant expression.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, data):
- self.__init_handle_by_constructor__(_ffi_api.Constant, data)
+ def __init__(self, data, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.Constant, data, span)
+
+
+@tvm._ffi.register_func("relay.ConstantWithFields")
+def ConstantWithFields(
+ constant,
+ data=None,
+ virtual_device=None,
+ span=None,
+):
+ """
+ Returns constant with the given properties. A None property denotes 'no
change'.
+ Returns constant if all properties are unchanged. Otherwise, returns a
copy with the new
+ fields.
+ """
+ return _ffi_api.ConstantWithFields(constant, data, virtual_device, span)
@tvm._ffi.register_object("relay.Tuple")
@@ -187,7 +205,7 @@ class Tuple(ExprWithOp):
The fields in the tuple.
span: Optional[tvm.relay.Span]
- Span that points to original source code
+ Span that points to original source code.
"""
def __init__(self, fields, span=None):
@@ -205,6 +223,16 @@ class Tuple(ExprWithOp):
raise TypeError("astype cannot be used on tuple")
+@tvm._ffi.register_func("relay.TupleWithFields")
+def TupleWithFields(tup, fields=None, virtual_device=None, span=None):
+ """
+ Returns tuple with the given properties. A None property denotes 'no
change'.
+ Returns tuple if all properties are unchanged. Otherwise, returns a copy
with the new
+ fields.
+ """
+ return _ffi_api.TupleWithFields(tup, fields, virtual_device, span)
+
+
@tvm._ffi.register_object("relay.Var")
class Var(ExprWithOp):
"""A local variable in Relay.
@@ -221,10 +249,13 @@ class Var(ExprWithOp):
type_annotation: tvm.relay.Type, optional
The type annotation on the variable.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, name_hint, type_annotation=None):
- self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
type_annotation)
+ def __init__(self, name_hint, type_annotation=None, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
type_annotation, span)
@property
def name_hint(self):
@@ -233,6 +264,16 @@ class Var(ExprWithOp):
return name
+@tvm._ffi.register_func("relay.VarWithFields")
+def VarWithFields(variable, vid=None, type_annotation=None,
virtual_device=None, span=None):
+ """
+ Returns var with the given properties. A None property denotes 'no change'.
+ Returns var if all properties are unchanged. Otherwise, returns a copy
with the new
+ fields.
+ """
+ return _ffi_api.VarWithFields(variable, vid, type_annotation,
virtual_device, span)
+
+
@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
@@ -256,7 +297,7 @@ class Call(ExprWithOp):
used in advanced usecase of template functions.
span: Optional[tvm.relay.Span]
- Span that points to original source code
+ Span that points to original source code.
"""
def __init__(self, op, args, attrs=None, type_args=None, span=None):
@@ -265,6 +306,18 @@ class Call(ExprWithOp):
self.__init_handle_by_constructor__(_ffi_api.Call, op, args, attrs,
type_args, span)
+@tvm._ffi.register_func("relay.CallWithFields")
+def CallWithFields(
+ call, op=None, args=None, attrs=None, type_args=None, virtual_device=None,
span=None
+):
+ """
+ Returns call with the given properties. A None property denotes 'no
change'.
+ Returns call if all properties are unchanged. Otherwise, returns a copy
with the new
+ fields.
+ """
+ return _ffi_api.CallWithFields(call, op, args, attrs, type_args,
virtual_device, span)
+
+
@tvm._ffi.register_object("relay.Let")
class Let(ExprWithOp):
"""Let variable binding expression.
@@ -279,10 +332,23 @@ class Let(ExprWithOp):
body: tvm.relay.Expr
The body of the let binding.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, variable, value, body):
- self.__init_handle_by_constructor__(_ffi_api.Let, variable, value,
body)
+ def __init__(self, variable, value, body, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.Let, variable, value,
body, span)
+
+
+@tvm._ffi.register_func("relay.LetWithFields")
+def LetWithFields(let, variable=None, value=None, body=None,
virtual_device=None, span=None):
+ """
+ Returns let with the given properties. A None property denotes 'no change'.
+ Returns let if all properties are unchanged. Otherwise, returns a copy
with the new
+ fields.
+ """
+ return _ffi_api.LetWithFields(let, variable, value, body, virtual_device,
span)
@tvm._ffi.register_object("relay.If")
@@ -299,10 +365,25 @@ class If(ExprWithOp):
false_branch: tvm.relay.Expr
The expression evaluated when condition is false.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, cond, true_branch, false_branch):
- self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch,
false_branch)
+ def __init__(self, cond, true_branch, false_branch, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.If, cond, true_branch,
false_branch, span)
+
+
+@tvm._ffi.register_func("relay.IfWithFields")
+def IfWithFields(
+ if_expr, cond=None, true_branch=None, false_branch=None,
virtual_device=None, span=None
+):
+ """
+ Returns if with the given properties. A None property denotes 'no change'.
+ Returns if if all properties are unchanged. Otherwise, returns a copy with
the new
+ fields.
+ """
+ return _ffi_api.IfWithFields(if_expr, cond, true_branch, false_branch,
virtual_device, span)
@tvm._ffi.register_object("relay.TupleGetItem")
@@ -316,10 +397,25 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, tuple_value, index):
- self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index)
+ def __init__(self, tuple_value, index, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index, span)
+
+
+@tvm._ffi.register_func("relay.TupleGetItemWithFields")
+def TupleGetItemWithFields(
+ tuple_get_item, tuple_value=None, index=None, virtual_device=None,
span=None
+):
+ """
+ Returns tuple_get_item with the given properties. A None property denotes
'no change'.
+ Returns tuple_get_item if all properties are unchanged. Otherwise, returns
a copy with the new
+ fields.
+ """
+ return _ffi_api.TupleGetItemWithFields(tuple_get_item, tuple_value, index,
virtual_device, span)
@tvm._ffi.register_object("relay.RefCreate")
@@ -329,10 +425,28 @@ class RefCreate(ExprWithOp):
----------
value: tvm.relay.Expr
The initial value.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, value):
- self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
+ def __init__(self, value, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.RefCreate, value, span)
+
+
+@tvm._ffi.register_func("relay.RefCreateWithFields")
+def RefCreateWithFields(
+ ref_create,
+ value=None,
+ virtual_device=None,
+ span=None,
+):
+ """
+ Returns ref_create with the given properties. A None property denotes 'no
change'.
+ Returns ref_create if all properties are unchanged. Otherwise, returns a
copy with the new
+ fields.
+ """
+ return _ffi_api.RefCreateWithFields(ref_create, value, virtual_device,
span)
@tvm._ffi.register_object("relay.RefRead")
@@ -342,10 +456,28 @@ class RefRead(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, ref):
- self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
+ def __init__(self, ref, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.RefRead, ref, span)
+
+
+@tvm._ffi.register_func("relay.RefReadWithFields")
+def RefReadWithFields(
+ ref_read,
+ ref=None,
+ virtual_device=None,
+ span=None,
+):
+ """
+ Returns ref_read with the given properties. A None property denotes 'no
change'.
+ Returns ref_read if all properties are unchanged. Otherwise, returns a
copy with the new
+ fields.
+ """
+ return _ffi_api.RefReadWithFields(ref_read, ref, virtual_device, span)
@tvm._ffi.register_object("relay.RefWrite")
@@ -357,12 +489,32 @@ class RefWrite(ExprWithOp):
----------
ref: tvm.relay.Expr
The reference.
+
value: tvm.relay.Expr
The new value.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, ref, value):
- self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
+ def __init__(self, ref, value, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value,
span)
+
+
+@tvm._ffi.register_func("relay.RefWriteWithFields")
+def RefWriteWithFields(
+ ref_write,
+ ref=None,
+ value=None,
+ virtual_device=None,
+ span=None,
+):
+ """
+ Returns ref_write with the given properties. A None property denotes 'no
change'.
+ Returns ref_write if all properties are unchanged. Otherwise, returns a
copy with the new
+ fields.
+ """
+ return _ffi_api.RefWriteWithFields(ref_write, ref, value, virtual_device,
span)
class TempExpr(ExprWithOp):
@@ -433,7 +585,7 @@ class TupleWrapper(object):
raise TypeError("astype cannot be used on tuple")
-def var(name_hint, type_annotation=None, shape=None, dtype="float32"):
+def var(name_hint, type_annotation=None, shape=None, dtype="float32",
span=None):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
@@ -456,6 +608,9 @@ def var(name_hint, type_annotation=None, shape=None,
dtype="float32"):
dtype: str, optional
The data type of the tensor.
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
+
Examples
--------
.. code-block:: python
@@ -476,10 +631,10 @@ def var(name_hint, type_annotation=None, shape=None,
dtype="float32"):
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
- return Var(name_hint, type_annotation)
+ return Var(name_hint, type_annotation, span)
-def const(value, dtype=None):
+def const(value, dtype=None, span=None):
"""Create a constant value.
Parameters
@@ -490,6 +645,9 @@ def const(value, dtype=None):
dtype: str, optional
The data type of the resulting constant.
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
+
Note
----
When dtype is None, we use the following rule:
@@ -516,7 +674,7 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
- return Constant(value)
+ return Constant(value, span)
def bind(expr, binds):
diff --git a/python/tvm/relay/frontend/common.py
b/python/tvm/relay/frontend/common.py
index 660426fb4a..5d3b0a3345 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -24,6 +24,7 @@ import tvm
from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple
+from ..expr_functor import ExprMutator
from .. import expr as _expr
from .. import function as _function
from .. import transform as _transform
@@ -304,13 +305,16 @@ class ExprTable(object):
self.const_ctr = 1
self.in_padding = False
- def new_const(self, value, shape=None, dtype="float32"):
+ def new_const(self, value, shape=None, dtype="float32", source_name=None):
+ """Construct a new var expr and add to exprs dictionary"""
name = "_param_%d" % (self.const_ctr)
if hasattr(value, "shape"):
shape = value.shape
self.const_ctr += 1
self.params[name] = value
self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
+ if source_name:
+ self.exprs[name] = set_span(self.exprs[name], source_name)
return self.exprs[name]
def get_expr(self, name):
@@ -1048,3 +1052,162 @@ def try_resolve_var_to_const(x, graph_params):
return _op.const(value, dtype)
return x
+
+
+class _SpanFiller(ExprMutator):
+ """SpanFiller"""
+
+ def __init__(self, span):
+ ExprMutator.__init__(self)
+ if isinstance(span, tvm.relay.Span):
+ self._span = span
+ elif isinstance(span, str):
+ self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0)
+ elif isinstance(span, bytes):
+ self._span =
tvm.relay.Span(tvm.relay.SourceName(span.decode("utf-8")), 0, 0, 0, 0)
+ else:
+ assert False, f"unsupported span type: {type(span)}"
+
+ def visit(self, expr):
+ if hasattr(expr, "span") and expr.span:
+ return expr
+
+ return super().visit(expr)
+
+ def visit_function(self, fn):
+ new_params = [self.visit(x) for x in fn.params]
+ new_body = self.visit(fn.body)
+ return _function.FunctionWithFields(
+ fn, list(new_params), new_body, fn.ret_type, fn.type_params,
fn.attrs, None, self._span
+ )
+
+ def visit_let(self, let):
+ new_variable = self.visit(let.var)
+ new_value = self.visit(let.value)
+ new_body = self.visit(let.body)
+ return _expr.LetWithFields(let, new_variable, new_value, new_body,
None, self._span)
+
+ def visit_call(self, call):
+ new_args = [self.visit(arg) for arg in call.args]
+ # call.op might be RelayExpr or Op type
+ # ExprMutator will return directly if subject belongs to Op type
+ new_op = self.visit(call.op)
+ return _expr.CallWithFields(
+ call, new_op, new_args, call.attrs, call.type_args, None,
self._span
+ )
+
+ def visit_var(self, var):
+ return _expr.VarWithFields(var, var.vid, var.type_annotation, None,
self._span)
+
+ def visit_if(self, ite):
+ return _expr.IfWithFields(
+ ite,
+ self.visit(ite.cond),
+ self.visit(ite.true_branch),
+ self.visit(ite.false_branch),
+ None,
+ self._span,
+ )
+
+ def visit_tuple(self, tup):
+ return _expr.TupleWithFields(
+ tup, [self.visit(field) for field in tup.fields], None, self._span
+ )
+
+ def visit_tuple_getitem(self, op):
+ return _expr.TupleGetItemWithFields(
+ op, self.visit(op.tuple_value), op.index, None, self._span
+ )
+
+ def visit_constant(self, const):
+ return _expr.ConstantWithFields(const, const.data, None, self._span)
+
+ # TODO: Frontend model translation could not use following relay
expressions so far,
+ # enable them when new models/impls leverage these kinds of relay
expressions.
+ def visit_ref_create(self, _):
+ raise NotImplementedError()
+
+ def visit_ref_write(self, _):
+ raise NotImplementedError()
+
+ def visit_ref_read(self, _):
+ raise NotImplementedError()
+
+ def visit_match(self, _):
+ raise NotImplementedError()
+
+ def fill(self, sym):
+ """Fill span to sym when it is an expr, or return it without change
+
+ Parameters
+ ----------
+ sym :
+ A symbol which is generated from the conversion of a frontend
operator.
+
+ Returns
+ -------
+ sym:
+ A expr with span-filled or the original sym.
+ """
+ if isinstance(sym, _expr.TupleWrapper):
+ return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
+ elif isinstance(sym, _expr.RelayExpr):
+ return self.visit(sym)
+ elif isinstance(sym, list):
+ assert all(
+ isinstance(expr, _expr.RelayExpr) for expr in sym
+ ), f"unexpected relay expressions in {sym}"
+ return [self.visit(expr) for expr in sym]
+ elif isinstance(sym, tuple):
+ # some op conversion may return dummy elements
+ # e.g. op in frontend/pytorch.py: min_max_common
+ assert all(
+ isinstance(expr, (_expr.RelayExpr, type(None))) for expr in sym
+ ), f"unexpected relay expressions in {sym}"
+ return tuple(self.visit(expr) if expr else None for expr in sym)
+ elif isinstance(sym, (float, int)):
+ return sym
+ elif isinstance(sym, np.ndarray):
+ return sym
+
+ raise RuntimeError(f"unsupported type {type(sym)}")
+
+
+def set_span(sym, span):
+ """
+ Recursively tag the span to the symbol. Stop when it encounters a
span-tagged expr. Disabled
+ when setting the "relay.frontend.fill_span" as False to the config of
PassContext
+
+ Parameters
+ ----------
+ sym :
+ A symbol is generated from the conversion of a frontend operator.
Raise an error when the
+ type of the symbol is not supported.
+
+ span : String, Span, or bytes
+ The source information of the corresponding symbol.
+
+ Returns
+ -------
+ result :
+ The symbol tagged with span.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+ w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+ y = set_span(
+ relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1,
1)), "conv2d"
+ )
+ print(relay.Function([x], y))
+
+ #fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {
+ # nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, ...) /*
span=conv2d:0:0 */
+ #}
+ """
+
+ if
tvm.transform.PassContext.current().config.get("relay.frontend.fill_span",
True):
+ return _SpanFiller(span).fill(sym)
+ return sym
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index 6b3513cb5e..68d8953900 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -44,14 +44,17 @@ class Function(BaseFunc):
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code.
"""
- def __init__(self, params, body, ret_type=None, type_params=None,
attrs=None):
+ def __init__(self, params, body, ret_type=None, type_params=None,
attrs=None, span=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
- _ffi_api.Function, params, body, ret_type, type_params, attrs
+ _ffi_api.Function, params, body, ret_type, type_params, attrs, span
)
def __call__(self, *args):
diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py
index 6c2ab2e23d..d46e34860f 100644
--- a/python/tvm/relay/loops.py
+++ b/python/tvm/relay/loops.py
@@ -54,7 +54,7 @@ def while_loop(cond, loop_vars, loop_bodies):
for i, loop_var in enumerate(loop_vars):
name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else
"arg{}".format(i)
- new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
+ new_var = _expr.var(name, type_annotation=sb.type_of(loop_var),
span=loop_var.span)
fresh_vars.append(new_var)
with sb.if_scope(cond(*fresh_vars)):
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 74ca326bca..899b054403 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -2081,3 +2081,25 @@ class CompareBeforeAfter:
f"or an instance of `tvm.tir.PrimFunc`. "
f"Instead, received {type(expected)}."
)
+
+
+class _control_span_filling:
+ def __init__(self, on=True):
+ self._on = on
+ self._pass_ctx =
tvm.transform.PassContext(config={"relay.frontend.fill_span": self._on})
+
+ def __enter__(self):
+ self._pass_ctx.__enter__()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._pass_ctx.__exit__(exc_type, exc_val, exc_tb)
+
+
+class enable_span_filling(_control_span_filling):
+ def __init__(self):
+ super().__init__()
+
+
+class disable_span_filling(_control_span_filling):
+ def __init__(self):
+ super().__init__(on=False)
diff --git a/src/ir/span.cc b/src/ir/span.cc
index e19bef4cb8..39f0044d16 100644
--- a/src/ir/span.cc
+++ b/src/ir/span.cc
@@ -20,13 +20,17 @@
* \file span.cc
* \brief The span data structure.
*/
+#include <tvm/ir/expr.h>
#include <tvm/ir/span.h>
+#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <algorithm>
namespace tvm {
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.frontend.fill_span", Bool);
+
ObjectPtr<Object> GetSourceNameNode(const String& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 5c85b3b29d..062d9206cf 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -72,9 +72,14 @@ Constant::Constant(runtime::NDArray data, Span span) {
TVM_REGISTER_NODE_TYPE(ConstantNode);
-TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray
data) {
- return Constant(data);
+TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray
data, Span span) {
+ return Constant(data, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.ConstantWithFields")
+ .set_body_typed([](Constant constant, Optional<runtime::NDArray> opt_data,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(constant, opt_data, opt_virtual_device, opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -129,6 +134,11 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr>
fields, Span span) {
return Tuple(fields, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.TupleWithFields")
+ .set_body_typed([](Tuple tuple, Optional<Array<Expr>> opt_fields,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(tuple, opt_fields, opt_virtual_device, opt_span);
+ });
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
Optional<VirtualDevice> opt_virtual_device, Optional<Span>
opt_span) {
@@ -200,9 +210,14 @@ Var WithFields(Var var, Optional<Id> opt_vid,
Optional<Type> opt_type_annotation
TVM_REGISTER_NODE_TYPE(VarNode);
-TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type
type_annotation) {
- return Var(str, type_annotation);
+TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type
type_annotation, Span span) {
+ return Var(str, type_annotation, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.VarWithFields")
+ .set_body_typed([](Var var, Optional<Id> opt_vid, Optional<Type>
opt_type_annotation,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(var, opt_vid, opt_type_annotation, opt_virtual_device,
opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -278,6 +293,13 @@ TVM_REGISTER_GLOBAL("relay.ir.Call")
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type>
type_args, Span span) {
return Call(op, args, attrs, type_args, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.CallWithFields")
+ .set_body_typed([](Call call, Optional<Expr> opt_op, Optional<Array<Expr>>
opt_args,
+ Optional<Attrs> opt_attrs, Optional<Array<Type>>
opt_type_args,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(call, opt_op, opt_args, opt_attrs, opt_type_args,
opt_virtual_device,
+ opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -320,9 +342,15 @@ Let WithFields(Let let, Optional<Var> opt_var,
Optional<Expr> opt_value, Optiona
TVM_REGISTER_NODE_TYPE(LetNode);
-TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value,
Expr body) {
- return Let(var, value, body);
+TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value,
Expr body, Span span) {
+ return Let(var, value, body, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.LetWithFields")
+ .set_body_typed([](Let let, Optional<Var> opt_var, Optional<Expr>
opt_value,
+ Optional<Expr> opt_body, Optional<VirtualDevice>
opt_virtual_device,
+ Optional<Span> opt_span) {
+ return WithFields(let, opt_var, opt_value, opt_body, opt_virtual_device,
opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -367,8 +395,15 @@ If WithFields(If if_expr, Optional<Expr> opt_cond,
Optional<Expr> opt_true_branc
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay.ir.If")
- .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) {
- return If(cond, true_branch, false_branch);
+ .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch, Span
span) {
+ return If(cond, true_branch, false_branch, span);
+ });
+TVM_REGISTER_GLOBAL("relay.ir.IfWithFields")
+ .set_body_typed([](If if_expr, Optional<Expr> opt_cond, Optional<Expr>
opt_true_branch,
+ Optional<Expr> opt_false_branch,
Optional<VirtualDevice> opt_virtual_device,
+ Optional<Span> opt_span) {
+ return WithFields(if_expr, opt_cond, opt_true_branch, opt_false_branch,
opt_virtual_device,
+ opt_span);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -410,9 +445,15 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item,
Optional<Expr> opt_tuple,
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index) {
- return TupleGetItem(tuple, index);
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index, Span span) {
+ return TupleGetItem(tuple, index, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItemWithFields")
+ .set_body_typed([](TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
+ Optional<Integer> opt_index, Optional<VirtualDevice>
opt_virtual_device,
+ Optional<Span> opt_span) {
+ return WithFields(tuple_get_item, opt_tuple, opt_index,
opt_virtual_device, opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -448,9 +489,14 @@ RefCreate WithFields(RefCreate ref_create, Optional<Expr>
opt_value,
TVM_REGISTER_NODE_TYPE(RefCreateNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) {
- return RefCreate(value);
+TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value, Span
span) {
+ return RefCreate(value, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.RefCreateWithFields")
+ .set_body_typed([](RefCreate ref_create, Optional<Expr> opt_value,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(ref_create, opt_value, opt_virtual_device, opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -486,7 +532,14 @@ RefRead WithFields(RefRead ref_read, Optional<Expr>
opt_ref,
TVM_REGISTER_NODE_TYPE(RefReadNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return
RefRead(ref); });
+TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref, Span span)
{
+ return RefRead(ref, span);
+});
+TVM_REGISTER_GLOBAL("relay.ir.RefReadWithFields")
+ .set_body_typed([](RefRead ref_read, Optional<Expr> opt_ref,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(ref_read, opt_ref, opt_virtual_device, opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
@@ -525,9 +578,14 @@ RefWrite WithFields(RefWrite ref_write, Optional<Expr>
opt_ref, Optional<Expr> o
TVM_REGISTER_NODE_TYPE(RefWriteNode);
-TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr
value) {
- return RefWrite(ref, value);
+TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr
value, Span span) {
+ return RefWrite(ref, value, span);
});
+TVM_REGISTER_GLOBAL("relay.ir.RefWriteWithFields")
+ .set_body_typed([](RefWrite ref_write, Optional<Expr> opt_ref,
Optional<Expr> opt_value,
+ Optional<VirtualDevice> opt_virtual_device,
Optional<Span> opt_span) {
+ return WithFields(ref_write, opt_ref, opt_value, opt_virtual_device,
opt_span);
+ });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 1a3db9974f..07cfb27b1d 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -124,8 +124,8 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay.ir.Function")
.set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
- tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
- return Function(params, body, ret_type, ty_params, attrs);
+ tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs,
Span span) {
+ return Function(params, body, ret_type, ty_params, attrs, span);
});
TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields")
.set_body_typed([](Function function, Optional<Array<Var>> opt_params,
Optional<Expr> opt_body,
diff --git a/tests/python/frontend/test_common.py
b/tests/python/frontend/test_common.py
index e706f2af30..2b35ae71f2 100644
--- a/tests/python/frontend/test_common.py
+++ b/tests/python/frontend/test_common.py
@@ -14,7 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from tvm.relay.frontend.common import StrAttrsDict
+
+import numpy as np
+
+from tvm import relay, testing, transform
+from tvm.relay.frontend.common import StrAttrsDict, set_span
+from relay.utils.tag_span import _set_span, _create_span,
_verify_structural_equal_with_span
def test_key_is_present():
@@ -27,6 +32,189 @@ def test_key_is_not_present():
assert not attrs.has_attr("b")
+class TestSetSpan:
+ def test_pass_ctx_switch(self):
+ def _res(should_fill):
+ if should_fill:
+ with testing.enable_span_filling():
+ return set_span(relay.var("x", shape=(1, 64, 56, 56)),
"x_var")
+ else:
+ with testing.disable_span_filling():
+ return set_span(relay.var("x", shape=(1, 64, 56, 56)),
"x_var")
+
+ disable = relay.var("x", shape=(1, 64, 56, 56))
+ enable = relay.var("x", shape=(1, 64, 56, 56),
span=_create_span("x_var"))
+
+ _verify_structural_equal_with_span(_res(False), disable)
+ _verify_structural_equal_with_span(_res(True), enable)
+
+ # Should tag all exprs without span, and stop when expr is span-tagged
+ def test_builtin_tuple(self):
+ def _res():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+ return set_span(tuple([a, b]), "tuple")
+
+ def _golden():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.zeros([1, 1, 1]), dtype="int64",
span=_create_span("tuple"))
+ return tuple([a, b])
+
+ res_tuple, golden_tuple = _res(), _golden()
+ assert len(res_tuple) == len(golden_tuple)
+ for i in range(len(res_tuple)):
+ _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
+
+ def test_builtin_list(self):
+ def _res():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+ t = relay.Tuple([a, b])
+ t_a = relay.TupleGetItem(t, 0)
+ t_b = relay.TupleGetItem(t, 1)
+ return set_span([t_a, t_b], "list")
+
+ def _golden():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.zeros([1, 1, 1]), dtype="int64",
span=_create_span("list"))
+ t = relay.Tuple([a, b], span=_create_span("list"))
+ t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
+ t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
+ return [t_a, t_b]
+
+ res_list, golden_list = _res(), _golden()
+ assert len(res_list) == len(golden_list)
+ for i in range(len(res_list)):
+ _verify_structural_equal_with_span(res_list[i], golden_list[i])
+
+ def test_var(self):
+ x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+ x_expected = relay.var("x", shape=(1, 64, 56, 56),
span=_create_span("x_var"))
+ _verify_structural_equal_with_span(x, x_expected)
+
+ def test_constant(self):
+ c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"),
"const_c")
+ c_expected = relay.const(
+ np.ones([64, 64, 3, 3]), dtype="int64",
span=_create_span("const_c")
+ )
+ _verify_structural_equal_with_span(c, c_expected)
+
+ def test_call(self):
+ def _res():
+ x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+ w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+ y = set_span(
+ relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3),
padding=(1, 1)), "conv2d"
+ )
+ return relay.Function([x], y)
+
+ def _golden():
+ x = relay.var("x", shape=(1, 64, 56, 56),
span=_create_span("x_var"))
+ w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64",
span=_create_span("conv2d"))
+ y = _set_span(
+ relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3),
padding=(1, 1)), "conv2d"
+ )
+ return relay.Function([x], y)
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+ def test_tuple(self):
+ def _res():
+ a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+ b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+ t = set_span(relay.Tuple([a, b]), "t")
+ return relay.Function([], t)
+
+ def _golden():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("t"))
+ t = relay.Tuple([a, b], span=_create_span("t"))
+ return relay.Function([], t)
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+ def test_tuple_getitem(self):
+ def _res():
+ a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+ b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+ t = relay.Tuple([a, b])
+ i = set_span(relay.TupleGetItem(t, 0), "i")
+ return relay.Function([], i)
+
+ def _golden():
+ a = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("a"))
+ b = relay.const(np.ones([1, 1, 1]), dtype="int64",
span=_create_span("i"))
+ t = relay.Tuple([a, b], span=_create_span("i"))
+ i = relay.TupleGetItem(t, 0, span=_create_span("i"))
+ return relay.Function([], i)
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+ def test_let(self):
+ def _res():
+ x = set_span(relay.Var("x"), "x_var")
+ c_1 = relay.const(np.ones(10))
+ add = relay.add(x, x)
+ body = set_span(relay.Let(x, c_1, add), "let")
+
+ c_2 = set_span(relay.const(np.zeros(10)), "zeros")
+ y = set_span(relay.add(body, c_2), "add_2")
+ return relay.Function([x], y)
+
+ def _golden():
+ x = relay.Var("x", span=_create_span("x_var"))
+ c_1 = relay.const(np.ones(10), span=_create_span("let"))
+ add = _set_span(relay.add(x, x), "let")
+ body = relay.Let(x, c_1, add, span=_create_span("let"))
+
+ c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
+ y = _set_span(relay.add(body, c_2), "add_2")
+ return relay.Function([x], y)
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+ def test_if(self):
+ def _res():
+ x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
+ y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
+ eq = relay.equal(x, y)
+
+ true_branch = set_span(relay.add(x, y), "true_branch")
+ false_branch = relay.subtract(x, y)
+ ife = set_span(relay.If(eq, true_branch, false_branch), "if")
+ return relay.Function([x, y], ife)
+
+ def _golden():
+ x = relay.var("x", shape=[], dtype="float32",
span=_create_span("x_var"))
+ y = relay.var("y", shape=[], dtype="float32",
span=_create_span("y_var"))
+ eq = _set_span(relay.equal(x, y), "if")
+
+ true_branch = _set_span(relay.add(x, y), "true_branch")
+ false_branch = _set_span(relay.subtract(x, y), "if")
+ ife = relay.If(eq, true_branch, false_branch,
span=_create_span("if"))
+ return relay.Function([x, y], ife)
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+ def test_fn(self):
+ def _res():
+ x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+ w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+ y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3),
padding=(1, 1))
+ f = set_span(relay.Function([x], y), "func")
+ return f
+
+ def _golden():
+ x = relay.var("x", shape=(1, 64, 56, 56),
span=_create_span("x_var"))
+ w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64",
span=_create_span("func"))
+ y = _set_span(
+ relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3),
padding=(1, 1)), "func"
+ )
+ f = relay.Function([x], y, span=_create_span("func"))
+ return f
+
+ _verify_structural_equal_with_span(_res(), _golden())
+
+
if __name__ == "__main__":
- test_key_is_present()
- test_key_is_present()
+ testing.main()
diff --git a/tests/python/relay/utils/tag_span.py
b/tests/python/relay/utils/tag_span.py
new file mode 100644
index 0000000000..77042be602
--- /dev/null
+++ b/tests/python/relay/utils/tag_span.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay, tir
+from tvm.relay import expr as _expr
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def _set_span(expr, src):
+ if isinstance(expr, _expr.Call):
+ return _expr.CallWithFields(
+ expr, expr.op, expr.args, expr.attrs, expr.type_args, None,
_create_span(src)
+ )
+ elif isinstance(expr, _expr.Var):
+ return _expr.VarWithFields(expr, expr.vid, expr.type_annotation, None,
_create_span(src))
+ elif isinstance(expr, _expr.TupleGetItem):
+ return _expr.TupleGetItemWithFields(
+ expr, expr.tuple_value, expr.index, None, _create_span(src)
+ )
+ elif isinstance(expr, _expr.Constant):
+ return _expr.ConstantWithFields(expr, expr.data, None,
_create_span(src))
+ elif isinstance(expr, _expr.Tuple):
+ return _expr.TupleWithFields(expr, expr.fields, None,
_create_span(src))
+ elif isinstance(expr, _expr.TupleWrapper):
+ return _expr.TupleWrapper(_set_span(expr.tuple_value, src), expr.size)
+
+ assert False, f"unsupported type {type(expr)}"
+
+
+def _create_span(src):
+ if isinstance(src, list):
+ tmp_list = []
+ for s in src:
+ if isinstance(s, str):
+ tmp_list.append(_create_span(s))
+ elif isinstance(s, relay.Span):
+ tmp_list.append(s)
+ elif isinstance(s, relay.SequentialSpan):
+ tmp_list.extend(s.spans)
+ elif s is None:
+ tmp_list.append(s)
+ else:
+ assert False, f"unsupported type {type(s)}"
+ return relay.SequentialSpan(tmp_list)
+ return relay.Span(relay.SourceName(src), 0, 0, 0, 0)
+
+
+def _collect_spans(objref):
+ class Collector:
+ def __init__(self):
+ self._spans = []
+
+ def collect(self, objref):
+ if hasattr(objref, "span"):
+ self._spans.append(objref.span)
+
+ @property
+ def get_spans(self):
+ return self._spans
+
+ pov = None
+ if isinstance(objref, relay.Expr):
+ pov = relay.analysis.post_order_visit
+ elif isinstance(objref, (tir.Stmt, tir.expr.PrimExprWithOp)):
+ pov = tir.stmt_functor.post_order_visit
+ else:
+ assert False, f"unsupported type {type(objref)}"
+
+ c = Collector()
+ pov(objref, c.collect)
+ return c.get_spans
+
+
+def _verify_span(lhs, rhs):
+ lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs)
+
+ assert len(lhs_spans) == len(rhs_spans)
+
+ for i in range(len(lhs_spans)):
+ assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i])
+
+
+def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False,
map_free_vars=False):
+ if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var):
+ # SEqualReduce compares the vid of Var type. Threrfore we only compare
span here.
+ _verify_span(lhs, rhs)
+ return
+
+ if assert_mode:
+ tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars)
+ else:
+ assert tvm.ir.structural_equal(lhs, rhs, map_free_vars)
+
+ _verify_span(lhs, rhs)