This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 94e83f27cc [Unity][VM] Recursively visit match bindings in
VMShapeLowerMutator (#16583)
94e83f27cc is described below
commit 94e83f27ccc8221655a7e049405a5828586d6ebc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Feb 17 08:21:50 2024 -0600
[Unity][VM] Recursively visit match bindings in VMShapeLowerMutator (#16583)
Prior to this commit, the `MatchBinding` visitor in
`VMShapeLowerMutator`. If the RHS of the `MatchBinding` is a
`ShapeExpr` that uses symbolic variables, that RHS must be visited in
order to have the symbolic variable updated.
---
src/relax/backend/vm/vm_shape_lower.cc | 2 +-
.../relax/test_backend_transform_shape_lower.py | 78 ++++++++++++++++++++++
2 files changed, 79 insertions(+), 1 deletion(-)
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
index 41b27ea625..5875ad5562 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -419,7 +419,7 @@ class VMShapeLowerMutator
// These checks are emitted as extra, in codegen
// match-cast is simply ignored and treated as a normal binding.
- builder_->EmitNormalized(GetRef<MatchCast>(binding));
+ ExprMutator::VisitBinding_(binding);
}
// Do not override shape in struct info fields
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index b9a3537630..31eb4b26be 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -731,5 +731,83 @@ def test_check_weights_with_dynamic_shape():
assert_structural_equal(after, expected)
+def test_update_symbolic_vars_in_match_cast_rhs():
+ """Symbolic variables may be used on the RHS of match_cast"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ arg_prim_value: R.Prim(value="n"),
+ ):
+ R.func_attr({"relax.force_pure": 1})
+ n = T.int64()
+ shape = R.shape([n])
+ m = T.int64()
+ _ = R.match_cast(shape, R.Shape([m]))
+ return R.prim_value(m)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
+ R.func_attr({"relax.force_pure": 1})
+ n = T.int64()
+
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [2],
+ sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_prim_value_info",
+ arg_prim_value,
+ R.dtype("int64"),
+ "",
+ sinfo_args=[R.Tuple],
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_prim_value",
+ arg_prim_value,
+ shape_heap,
+ MatchShapeCode.STORE_TO_HEAP,
+ 0,
+ "",
+ sinfo_args=[R.Tuple],
+ )
+ shape = R.call_packed(
+ "vm.builtin.make_shape",
+ shape_heap,
+ 1,
+ MakeShapeCode.LOAD_SHAPE,
+ 0,
+ sinfo_args=[R.Shape(ndim=1)],
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ shape,
+ shape_heap,
+ 1,
+ MatchShapeCode.STORE_TO_HEAP,
+ 1,
+ "",
+ sinfo_args=[R.Tuple],
+ )
+
+ m = T.int64()
+ _ = R.match_cast(shape, R.Shape([m]))
+ gv = R.call_packed(
+ "vm.builtin.make_prim_value",
+ shape_heap,
+ MakeShapeCode.LOAD_SHAPE,
+ 1,
+ sinfo_args=[R.Prim(value=m)],
+ )
+ return gv
+
+ After = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)
+ assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()