This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 63f9cd6523 [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor
(#17110)
63f9cd6523 is described below
commit 63f9cd6523bd827ea297c22cbbb74eaef9def931
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jun 26 08:43:12 2024 -0700
[Relax] Alloc BYOC workspace with R.builtin.alloc_tensor (#17110)
* [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor
This makes the allocation go through memory planning and make it
compatible with cuda graph.
* lint
* lint
---
python/tvm/relax/testing/matmul.py | 3 ++-
src/relax/op/op_common.h | 3 +++
src/relax/transform/allocate_workspace.cc | 3 +--
tests/python/relax/test_codegen_cutlass.py | 31 +++++++++++-----------
.../relax/test_transform_allocate_workspace.py | 10 +++----
5 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/python/tvm/relax/testing/matmul.py
b/python/tvm/relax/testing/matmul.py
index 0ce1225e7d..760ad1bdef 100644
--- a/python/tvm/relax/testing/matmul.py
+++ b/python/tvm/relax/testing/matmul.py
@@ -25,7 +25,7 @@ def get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
- out_dtype,
+ out_dtype=None,
transposed_y=False,
bias_shape=None,
activation=None,
@@ -33,6 +33,7 @@ def get_relax_matmul_module(
residual_activation=None,
):
"""Create a matmul op followd by epilogue operations."""
+ out_dtype = out_dtype if out_dtype is not None else in_dtype
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 94474ce784..ed6725e270 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -558,6 +558,9 @@ Expr MakeVMAllocStorage(Expr size, PrimValue
runtime_device_index, DataTypeImm d
StringImm storage_scope = StringImm("global"));
Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm
dtype);
+Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue
runtime_device_index,
+ StringImm storage_scope = StringImm("global"));
+
/**
* \brief Return the argument of the call.
* Note: If this is a call_tir, return the arguments passed to the TIR
func
diff --git a/src/relax/transform/allocate_workspace.cc
b/src/relax/transform/allocate_workspace.cc
index 4b26b590ef..fcfbf18771 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -144,8 +144,7 @@ class WorkspaceProvider : ExprMutator {
if (!workspace_var_main_.defined()) {
auto shape = ShapeExpr({Integer(max_workspace_size_)});
auto ty = DataTypeImm(DataType::UInt(8));
- auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
- auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape,
ty);
+ auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0));
workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
}
for (const auto& binding : block_node->bindings) {
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 57f47ca6e6..969651f72f 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True,
num_final_bindings=1):
mod = partition_for_cutlass(mod)
if assert_all_bindings_fused:
- assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
+ assert (
+ len(mod["main"].body.blocks[0].bindings) == num_final_bindings
+ ), "Not all bindings are fused. " + str(mod["main"])
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})
mod = codegen_pass(mod)
@@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype):
v_shape = (b, s_kv, n, h_v)
mod = get_relax_attention_module(q_shape, k_shape, v_shape,
dtype=attention_dtype)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size):
mod = get_relax_attention_module(
q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size,
attention_scale):
q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape,
qk_scale=attention_scale
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=2)
else:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size,
attention_causal):
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=2)
else:
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -932,9 +934,9 @@ def
test_stacked_attention_split_offload(stacked_attention_size):
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=2)
else:
- out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -950,9 +952,9 @@ def
test_stacked_attention_strided_slice_offload(stacked_attention_size):
qkv, b, s, n, h, h_v, "strided_slice", bias, scale,
single_shape=single_shape
)
if bias is None:
- out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, qkv,
num_final_bindings=2)
else:
- out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, qkv, bias,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
@@ -1311,9 +1313,8 @@ def test_attention_rewrite_fp16():
R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
- lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0),
R.dtype("uint8"))
- workspace_main = R.vm.alloc_tensor(
- lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+ workspace_main = R.builtin.alloc_tensor(
+ R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
@@ -2419,7 +2420,7 @@ def test_sliding_window():
1, 64, 64, 16, 8, 8, "none", "none", causal, "float16",
window_size=window_size
)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=2)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
diff --git a/tests/python/relax/test_transform_allocate_workspace.py
b/tests/python/relax/test_transform_allocate_workspace.py
index aca6ea2fe8..1198642d3f 100644
--- a/tests/python/relax/test_transform_allocate_workspace.py
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -126,9 +126,8 @@ class Expected:
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
cls = Expected
with R.dataflow():
- lv: R.Object = R.vm.alloc_storage(R.shape([65536]),
R.prim_value(0), R.dtype("uint8"))
- workspace_main: R.Tensor((65536,), dtype="uint8") =
R.vm.alloc_tensor(
- lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+ workspace_main: R.Tensor((65536,), dtype="uint8") =
R.builtin.alloc_tensor(
+ R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
gv: R.Tensor((32, 8, 16, 8), dtype="float16") =
cls.fused_relax_nn_attention_cutlass1(
q, k, v, workspace_main
@@ -144,9 +143,8 @@ class Expected:
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
cls = Expected
with R.dataflow():
- lv: R.Object = R.vm.alloc_storage(R.shape([65536]),
R.prim_value(0), R.dtype("uint8"))
- workspace_main: R.Tensor((65536,), dtype="uint8") =
R.vm.alloc_tensor(
- lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+ workspace_main: R.Tensor((65536,), dtype="uint8") =
R.builtin.alloc_tensor(
+ R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
)
gv: R.Tensor((32, 8, 16, 8), dtype="float16") =
cls.fused_relax_nn_attention_cutlass1(
q, k, v, workspace_main