This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 108a4e15b3 [Relax] Identify tuple unpack/repack in 
CanonicalizeBindings (#17313)
108a4e15b3 is described below

commit 108a4e15b3c68fea2f803dc13b1b45291b00f15b
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 28 18:29:18 2024 -0500

    [Relax] Identify tuple unpack/repack in CanonicalizeBindings (#17313)
    
    Prior to this commit, the `CanonicalizeBindings` pass could identify
    and simplify a value that had been packed into a tuple, then
    extracted from it.  (e.g. Simplifying `tup = (x,y); z = tup[0]` into
    `z = x`.)  However, it could not identify a value that had been
    expanded from a tuple, and then re-bundled.  (e.g. Simplifying
    `new_tuple = (tup[0], tup[1])` into `new_tuple = tup`.)
    
    This commit updates `CanonicalizeBindings` to identify and remove
    unnecessary tuple unpacking/repacking.
---
 src/relax/transform/canonicalize_bindings.cc       | 112 +++++++++++++++++----
 .../relax/test_transform_canonicalize_bindings.py  |  51 ++++++++++
 2 files changed, 143 insertions(+), 20 deletions(-)

diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
index d1a9f97337..807914075e 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -262,33 +262,105 @@ class CanonicalizePlanner : public ExprVisitor {
     current_block_ = Optional<BindingBlock>();
   }
 
-  void VisitBinding(const Binding& binding) override {
-    bool has_same_struct_info = true;
-    Expr value;
-    if (auto ptr = binding.as<VarBindingNode>()) {
-      value = ptr->value;
-    } else if (auto ptr = binding.as<MatchCastNode>()) {
-      has_same_struct_info =
-          StructuralEqual()(GetStructInfo(binding->var), 
GetStructInfo(ptr->value));
-      value = ptr->value;
-    } else {
-      LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
-    }
+  Optional<Expr> UnwrapKnownValue(Expr expr) {
+    // If the expression is a variable, then it can be unwrapped into
+    // its known value.
+    auto unwrap_var = [this](Expr expr) -> Expr {
+      if (auto var = expr.as<Var>()) {
+        if (auto opt = known_bindings_.Get(var.value())) {
+          return opt.value();
+        }
+      }
+      return expr;
+    };
 
-    // Unwrap TupleGetItem, if the Tuple being accessed is known.
-    if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
-      Expr tuple = tuple_get_item->tuple;
-      while (auto tuple_var = tuple.as<Var>()) {
-        if (auto opt = known_bindings_.Get(tuple_var.value())) {
-          tuple = opt.value();
+    auto recursively_unwrap_var = [&unwrap_var](Expr expr) -> Expr {
+      while (true) {
+        auto new_expr = unwrap_var(expr);
+        if (new_expr.same_as(expr)) {
+          return expr;
         } else {
-          break;
+          expr = new_expr;
         }
       }
+    };
 
+    // If the expression is a TupleGetItem, which accesses a field of
+    // a known tuple, then it can be unwrapped into a direct access of
+    // that field.
+    if (auto tuple_get_item = expr.as<TupleGetItemNode>()) {
+      Expr tuple = recursively_unwrap_var(tuple_get_item->tuple);
       if (auto ptr = tuple.as<TupleNode>()) {
-        value = ptr->fields[tuple_get_item->index];
+        return ptr->fields[tuple_get_item->index];
+      }
+    }
+
+    // If the expression is a Tuple, and each element is
+    // `TupleGetItem(earlier_tuple, i)`, then this is just a copy of
+    // `earlier_tuple`.
+    auto earlier_tuple = [&]() -> Optional<Expr> {
+      auto expr_tuple = expr.as<TupleNode>();
+      if (!expr_tuple) {
+        return NullOpt;
+      }
+
+      if (expr_tuple->fields.empty()) {
+        return NullOpt;
+      }
+
+      auto first_element = 
recursively_unwrap_var(expr_tuple->fields[0]).as<TupleGetItemNode>();
+      if (!first_element) {
+        return NullOpt;
+      }
+
+      auto earlier_tuple_size =
+          
Downcast<TupleStructInfo>(GetStructInfo(first_element->tuple))->fields.size();
+      if (earlier_tuple_size != expr_tuple->fields.size()) {
+        return NullOpt;
       }
+
+      Expr earlier_tuple = recursively_unwrap_var(first_element->tuple);
+
+      for (size_t i = 0; i < expr_tuple->fields.size(); i++) {
+        auto element = 
recursively_unwrap_var(expr_tuple->fields[i]).as<TupleGetItemNode>();
+        if (!element) {
+          return NullOpt;
+        }
+        if (static_cast<size_t>(element->index) != i) {
+          return NullOpt;
+        }
+
+        auto source_of_element = recursively_unwrap_var(element->tuple);
+
+        if (!earlier_tuple.same_as(source_of_element)) {
+          return NullOpt;
+        }
+      }
+
+      return earlier_tuple;
+    }();
+    if (earlier_tuple) {
+      return earlier_tuple.value();
+    }
+
+    return NullOpt;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    bool has_same_struct_info = [&]() {
+      if (binding.as<VarBindingNode>()) {
+        return true;
+      } else if (auto match_cast = binding.as<MatchCastNode>()) {
+        return StructuralEqual()(GetStructInfo(binding->var), 
GetStructInfo(match_cast->value));
+      } else {
+        LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
+      }
+    }();
+
+    Expr value = GetBoundValue(binding);
+
+    if (auto unwrapped = UnwrapKnownValue(value)) {
+      value = unwrapped.value();
     }
 
     if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index a7ff8cdc32..1d982b0972 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -1294,5 +1294,56 @@ def test_trivial_binding_of_replaced_non_dataflow_var():
     assert after_names == expected_names
 
 
+def test_trace_tuple_through_round_trip():
+    """Canonicalize to the orignal tuple, without unwrap/rewrap."""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+            with R.dataflow():
+                A = param_tuple[0]
+                B = param_tuple[1]
+                C = param_tuple[2]
+                output = (A, B, C)
+                R.output(output)
+            return output
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+            with R.dataflow():
+                A = param_tuple[0]
+                B = param_tuple[1]
+                C = param_tuple[2]
+                R.output()
+
+            return param_tuple
+
+    After = CanonicalizeBindings()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_trace_partial_tuple_through_round_trip():
+    """Canonicalize to the orignal tuple, without unwrap/rewrap."""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(param_tuple: R.Tuple([R.Tensor, R.Tensor, R.Tensor])):
+            with R.dataflow():
+                A = param_tuple[0]
+                B = param_tuple[1]
+                output = (A, B)
+                R.output(output)
+            return output
+
+    Expected = Before
+
+    After = CanonicalizeBindings()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to