tqchen commented on code in PR #16815:
URL: https://github.com/apache/tvm/pull/16815#discussion_r1544107187


##########
tests/python/relax/test_transform_rewrite_cuda_graph.py:
##########
@@ -757,5 +757,118 @@ def main() -> R.Tuple:
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_capture():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,), "float32")
+            y = T.match_buffer(y_handle, (m,), "float32")
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.remap("S", [i])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function
+        def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",), 
"float32"):
+            R.func_attr(
+                {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"], 
"relax.force_pure": True}
+            )
+            m = T.int64()
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), 0, "global", "float32"
+            )  # assume m is upper-bounded
+            alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(x, alloc1)
+            storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0, 
"global", "float32")
+            alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage1, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(alloc1, alloc2)
+            alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+                R.shape([m]), "float32", 0, "global"
+            )
+            _ = Before.add_one(alloc2, alloc3)
+            return alloc3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,))
+            y = T.match_buffer(y_handle, (m,))
+            # with T.block("root"):
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.spatial(m, i)
+                    T.reads(x[vi])
+                    T.writes(y[vi])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object, R.Object) = storage, storage1
+            return gv
+
+        @R.function(private=True)
+        def cuda_graph_capture(
+            alloc1: R.Tensor(("m",), dtype="float32"),
+            alloc2: R.Tensor(("m",), dtype="float32"),
+            shape_expr: R.Shape(["m"]),
+        ):
+            m = T.int64()
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            cls.add_one(alloc1, alloc2)
+            gv = R.tuple()
+            return R.tuple()
+
+        @R.function
+        def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), 
dtype="float32"):
+            m = T.int64()
+            R.func_attr(
+                {"relax.force_pure": True, 
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+            )
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object),),
+            )
+            storage: R.Object = gv[0]
+            alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            cls.add_one(x, alloc1)
+            storage1: R.Object = gv[1]
+            alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(

Review Comment:
   consider explicitly pass in the shape to the function, namely
   (cls.cuda_graph_capture, (alloc1, alloc2, R.shape([m])), R.prim_value(0), 
R.shape([m])),



##########
tests/python/relax/test_transform_rewrite_cuda_graph.py:
##########
@@ -757,5 +757,118 @@ def main() -> R.Tuple:
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_capture():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,), "float32")
+            y = T.match_buffer(y_handle, (m,), "float32")
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.remap("S", [i])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function
+        def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",), 
"float32"):
+            R.func_attr(
+                {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"], 
"relax.force_pure": True}
+            )
+            m = T.int64()
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), 0, "global", "float32"
+            )  # assume m is upper-bounded
+            alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(x, alloc1)
+            storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0, 
"global", "float32")
+            alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+                storage1, 0, R.shape([m]), "float32"
+            )
+            _ = Before.add_one(alloc1, alloc2)
+            alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+                R.shape([m]), "float32", 0, "global"
+            )
+            _ = Before.add_one(alloc2, alloc3)
+            return alloc3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add_one(x_handle: T.handle, y_handle: T.handle):
+            m = T.int64()
+            x = T.match_buffer(x_handle, (m,))
+            y = T.match_buffer(y_handle, (m,))
+            # with T.block("root"):
+            for i in range(m):
+                with T.block("add"):
+                    vi = T.axis.spatial(m, i)
+                    T.reads(x[vi])
+                    T.writes(y[vi])
+                    y[vi] = x[vi] + T.float32(1)
+
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([16]), R.prim_value(0), R.str("global"), 
R.dtype("float32")
+            )
+            gv: R.Tuple(R.Object, R.Object) = storage, storage1
+            return gv
+
+        @R.function(private=True)
+        def cuda_graph_capture(
+            alloc1: R.Tensor(("m",), dtype="float32"),
+            alloc2: R.Tensor(("m",), dtype="float32"),
+            shape_expr: R.Shape(["m"]),
+        ):
+            m = T.int64()
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            cls.add_one(alloc1, alloc2)
+            gv = R.tuple()
+            return R.tuple()
+
+        @R.function
+        def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), 
dtype="float32"):
+            m = T.int64()
+            R.func_attr(
+                {"relax.force_pure": True, 
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+            )
+            cls = Expected
+            gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object),),
+            )
+            storage: R.Object = gv[0]
+            alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            cls.add_one(x, alloc1)
+            storage1: R.Object = gv[1]
+            alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+            )
+            R.call_builtin_with_ctx(

Review Comment:
   consider explicitly pass in the shape to the function, namely
   `(cls.cuda_graph_capture, (alloc1, alloc2, R.shape([m])), R.prim_value(0), 
R.shape([m])),`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to