This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 460f6f1d3e [QoL][Relax] Infer StructInfo for relax::Tuple on
construction (#16860)
460f6f1d3e is described below
commit 460f6f1d3e1625882df701252234350f83aa6da1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Apr 16 16:28:00 2024 -0500
[QoL][Relax] Infer StructInfo for relax::Tuple on construction (#16860)
Prior to this commit, the `relax::Tuple` constructor left the
`struct_info_` field undefined. This is inconsistent with other Relax
leaf nodes, such as `relax::PrimValue`, `relax::Constant`, and
`relax::ExternFunc`, which initialize their struct info on
construction.
This commit updates the `relax::Tuple` constructor to define
`struct_info_` as `TupleStructInfo`, if all fields have a known struct
info. If any field does not have a known struct info, the current
behavior is kept, where `struct_info_` is constructed as `NullOpt`,
and is later populated by the `relax::BlockBuilder`.
---
src/relax/ir/expr.cc | 16 ++++++++++++++++
tests/python/relax/test_expr.py | 19 +++++++++++++++++++
2 files changed, 35 insertions(+)
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 0530bb770b..dd0f68dca4 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -137,9 +137,25 @@ TVM_REGISTER_GLOBAL("relax.If")
});
Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
+ Optional<StructInfo> tuple_sinfo = [&]() -> Optional<StructInfo> {
+ Array<StructInfo> field_sinfo;
+ for (const auto& field : fields) {
+ if (field->struct_info_.defined()) {
+ field_sinfo.push_back(GetStructInfo(field));
+ } else {
+ return NullOpt;
+ }
+ }
+ return TupleStructInfo(field_sinfo);
+ }();
+
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
n->span = std::move(span);
+ if (tuple_sinfo) {
+ n->checked_type_ = GetStaticType(tuple_sinfo.value());
+ }
+ n->struct_info_ = tuple_sinfo;
data_ = std::move(n);
}
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index af1bc851be..b20c9ef2d9 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -86,6 +86,25 @@ def test_tuple() -> None:
t[-3]
+def test_tuple_sinfo_inferred_on_construction():
+ v0 = rx.Var("v0", rx.ObjectStructInfo())
+ v1 = rx.Var("v1", rx.ObjectStructInfo())
+ tup = rx.Tuple((v0, v1))
+
+ assert tup.struct_info_ is not None
+ tvm.ir.assert_structural_equal(
+ tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(),
rx.ObjectStructInfo()])
+ )
+
+
+def test_tuple_sinfo_requires_fields_with_known_sinfo():
+ v0 = rx.Var("v0", rx.ObjectStructInfo())
+ v1 = rx.Var("v1")
+ tup = rx.Tuple((v0, v1))
+
+ assert tup.struct_info_ is None
+
+
def test_match_cast() -> None:
# match_cast([16, 8], [m, n])
m = tir.Var("m", dtype="int64")