This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c0b8953a1f [Unity] Dynamic-shape param support in LazyTransformParams 
(#15713)
c0b8953a1f is described below

commit c0b8953a1f4b79f8290217a4732365d0baa4c419
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Sep 9 13:46:27 2023 -0400

    [Unity] Dynamic-shape param support in LazyTransformParams (#15713)
    
    This PR brings the support for dynamic-shape parameters to pass
    LazyTransformParams.
    
    Prior to this PR, the symbolic variables in the dynamic-shape parameters
    are not properly popped out. This PR uses MatchCast to make sure the
    symbolic variables are always popped out and thereby support the
    dynamic-shape parameters.
    
    This PR also fixes a previvously failed test.
---
 .../tvm/relax/transform/lazy_transform_params.py   |  13 ++-
 .../relax/test_transform_lazy_transform_params.py  | 120 +++++++++++++++++++--
 2 files changed, 118 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/transform/lazy_transform_params.py 
b/python/tvm/relax/transform/lazy_transform_params.py
index 69f724067c..90e56c8dbb 100644
--- a/python/tvm/relax/transform/lazy_transform_params.py
+++ b/python/tvm/relax/transform/lazy_transform_params.py
@@ -164,12 +164,15 @@ class LazyTransformParamsMutator(PyExprMutator):
         # rewrite get item
         tuple_get_item = super().visit_tuple_getitem_(op)
         if tuple_get_item.tuple_value == self.input_tuple_param:
-            return relax.Call(
-                relax.ExternFunc("get_item"),
-                [relax.PrimValue(tuple_get_item.index)],
-                None,
-                [relax.ObjectStructInfo()],
+            get_item_result = self.builder_.emit(
+                relax.Call(
+                    relax.ExternFunc("get_item"),
+                    [relax.PrimValue(tuple_get_item.index)],
+                    None,
+                    [relax.ObjectStructInfo()],
+                )
             )
+            return self.builder_.match_cast(get_item_result, op.struct_info)
         else:
             return tuple_get_item
 
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index 478580ff8d..94f2181daf 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -79,15 +79,23 @@ def test_lazy_transform_params():
             R.func_attr({"relax.force_pure": True})
             cls = Expected
             lv: R.Object = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=(R.Object,))
-            _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, 
sinfo_args=(R.Object,))
-            _1: R.Tuple = R.vm.kill_object(lv)
+            gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+                lv, R.Tensor((16, 16, 3, 3), dtype="float32")
+            )
+            lv_m: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+            _: R.Object = R.call_packed("set_item", R.prim_value(0), lv_m, 
sinfo_args=(R.Object,))
+            _1: R.Tuple = R.vm.kill_object(lv_m)
             lv1: R.Object = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
+            gv3: R.Tensor((3, 16, 3, 3), dtype="float32") = R.match_cast(
+                lv1, R.Tensor((3, 16, 3, 3), dtype="float32")
+            )
+            lv1_m: R.Tensor((3, 16, 3, 3), dtype="float32") = gv3
             lv2 = R.call_tir(
                 cls.transform_layout_IOHW_to_OIHW,
-                (lv1,),
+                (lv1_m,),
                 out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"),
             )
-            _2: R.Tuple = R.vm.kill_object(lv1)
+            _2: R.Tuple = R.vm.kill_object(lv1_m)
             _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, 
sinfo_args=(R.Object,))
             gv: R.Tuple = R.tuple()
             return gv
@@ -146,13 +154,17 @@ def test_lazy_transform_params_with_symbolic_vars():
             slice_index = T.int64()
 
             param = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
+            gv: R.Tensor((16, 16), dtype="float32") = R.match_cast(
+                param, R.Tensor((16, 16), dtype="float32")
+            )
+            param_m: R.Tensor((16, 16), dtype="float32") = gv
             transformed = R.call_tir(
                 cls.slice_buffer,
-                (param,),
+                (param_m,),
                 tir_vars=[slice_index],
                 out_sinfo=R.Tensor((16,), dtype="float32"),
             )
-            unused_1_ = R.vm.kill_object(param)
+            unused_1_ = R.vm.kill_object(param_m)
             unused_2_ = R.call_packed(
                 "set_item", R.prim_value(0), transformed, 
sinfo_args=(R.Object,)
             )
@@ -175,14 +187,100 @@ def test_lazy_transform_params_with_symbolic_vars():
     tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
 
 
-# TODO(tvm-team): remove once regression get fixed
[email protected]("temp disable, minor regression on read/write region in zero 
dim buffer")
+def test_param_shape_symbolic():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
+            ic = T.int32()
+            w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
+            out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
+            for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((3, "ic", 3, 3), dtype="float32"),
+                R.Tensor((16, 16, 3, 3), dtype="float32"),
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 
3), dtype="float32")
+        ):
+            ic = T.int64()
+            # we expect ToNonDataflow and RemovePurityTracking to be invoked 
first
+            R.func_attr({"relax.force_pure": True})
+            cls = Before
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1]
+            lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = params[0]
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
+            )
+            gv: R.Tuple(
+                R.Tensor((16, 16, 3, 3), dtype="float32"),
+                R.Tensor((ic, 3, 3, 3), dtype="float32"),
+            ) = (lv, lv2)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def transform_layout_IOHW_to_OIHW(var_w1: T.handle, var_out: T.handle):
+            ic = T.int32()
+            w1 = T.match_buffer(var_w1, (ic, 16, 3, 3), "float32")
+            out = T.match_buffer(var_out, (16, ic, 3, 3), "float32")
+            for ax0, ax1, ax2, ax3 in T.grid(16, ic, 3, 3):
+                with T.block("layout_transform"):
+                    o, i, h, w = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                    T.reads(w1[i, o, h, w])
+                    T.writes(out[o, i, h, w])
+                    out[o, i, h, w] = w1[i, o, h, w]
+
+        @R.function
+        def main_transform_params() -> R.Tuple:
+            R.func_attr({"relax.force_pure": True})
+            ic = T.int64()
+            cls = Expected
+            gv: R.Object = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=(R.Object,))
+            gv1: R.Tensor((16, 16, 3, 3), dtype="float32") = R.match_cast(
+                gv, R.Tensor((16, 16, 3, 3), dtype="float32")
+            )
+            lv: R.Tensor((16, 16, 3, 3), dtype="float32") = gv1
+            _: R.Object = R.call_packed("set_item", R.prim_value(0), lv, 
sinfo_args=(R.Object,))
+            _1: R.Tuple = R.vm.kill_object(lv)
+            gv2: R.Object = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
+            gv3: R.Tensor((3, ic, 3, 3), dtype="float32") = R.match_cast(
+                gv2, R.Tensor((3, ic, 3, 3), dtype="float32")
+            )
+            lv1: R.Tensor((3, ic, 3, 3), dtype="float32") = gv3
+            lv2 = R.call_tir(
+                cls.transform_layout_IOHW_to_OIHW,
+                (lv1,),
+                out_sinfo=R.Tensor((ic, 3, 3, 3), dtype="float32"),
+            )
+            _2: R.Tuple = R.vm.kill_object(lv1)
+            _3: R.Object = R.call_packed("set_item", R.prim_value(1), lv2, 
sinfo_args=(R.Object,))
+            gv4: R.Tuple = R.tuple()
+            return gv4
+
+    after = LazyTransformParams()(Before)
+    tvm.ir.assert_structural_equal(after, Expected, map_free_vars=True)
+
+
 def test_output_with_use_site():
     @I.ir_module
     class Module:
         @T.prim_func
         def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")):
             with T.block("block"):
+                T.reads(x[()])
+                T.writes(y[()])
                 y[()] = x[()]
 
         @R.function
@@ -212,8 +310,10 @@ def test_output_with_use_site():
             R.func_attr({"relax.force_pure": True})
             cls = Expected
             x: R.Object = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=(R.Object,))
-            y = R.call_tir(cls.copy, (x,), out_sinfo=R.Tensor((), 
dtype="float32"))
-            _: R.Tuple = R.vm.kill_object(x)
+            gv: R.Tensor((), dtype="float32") = R.match_cast(x, R.Tensor((), 
dtype="float32"))
+            x_m: R.Tensor((), dtype="float32") = gv
+            y = R.call_tir(cls.copy, (x_m,), out_sinfo=R.Tensor((), 
dtype="float32"))
+            _: R.Tuple = R.vm.kill_object(x_m)
             z = R.call_tir(cls.copy, (y,), out_sinfo=R.Tensor((), 
dtype="float32"))
             _1: R.Object = R.call_packed("set_item", R.prim_value(0), y, 
sinfo_args=(R.Object,))
             _2: R.Object = R.call_packed("set_item", R.prim_value(1), z, 
sinfo_args=(R.Object,))

Reply via email to