This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 63f9cd6523 [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor 
(#17110)
63f9cd6523 is described below

commit 63f9cd6523bd827ea297c22cbbb74eaef9def931
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jun 26 08:43:12 2024 -0700

    [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor (#17110)
    
    * [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor
    
    This makes the allocation go through memory planning and make it
    compatible with cuda graph.
    
    * lint
    
    * lint
---
 python/tvm/relax/testing/matmul.py                 |  3 ++-
 src/relax/op/op_common.h                           |  3 +++
 src/relax/transform/allocate_workspace.cc          |  3 +--
 tests/python/relax/test_codegen_cutlass.py         | 31 +++++++++++-----------
 .../relax/test_transform_allocate_workspace.py     | 10 +++----
 5 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/python/tvm/relax/testing/matmul.py 
b/python/tvm/relax/testing/matmul.py
index 0ce1225e7d..760ad1bdef 100644
--- a/python/tvm/relax/testing/matmul.py
+++ b/python/tvm/relax/testing/matmul.py
@@ -25,7 +25,7 @@ def get_relax_matmul_module(
     x_shape,
     y_shape,
     in_dtype,
-    out_dtype,
+    out_dtype=None,
     transposed_y=False,
     bias_shape=None,
     activation=None,
@@ -33,6 +33,7 @@ def get_relax_matmul_module(
     residual_activation=None,
 ):
     """Create a matmul op followd by epilogue operations."""
+    out_dtype = out_dtype if out_dtype is not None else in_dtype
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 94474ce784..ed6725e270 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -558,6 +558,9 @@ Expr MakeVMAllocStorage(Expr size, PrimValue 
runtime_device_index, DataTypeImm d
                         StringImm storage_scope = StringImm("global"));
 Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm 
dtype);
 
+Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue 
runtime_device_index,
+                     StringImm storage_scope = StringImm("global"));
+
 /**
  * \brief Return the argument of the call.
  *        Note: If this is a call_tir, return the arguments passed to the TIR 
func
diff --git a/src/relax/transform/allocate_workspace.cc 
b/src/relax/transform/allocate_workspace.cc
index 4b26b590ef..fcfbf18771 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -144,8 +144,7 @@ class WorkspaceProvider : ExprMutator {
     if (!workspace_var_main_.defined()) {
       auto shape = ShapeExpr({Integer(max_workspace_size_)});
       auto ty = DataTypeImm(DataType::UInt(8));
-      auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
-      auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, 
ty);
+      auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0));
       workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
     }
     for (const auto& binding : block_node->bindings) {
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 57f47ca6e6..969651f72f 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True, 
num_final_bindings=1):
     mod = partition_for_cutlass(mod)
 
     if assert_all_bindings_fused:
-        assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
+        assert (
+            len(mod["main"].body.blocks[0].bindings) == num_final_bindings
+        ), "Not all bindings are fused. " + str(mod["main"])
 
     codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})
     mod = codegen_pass(mod)
@@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype):
     v_shape = (b, s_kv, n, h_v)
 
     mod = get_relax_attention_module(q_shape, k_shape, v_shape, 
dtype=attention_dtype)
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=2)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
@@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size):
     mod = get_relax_attention_module(
         q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
     )
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=2)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
@@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size, 
attention_scale):
         q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, 
qk_scale=attention_scale
     )
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=2)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=2)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size, 
attention_causal):
     )
 
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=2)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=2)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -932,9 +934,9 @@ def 
test_stacked_attention_split_offload(stacked_attention_size):
         )
 
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=2)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=2)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -950,9 +952,9 @@ def 
test_stacked_attention_strided_slice_offload(stacked_attention_size):
             qkv, b, s, n, h, h_v, "strided_slice", bias, scale, 
single_shape=single_shape
         )
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=2)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=3)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=2)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -1311,9 +1313,8 @@ def test_attention_rewrite_fp16():
             R.func_attr({"num_input": 4})
             cls = Expected
             with R.dataflow():
-                lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), 
R.dtype("uint8"))
-                workspace_main = R.vm.alloc_tensor(
-                    lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+                workspace_main = R.builtin.alloc_tensor(
+                    R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
                 )
                 lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
                 lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
@@ -2419,7 +2420,7 @@ def test_sliding_window():
         1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", 
window_size=window_size
     )
 
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=2)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
diff --git a/tests/python/relax/test_transform_allocate_workspace.py 
b/tests/python/relax/test_transform_allocate_workspace.py
index aca6ea2fe8..1198642d3f 100644
--- a/tests/python/relax/test_transform_allocate_workspace.py
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -126,9 +126,8 @@ class Expected:
     ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
         cls = Expected
         with R.dataflow():
-            lv: R.Object = R.vm.alloc_storage(R.shape([65536]), 
R.prim_value(0), R.dtype("uint8"))
-            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.vm.alloc_tensor(
-                lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.builtin.alloc_tensor(
+                R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
             )
             gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass1(
                 q, k, v, workspace_main
@@ -144,9 +143,8 @@ class Expected:
     ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
         cls = Expected
         with R.dataflow():
-            lv: R.Object = R.vm.alloc_storage(R.shape([65536]), 
R.prim_value(0), R.dtype("uint8"))
-            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.vm.alloc_tensor(
-                lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.builtin.alloc_tensor(
+                R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
             )
             gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass1(
                 q, k, v, workspace_main

Reply via email to