This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 108a4e15b3 [Relax] Identify tuple unpack/repack in
CanonicalizeBindings (#17313)
108a4e15b3 is described below
commit 108a4e15b3c68fea2f803dc13b1b45291b00f15b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 28 18:29:18 2024 -0500
[Relax] Identify tuple unpack/repack in CanonicalizeBindings (#17313)
Prior to this commit, the `CanonicalizeBindings` pass could identify
and simplify a value that had been packed into a tuple, then
extracted from it. (e.g. Simplifying `tup = (x,y); z = tup[0]` into
`z = x`.) However, it could not identify a value that had been
expanded from a tuple, and then re-bundled. (e.g. Simplifying
`new_tuple = (tup[0], tup[1])` into `new_tuple = tup`.)
This commit updates `CanonicalizeBindings` to identify and remove
unnecessary tuple unpacking/repacking.
---
src/relax/transform/canonicalize_bindings.cc | 112 +++++++++++++++++----
.../relax/test_transform_canonicalize_bindings.py | 51 ++++++++++
2 files changed, 143 insertions(+), 20 deletions(-)
diff --git a/src/relax/transform/canonicalize_bindings.cc
b/src/relax/transform/canonicalize_bindings.cc
index d1a9f97337..807914075e 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor {
current_block_ = Optional<BindingBlock>();
}
- void VisitBinding(const Binding& binding) override {
- bool has_same_struct_info = true;
- Expr value;
- if (auto ptr = binding.as<VarBindingNode>()) {
- value = ptr->value;
- } else if (auto ptr = binding.as<MatchCastNode>()) {
- has_same_struct_info =
- StructuralEqual()(GetStructInfo(binding->var),
GetStructInfo(ptr->value));
- value = ptr->value;
- } else {
- LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
- }
+ Optional<Expr> UnwrapKnownValue(Expr expr) {
+ // If the expression is a variable, then it can be unwrapped into
+ // its known value.
+ auto unwrap_var = [this](Expr expr) -> Expr {
+ if (auto var = expr.as<Var>()) {
+ if (auto opt = known_bindings_.Get(var.value())) {
+ return opt.value();
+ }
+ }
+ return expr;
+ };
- // Unwrap TupleGetItem, if the Tuple being accessed is known.
- if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
- Expr tuple = tuple_get_item->tuple;
- while (auto tuple_var = tuple.as<Var>()) {
- if (auto opt = known_bindings_.Get(tuple_var.value())) {
- tuple = opt.value();
+ auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr {
+ while (true) {
+ auto new_expr = unwrap_var(expr);
+ if (new_expr.same_as(expr)) {
+ return expr;
} else {
- break;
+ expr = new_expr;
}
}
+ };
+ // If the expression is a TupleGetItem, which accesses a field of
+ // a known tuple, then it can be unwrapped into a direct access of
+ // that field.
+ if (auto tuple_get_item = expr.as<TupleGetItemNode>()) {
+ Expr tuple = recursively_unwrap_var(tuple_get_item->tuple);
if (auto ptr = tuple.as<TupleNode>()) {
- value = ptr->fields[tuple_get_item->index];
+ return ptr->fields[tuple_get_item->index];
+ }
+ }
+
+ // If the expression is a Tuple, and each element is
+ // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of
+ // `earlier_tuple`.
+ auto earlier_tuple = [&]() -> Optional<Expr> {
+ auto expr_tuple = expr.as<TupleNode>();
+ if (!expr_tuple) {
+ return NullOpt;
+ }
+
+ if (expr_tuple->fields.empty()) {
+ return NullOpt;
+ }
+
+ auto first_element =
recursively_unwrap_var(expr_tuple->fields[0]).as<TupleGetItemNode>();
+ if (!first_element) {
+ return NullOpt;
+ }
+
+ auto earlier_tuple_size =
+
Downcast<TupleStructInfo>(GetStructInfo(first_element->tuple))->fields.size();
+ if (earlier_tuple_size != expr_tuple->fields.size()) {
+ return NullOpt;
}
+
+ Expr earlier_tuple = recursively_unwrap_var(first_element->tuple);
+
+ for (size_t i = 0; i < expr_tuple->fields.size(); i++) {
+ auto element =
recursively_unwrap_var(expr_tuple->fields[i]).as<TupleGetItemNode>();
+ if (!element) {
+ return NullOpt;
+ }
+ if (static_cast<size_t>(element->index) != i) {
+ return NullOpt;
+ }
+
+ auto source_of_element = recursively_unwrap_var(element->tuple);
+
+ if (!earlier_tuple.same_as(source_of_element)) {
+ return NullOpt;
+ }
+ }
+
+ return earlier_tuple;
+ }();
+ if (earlier_tuple) {
+ return earlier_tuple.value();
+ }
+
+ return NullOpt;
+ }
+
+ void VisitBinding(const Binding& binding) override {
+ bool has_same_struct_info = [&]() {
+ if (binding.as<VarBindingNode>()) {
+ return true;
+ } else if (auto match_cast = binding.as<MatchCastNode>()) {
+ return StructuralEqual()(GetStructInfo(binding->var),
GetStructInfo(match_cast->value));
+ } else {
+ LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
+ }
+ }();
+
+ Expr value = GetBoundValue(binding);
+
+ if (auto unwrapped = UnwrapKnownValue(value)) {
+ value = unwrapped.value();
}
if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py
b/tests/python/relax/test_transform_canonicalize_bindings.py
index a7ff8cdc32..1d982b0972 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -1294,5 +1294,56 @@ def test_trivial_binding_of_replaced_non_dataflow_var():
assert after_names == expected_names
+def test_trace_tuple_through_round_trip():
+ """Canonicalize to the orignal tuple, without unwrap/rewrap."""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+ with R.dataflow():
+ A = param_tuple[0]
+ B = param_tuple[1]
+ C = param_tuple[2]
+ output = (A, B, C)
+ R.output(output)
+ return output
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+ with R.dataflow():
+ A = param_tuple[0]
+ B = param_tuple[1]
+ C = param_tuple[2]
+ R.output()
+
+ return param_tuple
+
+ After = CanonicalizeBindings()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_trace_partial_tuple_through_round_trip():
+ """Canonicalize to the orignal tuple, without unwrap/rewrap."""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+ with R.dataflow():
+ A = param_tuple[0]
+ B = param_tuple[1]
+ output = (A, B)
+ R.output(output)
+ return output
+
+ Expected = Before
+
+ After = CanonicalizeBindings()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()