tqchen commented on code in PR #16815:
URL: https://github.com/apache/tvm/pull/16815#discussion_r1544107187
##########
tests/python/relax/test_transform_rewrite_cuda_graph.py:
##########
@@ -757,5 +757,118 @@ def main() -> R.Tuple:
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_capture():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,), "float32")
+ y = T.match_buffer(y_handle, (m,), "float32")
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.remap("S", [i])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function
+ def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",),
"float32"):
+ R.func_attr(
+ {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"],
"relax.force_pure": True}
+ )
+ m = T.int64()
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), 0, "global", "float32"
+ ) # assume m is upper-bounded
+ alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(x, alloc1)
+ storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0,
"global", "float32")
+ alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(alloc1, alloc2)
+ alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+ R.shape([m]), "float32", 0, "global"
+ )
+ _ = Before.add_one(alloc2, alloc3)
+ return alloc3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,))
+ y = T.match_buffer(y_handle, (m,))
+ # with T.block("root"):
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.spatial(m, i)
+ T.reads(x[vi])
+ T.writes(y[vi])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object, R.Object) = storage, storage1
+ return gv
+
+ @R.function(private=True)
+ def cuda_graph_capture(
+ alloc1: R.Tensor(("m",), dtype="float32"),
+ alloc2: R.Tensor(("m",), dtype="float32"),
+ shape_expr: R.Shape(["m"]),
+ ):
+ m = T.int64()
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ cls.add_one(alloc1, alloc2)
+ gv = R.tuple()
+ return R.tuple()
+
+ @R.function
+ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",),
dtype="float32"):
+ m = T.int64()
+ R.func_attr(
+ {"relax.force_pure": True,
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+ )
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object),),
+ )
+ storage: R.Object = gv[0]
+ alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ cls.add_one(x, alloc1)
+ storage1: R.Object = gv[1]
+ alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
Review Comment:
consider explicitly pass in the shape to the function, namely
(cls.cuda_graph_capture, (alloc1, alloc2, R.shape([m])), R.prim_value(0),
R.shape([m])),
##########
tests/python/relax/test_transform_rewrite_cuda_graph.py:
##########
@@ -757,5 +757,118 @@ def main() -> R.Tuple:
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_capture():
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,), "float32")
+ y = T.match_buffer(y_handle, (m,), "float32")
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.remap("S", [i])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function
+ def main(x: R.Tensor(("m",), "float32")) -> R.Tensor(("m",),
"float32"):
+ R.func_attr(
+ {"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"],
"relax.force_pure": True}
+ )
+ m = T.int64()
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), 0, "global", "float32"
+ ) # assume m is upper-bounded
+ alloc1: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(x, alloc1)
+ storage1: R.Object = R.memory.alloc_storage(R.shape([16]), 0,
"global", "float32")
+ alloc2: R.Tensor((m,), "float32") = R.memory.alloc_tensor(
+ storage1, 0, R.shape([m]), "float32"
+ )
+ _ = Before.add_one(alloc1, alloc2)
+ alloc3: R.Tensor((m,), "float32") = R.builtin.alloc_tensor(
+ R.shape([m]), "float32", 0, "global"
+ )
+ _ = Before.add_one(alloc2, alloc3)
+ return alloc3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add_one(x_handle: T.handle, y_handle: T.handle):
+ m = T.int64()
+ x = T.match_buffer(x_handle, (m,))
+ y = T.match_buffer(y_handle, (m,))
+ # with T.block("root"):
+ for i in range(m):
+ with T.block("add"):
+ vi = T.axis.spatial(m, i)
+ T.reads(x[vi])
+ T.writes(y[vi])
+ y[vi] = x[vi] + T.float32(1)
+
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([16]), R.prim_value(0), R.str("global"),
R.dtype("float32")
+ )
+ gv: R.Tuple(R.Object, R.Object) = storage, storage1
+ return gv
+
+ @R.function(private=True)
+ def cuda_graph_capture(
+ alloc1: R.Tensor(("m",), dtype="float32"),
+ alloc2: R.Tensor(("m",), dtype="float32"),
+ shape_expr: R.Shape(["m"]),
+ ):
+ m = T.int64()
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ cls.add_one(alloc1, alloc2)
+ gv = R.tuple()
+ return R.tuple()
+
+ @R.function
+ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",),
dtype="float32"):
+ m = T.int64()
+ R.func_attr(
+ {"relax.force_pure": True,
"relax.rewrite_cuda_graph.capture_symbolic_vars": ["m"]}
+ )
+ cls = Expected
+ gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object),),
+ )
+ storage: R.Object = gv[0]
+ alloc1: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ cls.add_one(x, alloc1)
+ storage1: R.Object = gv[1]
+ alloc2: R.Tensor((m,), dtype="float32") = R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([m]), R.dtype("float32")
+ )
+ R.call_builtin_with_ctx(
Review Comment:
consider explicitly pass in the shape to the function, namely
`(cls.cuda_graph_capture, (alloc1, alloc2, R.shape([m])), R.prim_value(0),
R.shape([m])),`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]