This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new be8607d47f [Relax][Bugfix] Infer TIR values from shapes inside a tuple
(#17312)
be8607d47f is described below
commit be8607d47fa418f6bf77671b81093e0ffd7fdc4d
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 28 17:43:54 2024 -0500
[Relax][Bugfix] Infer TIR values from shapes inside a tuple (#17312)
If a Relax function contains an `R.match_cast` that defines a symbolic
shape, and the value provided to the `R.match_cast` has a known static
shape, the `relax.transform.CanoncalizeBindings()` pass can in-line
the known static shape. However, while these known TIR values were
only collected if the expression used in `R.match_cast` was a
`R.Tensor`, `R.Shape`, and `R.Prim` (Relax types which may contain
symbolic TIR values), they were not collected if the `R.match_cast`
expression was a `R.Tuple`.
For example, while using `R.match_cast` to convert from
`R.Tensor([16])` to `R.Tensor([batch_size])` would identify that
`batch_size` must be `16`, using `R.match_cast` to convert from
`R.Tuple(R.Tensor([16]))` to `R.Tuple(R.Tensor([batch_size]))` would
not.
This commit updates the `InferSymbolicVarMap` to collect all symbolic
shapes, even if they occur within a `R.Tuple`.
---
src/relax/utils.cc | 27 ++++++++++++++---
.../relax/test_transform_canonicalize_bindings.py | 34 ++++++++++++++++++++++
2 files changed, 57 insertions(+), 4 deletions(-)
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 77416dc92b..96fd5578e4 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -159,13 +159,32 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
GetStructInfo(expr_tensor->shape.value()));
};
+ std::function<void(const StructInfo&, const StructInfo&)>
bind_from_struct_info = nullptr;
+ auto bind_from_tuple = [&bind_from_struct_info](const StructInfo& var, const
StructInfo& expr) {
+ auto var_tuple = var.as<TupleStructInfoNode>();
+ if (!var_tuple) return;
+
+ auto expr_tuple = expr.as<TupleStructInfoNode>();
+ if (!expr_tuple) return;
+
+ if (var_tuple->fields.size() != expr_tuple->fields.size()) return;
+
+ for (size_t i = 0; i < var_tuple->fields.size(); i++) {
+ bind_from_struct_info(var_tuple->fields[i], expr_tuple->fields[i]);
+ }
+ };
+
+ bind_from_struct_info = [&](const StructInfo& var, const StructInfo& expr) {
+ bind_from_tensor(var, expr);
+ bind_from_shape(var, expr);
+ bind_from_prim_value(var, expr);
+ bind_from_tuple(var, expr);
+ };
+
for (const auto& [relax_var, relax_expr] : relax_var_remap) {
auto var_sinfo = GetStructInfo(relax_var);
auto expr_sinfo = GetStructInfo(relax_expr);
-
- bind_from_tensor(var_sinfo, expr_sinfo);
- bind_from_shape(var_sinfo, expr_sinfo);
- bind_from_prim_value(var_sinfo, expr_sinfo);
+ bind_from_struct_info(var_sinfo, expr_sinfo);
}
return tir_var_remap;
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py
b/tests/python/relax/test_transform_canonicalize_bindings.py
index ea3b1c249b..a7ff8cdc32 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -253,6 +253,40 @@ def test_replace_symbolic_variable_and_remove_match_cast():
verify(TestChangeShape, Expected)
+def test_replace_symbolic_variable_and_remove_match_cast_of_tuple():
+ """Symbolic variables may be defined in R.match_cast of tuple
+
+ This test is similar to
+ `test_replace_symbolic_variable_and_remove_match_cast`, except
+ that the MatchCast is performed on a Relax tuple.
+
+ This is a regression test. Earlier implementations only inferred
+ TIR variables from `R.match_cast` of tensors, shapes, and prim
+ values, but omitted tuples.
+
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tuple(R.Tensor(("m", "n")))):
+ y = x
+ o, p = T.int64(), T.int64()
+ z = R.match_cast(x, R.Tuple(R.Tensor((o, p))))
+ w = z
+ q = R.add(w[0], y[0])
+ return R.add(q, w[0])
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tuple(R.Tensor(("m", "n")))):
+ q = R.add(x[0], x[0])
+ return R.add(q, x[0])
+
+ verify(Before, Expected)
+
+
def test_unwrap_tuple():
@I.ir_module
class Before: