This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 94e3d512dd [Unity][CUTLASS] Fixed stacked attention offload when QKV
reshape uses the same shape expression (#14728)
94e3d512dd is described below
commit 94e3d512dd85de85428276d8e9e885aa230ddc8a
Author: masahi <[email protected]>
AuthorDate: Thu Apr 27 09:05:32 2023 +0900
[Unity][CUTLASS] Fixed stacked attention offload when QKV reshape uses the
same shape expression (#14728)
* fixed stacked attention offload when QKV have the same shape
* add test
---
python/tvm/contrib/cutlass/build.py | 11 ++++--
tests/python/relax/test_codegen_cutlass.py | 54 +++++++++++++++++++-----------
2 files changed, 44 insertions(+), 21 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index e943aec6b1..bf55e8d15b 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -790,8 +790,15 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
arg["arg0_shape"] = signature["arg0_shape"]
arg["arg0_dtype"] = signature["arg0_dtype"]
arg["arg1_shape"] = q_shape = signature["arg1_shape"]
- arg["arg2_shape"] = k_shape = signature["arg2_shape"]
- arg["arg3_shape"] = v_shape = signature["arg3_shape"]
+
+ if "arg2_shape" not in signature:
+ arg["arg2_shape"] = k_shape = signature["arg1_shape"]
+ arg["arg3_shape"] = v_shape = signature["arg1_shape"]
+ else:
+ assert "arg3_shape" in signature
+ arg["arg2_shape"] = k_shape = signature["arg2_shape"]
+ arg["arg3_shape"] = v_shape = signature["arg3_shape"]
+
if "arg4_dtype" in signature:
arg["bias_dtype"] = signature["arg4_dtype"]
if "arg4_shape" in signature:
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index a5fbf0f642..3d0cc3a54c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -660,7 +660,9 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, bias_reshape, q
return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
-def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None,
qk_scale=None):
+def get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False
+):
dtype = str(qkv.dtype)
from tvm.script.ir_builder import IRBuilder
@@ -669,6 +671,13 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h,
h_v, op, bias=None, qk_s
if qk_scale is not None:
qk_scale = T.FloatImm("float32", qk_scale)
+ if single_shape:
+ qk_shape = R.shape([b, s, n, h])
+ v_shape = qk_shape
+ else:
+ qk_shape = [b, s, n, h]
+ v_shape = [b, s, n, h_v]
+
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
@@ -678,17 +687,14 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h,
h_v, op, bias=None, qk_s
with R.dataflow() as frame:
if op == "split":
qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
- q = R.reshape(qkv_tuple[0], [b, s, n, h])
- k = R.reshape(qkv_tuple[1], [b, s, n, h])
- v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+ q = R.reshape(qkv_tuple[0], qk_shape)
+ k = R.reshape(qkv_tuple[1], qk_shape)
+ v = R.reshape(qkv_tuple[2], v_shape)
elif op == "strided_slice":
- q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h],
[1]), [b, s, n, h])
- k = R.reshape(
- R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]),
[b, s, n, h]
- )
+ q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h],
[1]), qk_shape)
+ k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h *
2], [1]), qk_shape)
v = R.reshape(
- R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n
* h_v], [1]),
- [b, s, n, h_v],
+ R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n
* h_v], [1]), v_shape
)
else:
raise NotImplementedError()
@@ -703,9 +709,10 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h,
h_v, op, bias=None, qk_s
@pytest.fixture(
params=[
- # B, S, N, H, bias_shape, bias_reshape, scale
- (4, 8, 32, (64, 32), "none", "none", "none"),
- (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5),
+ # B, S, N, H, bias_shape, bias_reshape, scale, single_shape
+ (4, 8, 32, (64, 32), "none", "none", "none", False),
+ (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5, False),
+ (4, 8, 32, (64, 64), "none", "none", "none", True),
]
)
def stacked_attention_size(request):
@@ -713,14 +720,19 @@ def stacked_attention_size(request):
def test_stacked_attention_split_offload(stacked_attention_size):
- b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+ b, s, n, (h, h_v), bias_shape, bias_reshape, scale, single_shape =
stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(
b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
)
if scale == "none":
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v,
"split", bias)
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape
+ )
else:
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v,
"split", bias, scale)
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "split", bias, scale,
single_shape=single_shape
+ )
+
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv)
else:
@@ -729,14 +741,18 @@ def
test_stacked_attention_split_offload(stacked_attention_size):
def test_stacked_attention_strided_slice_offload(stacked_attention_size):
- b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+ b, s, n, (h, h_v), bias_shape, bias_reshape, scale, single_shape =
stacked_attention_size
qkv, bias, ref = get_numpy_stacked_attention_ref(
b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
)
if scale == "none":
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v,
"strided_slice", bias)
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "strided_slice", bias,
single_shape=single_shape
+ )
else:
- mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v,
"strided_slice", bias, scale)
+ mod = get_relax_stacked_attention_module(
+ qkv, b, s, n, h, h_v, "strided_slice", bias, scale,
single_shape=single_shape
+ )
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, qkv)
else: