This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new be8607d47f [Relax][Bugfix] Infer TIR values from shapes inside a tuple 
(#17312)
be8607d47f is described below

commit be8607d47fa418f6bf77671b81093e0ffd7fdc4d
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 28 17:43:54 2024 -0500

    [Relax][Bugfix] Infer TIR values from shapes inside a tuple (#17312)
    
    If a Relax function contains an `R.match_cast` that defines a symbolic
    shape, and the value provided to the `R.match_cast` has a known static
    shape, the `relax.transform.CanoncalizeBindings()` pass can in-line
    the known static shape.  However, while these known TIR values were
    only collected if the expression used in `R.match_cast` was a
    `R.Tensor`, `R.Shape`, and `R.Prim` (Relax types which may contain
    symbolic TIR values), they were not collected if the `R.match_cast`
    expression was a `R.Tuple`.
    
    For example, while using `R.match_cast` to convert from
    `R.Tensor([16])` to `R.Tensor([batch_size])` would identify that
    `batch_size` must be `16`, using `R.match_cast` to convert from
    `R.Tuple(R.Tensor([16]))` to `R.Tuple(R.Tensor([batch_size]))` would
    not.
    
    This commit updates the `InferSymbolicVarMap` to collect all symbolic
    shapes, even if they occur within a `R.Tuple`.
---
 src/relax/utils.cc                                 | 27 ++++++++++++++---
 .../relax/test_transform_canonicalize_bindings.py  | 34 ++++++++++++++++++++++
 2 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 77416dc92b..96fd5578e4 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -159,13 +159,32 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
                     GetStructInfo(expr_tensor->shape.value()));
   };
 
+  std::function<void(const StructInfo&, const StructInfo&)> 
bind_from_struct_info = nullptr;
+  auto bind_from_tuple = [&bind_from_struct_info](const StructInfo& var, const 
StructInfo& expr) {
+    auto var_tuple = var.as<TupleStructInfoNode>();
+    if (!var_tuple) return;
+
+    auto expr_tuple = expr.as<TupleStructInfoNode>();
+    if (!expr_tuple) return;
+
+    if (var_tuple->fields.size() != expr_tuple->fields.size()) return;
+
+    for (size_t i = 0; i < var_tuple->fields.size(); i++) {
+      bind_from_struct_info(var_tuple->fields[i], expr_tuple->fields[i]);
+    }
+  };
+
+  bind_from_struct_info = [&](const StructInfo& var, const StructInfo& expr) {
+    bind_from_tensor(var, expr);
+    bind_from_shape(var, expr);
+    bind_from_prim_value(var, expr);
+    bind_from_tuple(var, expr);
+  };
+
   for (const auto& [relax_var, relax_expr] : relax_var_remap) {
     auto var_sinfo = GetStructInfo(relax_var);
     auto expr_sinfo = GetStructInfo(relax_expr);
-
-    bind_from_tensor(var_sinfo, expr_sinfo);
-    bind_from_shape(var_sinfo, expr_sinfo);
-    bind_from_prim_value(var_sinfo, expr_sinfo);
+    bind_from_struct_info(var_sinfo, expr_sinfo);
   }
 
   return tir_var_remap;
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index ea3b1c249b..a7ff8cdc32 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -253,6 +253,40 @@ def test_replace_symbolic_variable_and_remove_match_cast():
     verify(TestChangeShape, Expected)
 
 
+def test_replace_symbolic_variable_and_remove_match_cast_of_tuple():
+    """Symbolic variables may be defined in R.match_cast of tuple
+
+    This test is similar to
+    `test_replace_symbolic_variable_and_remove_match_cast`, except
+    that the MatchCast is performed on a Relax tuple.
+
+    This is a regression test.  Earlier implementations only inferred
+    TIR variables from `R.match_cast` of tensors, shapes, and prim
+    values, but omitted tuples.
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tuple(R.Tensor(("m", "n")))):
+            y = x
+            o, p = T.int64(), T.int64()
+            z = R.match_cast(x, R.Tuple(R.Tensor((o, p))))
+            w = z
+            q = R.add(w[0], y[0])
+            return R.add(q, w[0])
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tuple(R.Tensor(("m", "n")))):
+            q = R.add(x[0], x[0])
+            return R.add(q, x[0])
+
+    verify(Before, Expected)
+
+
 def test_unwrap_tuple():
     @I.ir_module
     class Before:

Reply via email to