This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 77b35e80d8 [Unity][BYOC] Add fused patterns for stacked attention 
(#14608)
77b35e80d8 is described below

commit 77b35e80d81c6318dcc272937778f10406fe5aa7
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Apr 12 18:20:22 2023 -0700

    [Unity][BYOC] Add fused patterns for stacked attention (#14608)
    
    * [Unity][BYOC] Add fused patterns for stacked attention
    
    In some models, the input Q, K and V for attention ops are from a
    stacked tensor initially, and then they are splitted and reshaped to call
    attention op, like
    
    stacked_qkv -> split -> reshape -> attention.
    
    Actually, we could to skip the split and reshape ops,
    by manipulating the layout parameters in codegen.
    
    This PR adds the such fused patterns for stacked attention in BYOC.
    So that we are able to codegen directly from stacked_qkv.
    
    * fix lint
    
    * fix lint
---
 python/tvm/contrib/cutlass/attention_operation.py | 92 +++++++++++++++--------
 python/tvm/contrib/cutlass/build.py               | 56 ++++++++------
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 35 +++++----
 python/tvm/relax/backend/contrib/cutlass.py       |  9 +++
 python/tvm/relax/backend/patterns.py              | 49 +++++++++++-
 src/relax/backend/contrib/cutlass/codegen.cc      |  2 +
 tests/python/relax/test_codegen_cutlass.py        | 88 +++++++++++++++++++++-
 7 files changed, 259 insertions(+), 72 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index f7dee4e3b8..7c5b7048a2 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -19,45 +19,87 @@
 from .library import *
 
 
-def instantiate_attention_template(attrs, func_args):
+def instantiate_attention_template(attrs):
     """Return CUTLASS host code for fused multi head attention
     based on a template and the provided attribute map."""
 
     bias_template = {
         "B11S'": """
-  CHECK(${arg3}->ndim == 2); // B, 1, 1, S'
+  CHECK(${bias}->ndim == 2); // B, 1, 1, S'
 
-  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
   p.bias_strideM = 0; // 0
   p.bias_strideH = 0; // 0
   p.bias_strideB = p.num_keys; // S'
 """,
         "B1SS'": """
-  CHECK(${arg3}->ndim == 3); // B, 1, S, S'
+  CHECK(${bias}->ndim == 3); // B, 1, S, S'
 
-  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
   p.bias_strideM = p.num_keys; // S'
   p.bias_strideH = 0; // 0
   p.bias_strideB = p.bias_strideM * p.num_queries; // S' * S
 """,
         "BNSS'": """
-  CHECK(${arg3}->ndim == 4); // B, N, S, S'
+  CHECK(${bias}->ndim == 4); // B, N, S, S'
 
-  p.attn_bias_ptr = reinterpret_cast<T *>(${arg3}->data);
+  p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
   p.bias_strideM = p.num_keys; // S'
   p.bias_strideH = p.bias_strideM * p.num_queries; // S' * S
   p.bias_strideB = p.bias_strideH * p.num_heads; // S' * S * N
 """,
     }
 
+    qkv_template = {
+        "default": """
+  p.query_ptr = reinterpret_cast<T *>(${query}->data);
+  p.key_ptr = reinterpret_cast<T *>(${key}->data);
+  p.value_ptr = reinterpret_cast<T *>(${value}->data);
+  CHECK(${query}->ndim == 4); // B, S, N, H
+  CHECK(${key}->ndim == 4); // B, S', N, H
+  CHECK(${value}->ndim == 4); // B, S', N, H'
+
+  // stride for N
+  p.q_strideH = p.head_dim; // H
+  p.k_strideH = p.head_dim; // H
+  p.v_strideH = p.head_dim_value; // H'
+
+  // stride for S
+  p.q_strideM = p.q_strideH * p.num_heads; // H * N
+  p.k_strideM = p.k_strideH * p.num_heads; // H * N
+  p.v_strideM = p.v_strideH * p.num_heads; // H' * N
+
+  // stride for B
+  p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
+  p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
+  p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
+""",
+        "qkv_stacked": """
+  p.query_ptr = reinterpret_cast<T *>(${qkv}->data);
+  p.key_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads;
+  p.value_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads 
* 2;
+  CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH'
+
+  // stride for N
+  p.q_strideH = p.head_dim; // H
+  p.k_strideH = p.head_dim; // H
+  p.v_strideH = p.head_dim_value; // H'
+
+  // stride for S
+  p.q_strideM = p.k_strideM = p.v_strideM =
+    p.q_strideH * p.num_heads +
+    p.k_strideH * p.num_heads +
+    p.v_strideH * p.num_heads; // H * N + H * N + H * N'
+
+  // stride for B
+  p.q_strideB = p.k_strideB = p.v_strideB =
+    p.q_strideM * p.num_queries; // (H * N + H * N + H * N') * S
+""",
+    }
+
     template = """
   using T = ${data_type};
 
-  CHECK(${arg0}->ndim == 4); // B, S, N, H
-  CHECK(${arg1}->ndim == 4); // B, S', N, H
-  CHECK(${arg2}->ndim == 4); // B, S', N, H'
-  CHECK(out0->ndim == 4); // B, S, N, H'
-
   using Attention =
       AttentionKernel<T,
                       /*ArchTag=*/${arch},
@@ -70,10 +112,6 @@ def instantiate_attention_template(attrs, func_args):
       >;
 
   typename Attention::Params p;
-
-  p.query_ptr = reinterpret_cast<T *>(${arg0}->data);
-  p.key_ptr = reinterpret_cast<T *>(${arg1}->data);
-  p.value_ptr = reinterpret_cast<T *>(${arg2}->data);
   p.logsumexp_ptr = nullptr;
   p.output_ptr = reinterpret_cast<T *>(out0->data);
   p.output_accum_ptr = nullptr;
@@ -92,22 +130,11 @@ def instantiate_attention_template(attrs, func_args):
   p.num_keys = ${num_keys}; // S'
   p.scale = ${scale};
 
-  // stride for N
-  p.q_strideH = p.head_dim; // H
-  p.k_strideH = p.head_dim; // H
-  p.v_strideH = p.head_dim_value; // H'
 
-  // stride for S
-  p.q_strideM = p.q_strideH * p.num_heads; // H * N
-  p.k_strideM = p.k_strideH * p.num_heads; // H * N
-  p.v_strideM = p.v_strideH * p.num_heads; // H' * N
   p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
+  CHECK(out0->ndim == 4); // B, S, N, H'
 
-  // stride for B
-  p.q_strideB = p.q_strideM * p.num_queries; // H * N * S
-  p.k_strideB = p.k_strideM * p.num_keys; // H * N * S'
-  p.v_strideB = p.v_strideM * p.num_keys; // H'* N * S'
-
+  ${qkv_template}
   ${bias_template}
 
   constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
@@ -126,9 +153,10 @@ def instantiate_attention_template(attrs, func_args):
 
     template = substitute_template(
         template,
-        {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" 
in attrs else ""},
+        {
+            "qkv_template": qkv_template[attrs["qkv_layout"]],
+            "bias_template": bias_template[attrs["bias_layout"]] if 
"bias_layout" in attrs else "",
+        },
     )
 
-    for i, arg in enumerate(func_args):
-        attrs["arg{}".format(i)] = arg
     return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 43494991a0..e943aec6b1 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -546,12 +546,17 @@ def _extract_relax_function_signature(f):
 
     for i, arg in enumerate(f.params):
         sinfo = arg.struct_info
-        signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape)
-        signature["arg%d_dtype" % i] = sinfo.dtype
+        if isinstance(sinfo, relax.TensorStructInfo):
+            signature["arg%d_shape" % i] = get_const_tuple(sinfo.shape)
+            signature["arg%d_dtype" % i] = sinfo.dtype
+        elif isinstance(sinfo, relax.ShapeStructInfo):
+            signature["arg%d_shape" % i] = get_const_tuple(sinfo.values)
+        else:
+            raise NotImplementedError()
 
     ret_sinfo = f.ret_struct_info
     if ret_sinfo.shape is not None:
-        signature["ret_shape"] = list(ret_sinfo.shape)
+        signature["ret_shape"] = get_const_tuple(ret_sinfo.shape)
     else:
         signature["ret_shape"] = None
     signature["ret_dtype"] = ret_sinfo.dtype
@@ -779,34 +784,42 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
         else:
             raise ValueError(f"Cannot find call node for attention")
-        q_shape = signature["arg0_shape"]
-        k_shape = signature["arg1_shape"]
-        v_shape = signature["arg2_shape"]
+        arg = {}
+
+        if "stacked_attention" in op_type:
+            arg["arg0_shape"] = signature["arg0_shape"]
+            arg["arg0_dtype"] = signature["arg0_dtype"]
+            arg["arg1_shape"] = q_shape = signature["arg1_shape"]
+            arg["arg2_shape"] = k_shape = signature["arg2_shape"]
+            arg["arg3_shape"] = v_shape = signature["arg3_shape"]
+            if "arg4_dtype" in signature:
+                arg["bias_dtype"] = signature["arg4_dtype"]
+            if "arg4_shape" in signature:
+                arg["bias_shape"] = signature["arg4_shape"]
+            qkv_layout = "qkv_stacked"
+        else:
+            arg["arg0_shape"] = q_shape = signature["arg0_shape"]
+            arg["arg1_shape"] = k_shape = signature["arg1_shape"]
+            arg["arg2_shape"] = v_shape = signature["arg2_shape"]
+            arg["arg0_dtype"] = signature["arg0_dtype"]
+            arg["arg1_dtype"] = signature["arg1_dtype"]
+            arg["arg2_dtype"] = signature["arg2_dtype"]
+            if "arg3_dtype" in signature:
+                arg["bias_dtype"] = signature["arg3_dtype"]
+            if "arg3_shape" in signature:
+                arg["bias_shape"] = signature["arg3_shape"]
+            qkv_layout = "default"
         out_shape = signature["ret_shape"]
-        q_dtype = signature["arg0_dtype"]
-        k_dtype = signature["arg1_dtype"]
-        v_dtype = signature["arg2_dtype"]
         out_dtype = signature["ret_dtype"]
         num_batches, num_queries, num_heads, head_dim = q_shape
         _, num_keys, _, _ = k_shape
         _, _, _, head_dim_value = v_shape
         scale = op_attrs.scale
-        bias = {}
-        if "arg3_dtype" in signature:
-            bias["arg3_dtype"] = signature["arg3_dtype"]
-        if "arg3_shape" in signature:
-            bias["arg3_shape"] = signature["arg3_shape"]
 
         return f.with_attrs(
             {
                 "op_type": op_type,
-                "arg0_dtype": q_dtype,
-                "arg1_dtype": k_dtype,
-                "arg2_dtype": v_dtype,
                 "ret_dtype": out_dtype,
-                "arg0_shape": q_shape,
-                "arg1_shape": k_shape,
-                "arg2_shape": v_shape,
                 "ret_shape": out_shape,
                 "num_batches": num_batches,
                 "num_queries": num_queries,
@@ -816,7 +829,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 "head_dim_value": head_dim_value,
                 "scale": scale,
                 "arch": self.options["sm"],
-                **bias,
+                "qkv_layout": qkv_layout,
+                **arg,
             }
         )
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 61c88c657f..bb4d224329 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -500,12 +500,6 @@ def instantiate_template(func_name, annotations, 
func_args):
         if k in annotations:
             attrs[k] = annotations[k]
 
-    arg0_shape = annotations["arg0_shape"]
-    arg1_shape = annotations["arg1_shape"]
-    attrs["ElementInputA"] = DataTypeTag[dtype_map[annotations["arg0_dtype"]]]
-    attrs["ElementInputB"] = DataTypeTag[dtype_map[annotations["arg1_dtype"]]]
-    attrs["ElementOutput"] = DataTypeTag[dtype_map[annotations["ret_dtype"]]]
-
     headers = []
 
     if "relu" in func_name:
@@ -649,12 +643,12 @@ def instantiate_template(func_name, annotations, 
func_args):
         if "conv2d_transpose" in func_name:
             headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h")
             activation_shape = output_shape
-            output_shape = arg0_shape
+            output_shape = annotations["arg0_shape"]
         elif "backward" in func_name:
             headers.append("cutlass/conv/kernel/default_conv2d_wgrad.h")
-            activation_shape = arg1_shape
+            activation_shape = annotations["arg1_shape"]
             weight_shape = output_shape
-            output_shape = arg0_shape
+            output_shape = annotations["arg0_shape"]
         elif "residual" in func_name:
             
headers.append("cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h")
         else:
@@ -731,13 +725,26 @@ def instantiate_template(func_name, annotations, 
func_args):
         ), "Cutlass may generate nan occasionally when scale == 0.0"
         attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
         attrs["kSupportsDropout"] = False
-        if len(func_args) > 3:
+        attrs["qkv_layout"] = annotations["qkv_layout"]
+        if attrs["qkv_layout"] == "default":
+            attrs["query"] = func_args[0]
+            attrs["key"] = func_args[1]
+            attrs["value"] = func_args[2]
+            if len(func_args) > 3:
+                attrs["bias"] = func_args[3]
+        elif attrs["qkv_layout"] == "qkv_stacked":
+            attrs["qkv"] = func_args[0]
+            if len(func_args) > 4:
+                attrs["bias"] = func_args[4]
+        else:
+            raise NotImplementedError()
+        if "bias" in attrs:
             attrs["kSupportsBias"] = True
-            if len(annotations["arg3_shape"]) == 4:
+            if len(annotations["bias_shape"]) == 4:
                 attrs["bias_layout"] = "BNSS'"
-            elif len(annotations["arg3_shape"]) == 3:
+            elif len(annotations["bias_shape"]) == 3:
                 attrs["bias_layout"] = "B1SS'"
-            elif len(annotations["arg3_shape"]) == 2:
+            elif len(annotations["bias_shape"]) == 2:
                 attrs["bias_layout"] = "B11S'"
             else:
                 raise NotImplementedError()
@@ -745,7 +752,7 @@ def instantiate_template(func_name, annotations, func_args):
             # To support negative scale in current Cutlass implementation,
             # kSupportsBias should be set true, or there are nan's as result.
             attrs["kSupportsBias"] = attrs["scale"] < 0
-        code = instantiate_attention_template(attrs, func_args)
+        code = instantiate_attention_template(attrs)
         return CodegenResult(code, headers)
 
     raise ValueError("Do not have a template for {}".format(func_name))
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 856cd4d787..4515118f58 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -29,6 +29,7 @@ from ..patterns import (
     make_fused_bias_activation_pattern,
     make_matmul_pattern,
     make_residual_block_pattern,
+    make_stacked_attention_pattern,
 )
 
 
@@ -244,6 +245,14 @@ def attention_patterns():
             "cutlass.attention_bias",
             *make_attention_pattern(with_bias=True),
         ),
+        (
+            "cutlass.stacked_attention",
+            *make_stacked_attention_pattern(),
+        ),
+        (
+            "cutlass.stacked_attention",
+            *make_stacked_attention_pattern(with_bias=True),
+        ),
     ]
 
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index e27b91b3ea..9e34b0c964 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -19,7 +19,7 @@
 
 from typing import Dict, Mapping, Tuple, Union
 
-from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
+from tvm.relax.dpl.pattern import DFPattern, is_op, is_tuple_get_item, wildcard
 
 
 def _with_bias_activation_pattern(
@@ -168,6 +168,11 @@ def make_attention_pattern(with_bias: bool = False):
     """
     Create pattern for fused multi head attention.
 
+    Parameters
+    ----------
+    with_bias: bool
+        Whether or not to include bias addition
+
     Returns
     -------
     pattern: DFPattern
@@ -190,3 +195,45 @@ def make_attention_pattern(with_bias: bool = False):
         out = is_op("relax.nn.attention")(query, key, value)
 
     return out, annotations
+
+
+def make_stacked_attention_pattern(with_bias: bool = False):
+    """
+    Create pattern for fused multi head attention with stacked input.
+
+    Parameters
+    ----------
+    with_bias: bool
+        Whether or not to include bias addition
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing a fused multi head attention.
+
+    annotations: Mapping[str, DFPattern]
+        A mapping from name to sub pattern. It can be used to extract
+        important expressions from match result, to power the partition
+        check function and codegen.
+    """
+    stacked_qkv = wildcard()
+    qkv_tuple = is_op("relax.split")(stacked_qkv)
+    query_reshape_list = wildcard()
+    key_reshape_list = wildcard()
+    value_reshape_list = wildcard()
+    query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), 
query_reshape_list)
+    key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), 
key_reshape_list)
+    value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), 
value_reshape_list)
+    annotations = {
+        "stacked_qkv": stacked_qkv,
+        "query_reshape_list": query_reshape_list,
+        "key_reshape_list": key_reshape_list,
+        "value_reshape_list": value_reshape_list,
+    }
+    if with_bias:
+        bias = wildcard()
+        annotations["bias"] = bias
+        out = is_op("relax.nn.attention_bias")(query, key, value, bias)
+    else:
+        out = is_op("relax.nn.attention")(query, key, value)
+    return out, annotations
diff --git a/src/relax/backend/contrib/cutlass/codegen.cc 
b/src/relax/backend/contrib/cutlass/codegen.cc
index 8ef68baf68..730d098510 100644
--- a/src/relax/backend/contrib/cutlass/codegen.cc
+++ b/src/relax/backend/contrib/cutlass/codegen.cc
@@ -59,6 +59,8 @@ class CodegenCutlass : public 
relax::MemoizedExprTranslator<OutputType>,
       auto sinfo = GetStructInfo(arg);
       if (const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
         arg_types.emplace_back(backend::DType2String(tensor_sinfo->dtype));
+      } else if (const auto* shape_sinfo = sinfo.as<ShapeStructInfoNode>()) {
+        
arg_types.emplace_back(backend::DType2String(shape_sinfo->values.value()[0]->dtype));
       } else {
         LOG(FATAL) << "Unimplemented";
       }
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index b9ba4f4dc9..9288db3eb5 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -592,10 +592,10 @@ def attention_bias_size(request):
     return request.param
 
 
-def test_attention_bias_offload(attention_bias_size, attention_dtype):
+def test_attention_bias_offload(attention_bias_size):
     b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size
     q, k, v, bias, ref = get_numpy_attention_ref(
-        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", 
attention_dtype
+        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", "float32"
     )
 
     mod = get_relax_attention_module(q, k, v, bias)
@@ -620,10 +620,10 @@ def attention_scale(request):
     return request.param
 
 
-def test_attention_scale_offload(attention_scale_size, attention_scale, 
attention_dtype):
+def test_attention_scale_offload(attention_scale_size, attention_scale):
     b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size
     q, k, v, bias, ref = get_numpy_attention_ref(
-        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, 
attention_dtype
+        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, 
"float32"
     )
 
     mod = get_relax_attention_module(q, k, v, bias, attention_scale)
@@ -634,5 +634,85 @@ def test_attention_scale_offload(attention_scale_size, 
attention_scale, attentio
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+@memoize("topi.tests.test_codegen_cutlass.test_stacked_attention_offload")
+def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, 
qk_scale, dtype):
+    qkv = np.random.randn(b, s, n * h + n * h + n * h_v).astype(dtype)
+    split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2)
+    q = np.reshape(split_qkv[0], (b, s, n, h))
+    k = np.reshape(split_qkv[1], (b, s, n, h))
+    v = np.reshape(split_qkv[2], (b, s, n, h_v))
+    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
+    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s
+    if not qk_scale == "none":
+        score = qt @ kt * qk_scale  # b, n, s, s
+    else:
+        score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s
+    if not bias_shape == "none":
+        bias = np.random.randn(*bias_shape).astype(dtype)
+        score = score + bias.reshape(*bias_reshape)  # b, n, s, s
+    else:
+        bias = None
+    attn = tvm.topi.testing.softmax_python(score, -1)
+    vt = v.transpose(0, 2, 1, 3)  # b, n, s, h_v
+    ref = attn @ vt  # b, n, s, h_v
+    return qkv, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+
+
+def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, 
qk_scale=None):
+    dtype = str(qkv.dtype)
+
+    from tvm.script.ir_builder import IRBuilder
+    from tvm.script.ir_builder import relax as relax_builder, tir as T
+
+    if qk_scale is not None:
+        qk_scale = T.FloatImm("float32", qk_scale)
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype))
+            if bias is not None:
+                bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+            with R.dataflow() as frame:
+                qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
+                q = R.reshape(qkv_tuple[0], [b, s, n, h])
+                k = R.reshape(qkv_tuple[1], [b, s, n, h])
+                v = R.reshape(qkv_tuple[2], [b, s, n, h_v])
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
[email protected](
+    params=[
+        # B, S, N, H, bias_shape, bias_reshape, scale
+        (4, 8, 32, (64, 32), "none", "none", "none"),
+        (4, 8, 32, (64, 32), (4, 32, 8, 8), (4, 32, 8, 8), 0.5),
+    ]
+)
+def stacked_attention_size(request):
+    return request.param
+
+
+def test_stacked_attention_offload(stacked_attention_size):
+    b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size
+    qkv, bias, ref = get_numpy_stacked_attention_ref(
+        b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32"
+    )
+    if scale == "none":
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias)
+    else:
+        mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, 
scale)
+    if bias is None:
+        out = get_result_with_relax_cutlass_offload(mod, qkv)
+    else:
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to