This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 94e83f27cc [Unity][VM] Recursively visit match bindings in 
VMShapeLowerMutator (#16583)
94e83f27cc is described below

commit 94e83f27ccc8221655a7e049405a5828586d6ebc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Feb 17 08:21:50 2024 -0600

    [Unity][VM] Recursively visit match bindings in VMShapeLowerMutator (#16583)
    
    Prior to this commit, the `MatchBinding` visitor in
    `VMShapeLowerMutator`.  If the RHS of the `MatchBinding` is a
    `ShapeExpr` that uses symbolic variables, that RHS must be visited in
    order to have the symbolic variable updated.
---
 src/relax/backend/vm/vm_shape_lower.cc             |  2 +-
 .../relax/test_backend_transform_shape_lower.py    | 78 ++++++++++++++++++++++
 2 files changed, 79 insertions(+), 1 deletion(-)

diff --git a/src/relax/backend/vm/vm_shape_lower.cc 
b/src/relax/backend/vm/vm_shape_lower.cc
index 41b27ea625..5875ad5562 100644
--- a/src/relax/backend/vm/vm_shape_lower.cc
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -419,7 +419,7 @@ class VMShapeLowerMutator
 
     // These checks are emitted as extra, in codegen
     // match-cast is simply ignored and treated as a normal binding.
-    builder_->EmitNormalized(GetRef<MatchCast>(binding));
+    ExprMutator::VisitBinding_(binding);
   }
 
   // Do not override shape in struct info fields
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py 
b/tests/python/relax/test_backend_transform_shape_lower.py
index b9a3537630..31eb4b26be 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -731,5 +731,83 @@ def test_check_weights_with_dynamic_shape():
     assert_structural_equal(after, expected)
 
 
+def test_update_symbolic_vars_in_match_cast_rhs():
+    """Symbolic variables may be used on the RHS of match_cast"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            arg_prim_value: R.Prim(value="n"),
+        ):
+            R.func_attr({"relax.force_pure": 1})
+            n = T.int64()
+            shape = R.shape([n])
+            m = T.int64()
+            _ = R.match_cast(shape, R.Shape([m]))
+            return R.prim_value(m)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
+            R.func_attr({"relax.force_pure": 1})
+            n = T.int64()
+
+            shape_heap = R.call_builtin_with_ctx(
+                "vm.builtin.alloc_shape_heap",
+                [2],
+                sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
+            )
+            _ = R.call_packed(
+                "vm.builtin.check_prim_value_info",
+                arg_prim_value,
+                R.dtype("int64"),
+                "",
+                sinfo_args=[R.Tuple],
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_prim_value",
+                arg_prim_value,
+                shape_heap,
+                MatchShapeCode.STORE_TO_HEAP,
+                0,
+                "",
+                sinfo_args=[R.Tuple],
+            )
+            shape = R.call_packed(
+                "vm.builtin.make_shape",
+                shape_heap,
+                1,
+                MakeShapeCode.LOAD_SHAPE,
+                0,
+                sinfo_args=[R.Shape(ndim=1)],
+            )
+            _ = R.call_packed(
+                "vm.builtin.match_shape",
+                shape,
+                shape_heap,
+                1,
+                MatchShapeCode.STORE_TO_HEAP,
+                1,
+                "",
+                sinfo_args=[R.Tuple],
+            )
+
+            m = T.int64()
+            _ = R.match_cast(shape, R.Shape([m]))
+            gv = R.call_packed(
+                "vm.builtin.make_prim_value",
+                shape_heap,
+                MakeShapeCode.LOAD_SHAPE,
+                1,
+                sinfo_args=[R.Prim(value=m)],
+            )
+            return gv
+
+    After = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)
+    assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to