This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 94e3d512dd [Unity][CUTLASS] Fixed stacked attention offload when QKV 
reshape uses the same shape expression (#14728)
94e3d512dd is described below

commit 94e3d512dd85de85428276d8e9e885aa230ddc8a
Author: masahi <[email protected]>
AuthorDate: Thu Apr 27 09:05:32 2023 +0900

    [Unity][CUTLASS] Fixed stacked attention offload when QKV reshape uses the 
same shape expression (#14728)
    
    * fixed stacked attention offload when QKV have the same shape
    
    * add test
---
 python/tvm/contrib/cutlass/build.py        | 11 ++++--
 tests/python/relax/test_codegen_cutlass.py | 54 +++++++++++++++++++-----------
 2 files changed, 44 insertions(+), 21 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index e943aec6b1..bf55e8d15b 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -790,8 +790,15 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             arg["arg0_shape"] = signature["arg0_shape"]
             arg["arg0_dtype"] = signature["arg0_dtype"]
             arg["arg1_shape"] = q_shape = signature["arg1_shape"]
-            arg["arg2_shape"] = k_shape = signature["arg2_shape"]
-            arg["arg3_shape"] = v_shape = signature["arg3_shape"]
+
+            if "arg2_shape" not in signature:
+                arg["arg2_shape"] = k_shape = signature["arg1_shape"]
+                arg["arg3_shape"] = v_shape = signature["arg1_shape"]
+            else:
+                assert "arg3_shape" in signature
+                arg["arg2_shape"] = k_shape = signature["arg2_shape"]
+                arg["arg3_shape"] = v_shape = signature["arg3_shape"]
+
             if "arg4_dtype" in signature:
                 arg["bias_dtype"] = signature["arg4_dtype"]
             if "arg4_shape" in signature:
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index a5fbf0f642..3d0cc3a54c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -660,7 +660,9 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, bias_reshape, q
     return qkv, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
 
 
-def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, 
qk_scale=None):
+def get_relax_stacked_attention_module(
+    qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False
+):
     dtype = str(qkv.dtype)
 
     from tvm.script.ir_builder import IRBuilder
@@ -669,6 +671,13 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, 
h_v, op, bias=None, qk_s
     if qk_scale is not None:
         qk_scale = T.FloatImm("float32", qk_scale)
 
+    if single_shape:
+        qk_shape = R.shape([b, s, n, h])
+        v_shape = qk_shape
+    else:
+        qk_shape = [b, s, n, h]
+        v_shape = [b, s, n, h_v]
+
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
@@ -678,17 +687,14 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, 
h_v, op, bias=None, qk_s
             with R.dataflow() as frame:
                 if op == "split":
                     qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
-                    q = R.reshape(qkv_tuple[0], [b, s, n, h])
-                    k = R.reshape(qkv_tuple[1], [b, s, n, h])
-                    v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+                    q = R.reshape(qkv_tuple[0], qk_shape)
+                    k = R.reshape(qkv_tuple[1], qk_shape)
+                    v = R.reshape(qkv_tuple[2], v_shape)
                 elif op == "strided_slice":
-                    q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], 
[1]), [b, s, n, h])
-                    k = R.reshape(
-                        R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), 
[b, s, n, h]
-                    )
+                    q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], 
[1]), qk_shape)
+                    k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h * 
2], [1]), qk_shape)
                     v = R.reshape(
-                        R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n 
* h_v], [1]),
-                        [b, s, n, h_v],
+                        R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n 
* h_v], [1]), v_shape
                     )
                 else:
                     raise NotImplementedError()
@@ -703,9 +709,10 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, 
h_v, op, bias=None, qk_s
 
 @pytest.fixture(
     params=[
-        # B, S, N, H, bias_shape, bias_reshape, scale
-        (4, 8, 32, (64, 32), "none", "none", "none"),
-        (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5),
+        # B, S, N, H, bias_shape, bias_reshape, scale, single_shape
+        (4, 8, 32, (64, 32), "none", "none", "none", False),
+        (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5, False),
+        (4, 8, 32, (64, 64), "none", "none", "none", True),
     ]
 )
 def stacked_attention_size(request):
@@ -713,14 +720,19 @@ def stacked_attention_size(request):
 
 
 def test_stacked_attention_split_offload(stacked_attention_size):
-    b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+    b, s, n, (h, h_v), bias_shape, bias_reshape, scale, single_shape = 
stacked_attention_size
     qkv, bias, ref = get_numpy_stacked_attention_ref(
         b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
     )
     if scale == "none":
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, 
"split", bias)
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape
+        )
     else:
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, 
"split", bias, scale)
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "split", bias, scale, 
single_shape=single_shape
+        )
+
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, qkv)
     else:
@@ -729,14 +741,18 @@ def 
test_stacked_attention_split_offload(stacked_attention_size):
 
 
 def test_stacked_attention_strided_slice_offload(stacked_attention_size):
-    b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+    b, s, n, (h, h_v), bias_shape, bias_reshape, scale, single_shape = 
stacked_attention_size
     qkv, bias, ref = get_numpy_stacked_attention_ref(
         b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
     )
     if scale == "none":
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, 
"strided_slice", bias)
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "strided_slice", bias, 
single_shape=single_shape
+        )
     else:
-        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, 
"strided_slice", bias, scale)
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "strided_slice", bias, scale, 
single_shape=single_shape
+        )
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, qkv)
     else:

Reply via email to