This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d5edd2fd8 [Relax] Integrate cuDNN attention (#17157)
5d5edd2fd8 is described below

commit 5d5edd2fd8b891bb74681f83095d606739cadfcb
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Jul 22 12:36:06 2024 -0700

    [Relax] Integrate cuDNN attention (#17157)
    
    * [Relax] Integrate cuDNN attention
    
    * update cmake
    
    * lint
    
    * lint
    
    * cudnn frontend
    
    * lint
    
    * lint
    
    * fix test
    
    * skip test
---
 cmake/config.cmake                                 |   7 +
 cmake/modules/CUDA.cmake                           |  16 ++
 python/tvm/contrib/cutlass/build.py                |  32 +--
 python/tvm/contrib/cutlass/gen_tensor_op.py        |   4 +-
 python/tvm/relax/backend/contrib/cudnn.py          |  99 +++++++-
 python/tvm/relax/backend/contrib/cutlass.py        |  18 +-
 python/tvm/relax/backend/patterns.py               |  32 ++-
 python/tvm/relax/frontend/nn/op.py                 |   9 +-
 python/tvm/relax/testing/__init__.py               |   1 +
 python/tvm/relax/testing/attention.py              | 148 ++++++++++++
 python/tvm/topi/testing/__init__.py                |   1 +
 python/tvm/topi/testing/attention_python.py        | 122 ++++++++++
 src/relax/backend/contrib/cudnn/codegen.cc         |  47 ++++
 src/relax/transform/allocate_workspace.cc          |   9 +-
 src/relax/transform/fuse_ops.cc                    |  19 +-
 .../contrib/cudnn/cudnn_frontend/attention.cc      | 124 ++++++++++
 .../contrib/cudnn/cudnn_frontend/attention.h       |  83 +++++++
 src/runtime/contrib/cudnn/cudnn_json_runtime.cc    | 267 ++++++++++++---------
 tests/python/relax/test_codegen_cudnn.py           |  65 ++++-
 tests/python/relax/test_codegen_cutlass.py         | 213 +++++-----------
 .../relax/test_transform_allocate_workspace.py     |   3 +-
 .../test_transform_merge_composite_functions.py    |   5 +-
 22 files changed, 1010 insertions(+), 314 deletions(-)

diff --git a/cmake/config.cmake b/cmake/config.cmake
index 416eec0dcb..26d50630f7 100644
--- a/cmake/config.cmake
+++ b/cmake/config.cmake
@@ -245,6 +245,13 @@ set(USE_EDGETPU OFF)
 # - /path/to/cudnn: use specific path to cuDNN path
 set(USE_CUDNN OFF)
 
+# Whether use cuDNN frontend
+# Possible values:
+# - ON: enable cuDNN frontend
+# - /path/to/cudnn_frontend: use specific path to cuDNN frontend
+# - OFF: disable cuDNN frontend
+set(USE_CUDNN_FRONTEND OFF)
+
 # Whether use cuBLAS
 set(USE_CUBLAS OFF)
 
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index b7b405f822..ad83ebe26b 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -77,6 +77,22 @@ if(USE_CUDA)
     list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY})
   endif(USE_CUDNN)
 
+  if (USE_CUDNN_FRONTEND)
+    message(STATUS "Build with cuDNN Frontend support")
+    if (IS_DIRECTORY ${USE_CUDNN_FRONTEND})
+      find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h HINTS 
${USE_CUDNN_FRONTEND}/include)
+      include_directories(SYSTEM ${USE_CUDNN_FRONTEND}/include)
+    else()
+      find_file(CUDNN_FRONTEND_HEADER cudnn_frontend.h)
+    endif()
+    if (NOT CUDNN_FRONTEND_HEADER)
+      message(FATAL_ERROR "Cannot find cudnn_frontend.h, please set 
USE_CUDNN_FRONTEND to the path of the cuDNN frontend header")
+    endif()
+    tvm_file_glob(GLOB CONTRIB_CUDNN_FRONTEND_SRCS 
src/runtime/contrib/cudnn/cudnn_frontend/*.cc)
+    set_property(SOURCE ${CONTRIB_CUDNN_SRCS} APPEND PROPERTY 
COMPILE_DEFINITIONS TVM_USE_CUDNN_FRONTEND=1)
+    list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_FRONTEND_SRCS})
+  endif(USE_CUDNN_FRONTEND)
+
   if(USE_CUBLAS)
     message(STATUS "Build with cuBLAS support")
     tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC 
src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 1c0a30c62d..5c09c79bd9 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -868,34 +868,26 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         signature = _extract_relax_function_signature(f)
 
         if _get_call_node(f.body, "relax.nn.attention") is not None:
-            op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
+            attention_node = _get_call_node(f.body, "relax.nn.attention")
+            op_attrs = attention_node.attrs
         elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
-            op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
+            attention_node = _get_call_node(f.body, "relax.nn.attention_bias")
+            op_attrs = attention_node.attrs
         elif _get_call_node(f.body, "relax.nn.attention_var_len") is not None:
-            op_attrs = _get_call_node(f.body, 
"relax.nn.attention_var_len").attrs
+            attention_node = _get_call_node(f.body, 
"relax.nn.attention_var_len")
+            op_attrs = attention_node.attrs
         else:
             raise ValueError("Cannot find call node for attention")
         arg = {}
 
         if "stacked_attention" in op_type:
-            arg["arg0_shape"] = signature["arg0_shape"]
             arg["arg0_dtype"] = signature["arg0_dtype"]
-            arg["arg1_shape"] = q_shape = signature["arg1_shape"]
-
-            if "arg3_shape" not in signature:
-                # arg0: qkv, arg1: shape, arg2: workspace
-                arg["arg2_shape"] = k_shape = signature["arg1_shape"]
-                arg["arg3_shape"] = v_shape = signature["arg1_shape"]
-            else:
-                # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: 
workspace
-                arg["arg2_shape"] = k_shape = signature["arg2_shape"]
-                arg["arg3_shape"] = v_shape = signature["arg3_shape"]
-
-            if "arg5_dtype" in signature:
-                # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: 
bias, arg5: workspace
-                arg["bias_dtype"] = signature["arg4_dtype"]
-            if "arg5_shape" in signature:
-                arg["bias_shape"] = signature["arg4_shape"]
+            q_shape = get_const_tuple(attention_node.args[0].struct_info.shape)
+            k_shape = get_const_tuple(attention_node.args[1].struct_info.shape)
+            v_shape = get_const_tuple(attention_node.args[2].struct_info.shape)
+            if len(attention_node.args) == 4:
+                arg["bias_shape"] = 
get_const_tuple(attention_node.args[3].struct_info.shape)
+                arg["bias_dtype"] = attention_node.args[3].struct_info.dtype
 
             qkv_layout = "qkv_stacked"
         else:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 2f21a1d313..5d04cf13e6 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -745,8 +745,8 @@ def instantiate_template(func_name, annotations, func_args):
             attrs["qkv"] = func_args[0]
             attrs["num_queries"] = s = annotations["num_queries"]
             attrs["num_keys"] = annotations["num_keys"]
-            if len(func_args) > 5 and not is_var_len:  # +1 for workspace, the 
last arg
-                attrs["bias"] = func_args[4]
+            if len(func_args) > 2 and not is_var_len:  # +1 for workspace, the 
last arg
+                attrs["bias"] = func_args[1]
         else:
             raise NotImplementedError()
 
diff --git a/python/tvm/relax/backend/contrib/cudnn.py 
b/python/tvm/relax/backend/contrib/cudnn.py
index f730d4e5be..2f15e3a4fd 100644
--- a/python/tvm/relax/backend/contrib/cudnn.py
+++ b/python/tvm/relax/backend/contrib/cudnn.py
@@ -16,11 +16,16 @@
 # under the License.
 
 """Pattern table for cuDNN backend"""
-from tvm.relax import transform
+import operator
+from functools import partial, reduce
+
+import tvm
+from tvm import relax
+from tvm.relax import PyExprMutator, expr_functor, transform
 from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
-from ..patterns import make_conv2d_pattern
+from ..patterns import make_conv2d_pattern, make_stacked_attention_pattern
 from ..utils import has_leaking_intermediate_variables
 
 
@@ -60,6 +65,29 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
     return True
 
 
+def _check_stacked_attention(context: PatternCheckContext, layout: str) -> 
bool:
+    """Check if the given stacked attention workload can be offloaded to 
cuDNN."""
+    if has_leaking_intermediate_variables(context):
+        return False
+    if layout == "BS3NH":
+        if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
+            return False
+        if "split" in context.annotated_expr:
+            split_op = context.annotated_expr["split"]
+            if not split_op.attrs.axis == 2:
+                return False
+    elif layout == "SBN3H":
+        if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 4:
+            return False
+        if "split" in context.annotated_expr:
+            split_op = context.annotated_expr["split"]
+            if not split_op.attrs.axis == 3:
+                return False
+    else:
+        raise NotImplementedError(f"Unsupported layout: {layout}")
+    return True
+
+
 register_patterns(
     [
         (
@@ -84,6 +112,16 @@ register_patterns(
             ),
             _check_conv2d,
         ),
+        (
+            "cudnn.attention.BS3NH",
+            *make_stacked_attention_pattern(start_op="split", layout="BS3NH"),
+            partial(_check_stacked_attention, layout="BS3NH"),
+        ),
+        (
+            "cudnn.attention.SBN3H",
+            *make_stacked_attention_pattern(start_op="split", layout="SBN3H"),
+            partial(_check_stacked_attention, layout="SBN3H"),
+        ),
     ]
 )
 
@@ -105,4 +143,59 @@ def partition_for_cudnn(mod):
     """
 
     patterns = get_patterns_with_prefix("cudnn")
-    return transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True)(mod)
+    return tvm.transform.Sequential(
+        [
+            transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True),
+            annotate_workspace,
+            transform.AllocateWorkspace(),
+        ]
+    )(mod)
+
+
+def _shape_1d(shape):
+    return reduce(operator.mul, shape, 1)
+
+
+@expr_functor.mutator
+class WorkspaceAnnotator(PyExprMutator):
+    """Annotate a workspace requirement for each cuDNN-offloaded function."""
+
+    def __init__(self, mod):
+        super().__init__(mod)
+
+    def visit_function_(self, f):
+        if "Composite" not in f.attrs:
+            body = super().visit_expr(f.body)
+            new_f = relax.Function(f.params, body, f.ret_struct_info, 
f.is_pure, f.attrs, f.span)
+
+            if "global_symbol" in f.attrs and "cudnn" in 
f.attrs["global_symbol"]:
+                composite_func = body.blocks[0].bindings[0].value
+                if "WorkspaceSize" in composite_func.attrs:
+                    return new_f.with_attr("WorkspaceSize", 
composite_func.attrs["WorkspaceSize"])
+
+            return new_f
+
+        if "attention" in f.attrs["Composite"] and "cudnn" in 
f.attrs["Composite"]:
+            # Workspace is needed only for larger head sizes, but for 
simplicity we always allocate.
+            out_dtype = f.ret_struct_info.dtype
+            out_size_1d = _shape_1d(f.ret_struct_info.shape)
+            # This needs to be in sync with the actual value that the kernel 
expects.
+            workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 
4}[out_dtype]
+            if not isinstance(workspace_size_bytes, (int, 
tvm.tir.expr.IntImm)):
+                # Tempororay workaround for dynamic shape workload. Will be 
removed when
+                # workspace for dynamic shape workload is implemented.
+                workspace_size_bytes = 8
+            return f.with_attr("WorkspaceSize", workspace_size_bytes)
+
+        return f
+
+
[email protected]_pass(opt_level=0)
+def annotate_workspace(mod, _):
+    """Pass to annotate a workspace requirement for each cuDNN-offloaded 
function."""
+    annotator = WorkspaceAnnotator(mod)
+    for name, f in mod.functions_items():
+        if isinstance(f, relax.Function):
+            new_f = annotator.visit_expr(f)
+            mod.update_func(name, new_f)
+    return mod
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 0d9f4ff8e9..80979bbe7e 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -383,19 +383,25 @@ def _check_stacked_attention(context: 
PatternCheckContext) -> bool:
         if not split_op.attrs.axis == 2:
             return False
     else:
+        get_const_int_list = lambda tup: [int(e.value) for e in tup]
         last_end = 0
         for name in ["query", "key", "value"]:
             assert f"strided_slice_{name}" in context.annotated_expr
             strided_slice_op = context.annotated_expr[f"strided_slice_{name}"]
-            if list(strided_slice_op.attrs.axes) != [2]:
+            axes = get_const_int_list(strided_slice_op.args[1])
+            begins = get_const_int_list(strided_slice_op.args[2])
+            ends = get_const_int_list(strided_slice_op.args[3])
+            strides = get_const_int_list(strided_slice_op.args[4])
+
+            if axes != [2]:
                 return False
-            if list(strided_slice_op.attrs.begin) != [last_end]:
+            if begins != [last_end]:
                 return False
-            if not len(strided_slice_op.attrs.end) == 1:
+            if not len(ends) == 1:
                 return False
-            last_end = strided_slice_op.attrs.end[0]
-            if list(strided_slice_op.attrs.strides) != [1]:
+            if strides != [1]:
                 return False
+            last_end = ends[0]
     return True
 
 
@@ -537,7 +543,7 @@ class WorkspaceAnnotator(PyExprMutator):
 
             return new_f
 
-        if "attention" in f.attrs["Composite"]:
+        if "attention" in f.attrs["Composite"] and "cutlass" in 
f.attrs["Composite"]:
             # Workspace is needed only for larger head sizes, but for 
simplicity we always allocate.
             out_dtype = f.ret_struct_info.dtype
             out_size_1d = _shape_1d(f.ret_struct_info.shape)
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 8ec43f1f27..1faef9cceb 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -260,7 +260,7 @@ def make_attention_pattern(with_bias: bool = False, 
var_len: bool = False):
     return out, annotations
 
 
-def make_stacked_attention_pattern(start_op: str, with_bias: bool = False):
+def make_stacked_attention_pattern(start_op: str, with_bias: bool = False, 
layout="BS3NH"):
     """
     Create pattern for fused multi head attention with stacked input.
 
@@ -272,6 +272,9 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
     with_bias: bool
         Whether or not to include bias addition
 
+    layout: str
+        The layout of the stacked input tensor.
+
     Returns
     -------
     pattern: DFPattern
@@ -290,17 +293,28 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
         key_raw = is_tuple_get_item(qkv_tuple, 1)
         value_raw = is_tuple_get_item(qkv_tuple, 2)
     elif start_op == "strided_slice":
-        ops["strided_slice_query"] = query_raw = 
is_op("relax.strided_slice")(stacked_qkv)
-        ops["strided_slice_key"] = key_raw = 
is_op("relax.strided_slice")(stacked_qkv)
-        ops["strided_slice_value"] = value_raw = 
is_op("relax.strided_slice")(stacked_qkv)
+        ops["strided_slice_query"] = query_raw = is_op("relax.strided_slice")(
+            stacked_qkv, varg_default_wildcard=True
+        )
+        ops["strided_slice_key"] = key_raw = is_op("relax.strided_slice")(
+            stacked_qkv, varg_default_wildcard=True
+        )
+        ops["strided_slice_value"] = value_raw = is_op("relax.strided_slice")(
+            stacked_qkv, varg_default_wildcard=True
+        )
     else:
         raise NotImplementedError()
     query_reshape_list = wildcard()
     key_reshape_list = wildcard()
     value_reshape_list = wildcard()
-    query = is_op("relax.reshape")(query_raw, query_reshape_list)
-    key = is_op("relax.reshape")(key_raw, key_reshape_list)
-    value = is_op("relax.reshape")(value_raw, value_reshape_list)
+    if layout == "BS3NH":
+        query = is_op("relax.reshape")(query_raw, query_reshape_list)
+        key = is_op("relax.reshape")(key_raw, key_reshape_list)
+        value = is_op("relax.reshape")(value_raw, value_reshape_list)
+    elif layout == "SBN3H":
+        ops["q_transpose"] = query = is_op("relax.permute_dims")(query_raw)
+        ops["k_transpose"] = key = is_op("relax.permute_dims")(key_raw)
+        ops["v_transpose"] = value = is_op("relax.permute_dims")(value_raw)
     annotations = {
         "stacked_qkv": stacked_qkv,
         "query_reshape_list": query_reshape_list,
@@ -314,6 +328,10 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
         out = is_op("relax.nn.attention_bias")(query, key, value, bias)
     else:
         out = is_op("relax.nn.attention")(query, key, value)
+
+    if layout == "SBN3H":
+        out = is_op("relax.permute_dims")(out)
+
     return out, annotations
 
 
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 725a930fd6..ec072f663c 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1568,11 +1568,14 @@ def scaled_dot_product_attention(
     Parameters
     ----------
     query : Tensor
-        Tensor representing current attention lookup.
+        Tensor representing current attention lookup of shape
+        [batch, seq_len, num_heads, head_size].
     key : Tensor
-        Tensor representing cross attention mapping.
+        Tensor representing cross attention mapping of shape
+        [batch, seq_len_kv, num_heads_kv, head_size].
     value : Tensor
-        Tensor representing embedded attention values.
+        Tensor representing embedded attention values of shape
+        [batch, seq_len_kv, num_heads_kv, head_size_value].
     attn_mask : Optional[Tensor]
         Optional mask for attention, not yet supported.
     is_causal : Optional[bool]
diff --git a/python/tvm/relax/testing/__init__.py 
b/python/tvm/relax/testing/__init__.py
index 4256ebc3be..dc43d6c1f8 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -21,3 +21,4 @@ from .nn import *
 from .relay_translator import *
 from .ast_printer import dump_ast
 from .matmul import *
+from .attention import *
diff --git a/python/tvm/relax/testing/attention.py 
b/python/tvm/relax/testing/attention.py
new file mode 100644
index 0000000000..a00674394b
--- /dev/null
+++ b/python/tvm/relax/testing/attention.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Relax script for attention module."""
+import tvm
+from tvm.script import relax as R, tir as T
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_relax_attention_module(
+    q_shape,
+    k_shape,
+    v_shape,
+    *,
+    dtype,
+    bias_shape=None,
+    qk_scale=None,
+    causal_mask=None,
+    window_size=None,
+):  # pylint: disable=too-many-arguments, too-many-locals, invalid-name
+    """Get a relax module for attention."""
+
+    if qk_scale is not None:
+        qk_scale = T.FloatImm("float32", qk_scale)
+
+    if window_size is not None:
+        window_size = T.IntImm("int32", window_size)
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            q = R.arg("q", R.Tensor(q_shape, dtype))
+            k = R.arg("k", R.Tensor(k_shape, dtype))
+            v = R.arg("v", R.Tensor(v_shape, dtype))
+            bias = None
+            if bias_shape is not None and bias_shape != "none":
+                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
+
+            with R.dataflow() as frame:
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal_mask, window_size))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
+
+
+def get_relax_stacked_attention_module(
+    qkv,
+    b,
+    s,
+    n,
+    h,
+    h_v,
+    op,
+    bias=None,
+    qk_scale=None,
+    single_shape=False,
+    layout="BS3NH",
+):  # pylint: disable=too-many-arguments, too-many-locals, too-many-branches, 
invalid-name
+    # pylint: disable=too-many-statements
+    """Get a relax module for stacked attention."""
+    dtype = str(qkv.dtype)
+    assert layout in ["BS3NH", "SBN3H"]
+
+    if qk_scale is not None:
+        qk_scale = T.FloatImm("float32", qk_scale)
+
+    if single_shape:
+        if layout == "BS3NH":
+            qk_shape = R.shape([b, s, n, h])
+        elif layout == "SBN3H":
+            qk_shape = R.shape([b, s, n, h])
+        v_shape = qk_shape
+    else:
+        if layout == "BS3NH":
+            qk_shape = [b, s, n, h]
+            v_shape = [b, s, n, h_v]
+        elif layout == "SBN3H":
+            qk_shape = [s, b, n, h]
+            v_shape = [s, b, n, h_v]
+
+    if layout == "BS3NH":
+        split_axis = 2
+        split_sections = [n * h, n * h * 2]
+    elif layout == "SBN3H":
+        split_axis = 3
+        split_sections = [h, h * 2]
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype))
+            if bias is not None:
+                bias = R.arg("bias", R.Tensor(bias.shape, dtype))
+            with R.dataflow() as frame:
+                if op == "split":
+                    qkv_tuple = R.split(qkv, split_sections, axis=split_axis)
+                    q = qkv_tuple[0]
+                    k = qkv_tuple[1]
+                    v = qkv_tuple[2]
+                elif op == "strided_slice":
+                    q = R.strided_slice(qkv, [split_axis], [0], 
[split_sections[0]], [1])
+                    k = R.strided_slice(
+                        qkv, [split_axis], [split_sections[0]], 
[split_sections[1]], [1]
+                    )
+                    v = R.strided_slice(
+                        qkv,
+                        [split_axis],
+                        [split_sections[1]],
+                        [int(qkv.struct_info.shape[split_axis])],
+                        [1],
+                    )
+                else:
+                    raise NotImplementedError()
+                if layout == "BS3NH":
+                    q = R.reshape(q, qk_shape)
+                    k = R.reshape(k, qk_shape)
+                    v = R.reshape(v, v_shape)
+                elif layout == "SBN3H":
+                    q = R.permute_dims(q, [1, 0, 2, 3])
+                    k = R.permute_dims(k, [1, 0, 2, 3])
+                    v = R.permute_dims(v, [1, 0, 2, 3])
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
+                if layout == "SBN3H":
+                    result = R.emit(R.permute_dims(result, [1, 0, 2, 3]))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
diff --git a/python/tvm/topi/testing/__init__.py 
b/python/tvm/topi/testing/__init__.py
index 72a7cedc49..1486e9986e 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -84,3 +84,4 @@ from .dense import dense
 from .searchsorted import searchsorted_ref
 from .conv2d_backcward_weight_python import conv2d_backward_weight_python
 from .lstm_python import lstm_python
+from .attention_python import attention_python
diff --git a/python/tvm/topi/testing/attention_python.py 
b/python/tvm/topi/testing/attention_python.py
new file mode 100644
index 0000000000..856667aedd
--- /dev/null
+++ b/python/tvm/topi/testing/attention_python.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Attention operator in python"""
+from typing import Optional
+import numpy as np
+from .softmax_python import softmax_python
+
+
+def attention_python(
+    q: np.ndarray,
+    k: np.ndarray,
+    v: np.ndarray,
+    bias: Optional[np.ndarray],
+    qk_scale: float,
+    causal: str,
+    window_size: Optional[int] = None,
+    layout: str = "BSNH",
+):  # pylint: disable=too-many-arguments, too-many-locals, invalid-name
+    """Attention operator in python
+
+    Parameters
+    ----------
+    q : np.ndarray
+        Query tensor with shape [batch, seq_length, num_heads, head_dim] in 
the layout specified by
+        `layout`.
+    k : np.ndarray
+        Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] 
in the layout specified
+        by `layout`.
+    v : np.ndarray
+        Value tensor with shape [batch, seq_length_kv, num_kv_heads, 
head_dim_v] in the layout
+        specified by `layout`.
+    bias : np.ndarray
+        Bias tensor with shape [batch, num_heads, seq_length, seq_length]
+    qk_scale : float
+        Scale factor for the query-key product.
+    causal : str
+        The type of causal mask to apply. Can be "none", "TopLeft", or 
"BottomRight".
+    window_size : Optional[int]
+        The window size for the causal mask.
+    layout : str
+        The layout of the input tensors, e.g. "BSNH" or "BNSH".
+
+    Returns
+    -------
+    np.ndarray
+        The output tensor with shape [batch, seq_length, num_heads, 
head_dim_v] in the layout
+        specified by `layout`.
+    """
+    assert layout in ["BSNH", "BNSH", "SBNH"]
+
+    dim_b = layout.find("B")
+    dim_s = layout.find("S")
+    dim_n = layout.find("N")
+    dim_h = layout.find("H")
+
+    q = q.transpose(dim_b, dim_n, dim_s, dim_h)  # b, n, s, h
+    k = k.transpose(dim_b, dim_n, dim_s, dim_h)  # b, n, s_kv, h
+    kt = k.transpose(0, 1, 3, 2)  # b, n, h, s_kv
+    v = v.transpose(dim_b, dim_n, dim_s, dim_h)
+
+    num_heads = q.shape[1]
+    num_kv_heads = k.shape[1]
+    s = q.shape[2]
+    s_kv = k.shape[2]
+
+    if num_heads != num_kv_heads:
+        assert num_heads % num_kv_heads == 0
+        factor = num_heads // num_kv_heads
+        kt = np.repeat(kt, factor, axis=1)
+        v = np.repeat(v, factor, axis=1)
+
+    if not qk_scale == "none":
+        score = q @ kt * qk_scale  # b, n, s, s_kv
+    else:
+        score = q @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    if bias is not None:
+        score = score + bias  # b, n, s, s_kv
+    if causal == "none":
+        attn = softmax_python(score, -1)
+    else:
+        if causal == "TopLeft":
+            offset = 0
+        elif causal == "BottomRight":
+            offset = abs(s - s_kv)
+        else:
+            raise ValueError(f"Unsupported causal type: {causal}")
+        score_masked = np.tril(score, k=offset)
+
+        if window_size:
+            score_masked = np.triu(
+                score_masked, -window_size + 1  # pylint: 
disable=invalid-unary-operand-type
+            )
+
+        score_masked_exp = np.tril(
+            np.exp(score_masked - np.max(score_masked, axis=-1, 
keepdims=True)), k=offset
+        )
+
+        if window_size:
+            score_masked_exp = np.triu(
+                score_masked_exp, -window_size + 1  # pylint: 
disable=invalid-unary-operand-type
+            )
+
+        score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
+        attn = np.divide(score_masked_exp, score_masked_sum)
+
+    out = attn @ v  # b, n, s, h_v
+    return out.transpose(*np.argsort([dim_b, dim_n, dim_s, dim_h]).tolist())
diff --git a/src/relax/backend/contrib/cudnn/codegen.cc 
b/src/relax/backend/contrib/cudnn/codegen.cc
index 812016b8ea..d8ca5f4e97 100644
--- a/src/relax/backend/contrib/cudnn/codegen.cc
+++ b/src/relax/backend/contrib/cudnn/codegen.cc
@@ -55,6 +55,17 @@ class cuDNNJSONSerializer : public JSONSerializer {
 
     std::string composite_name = composite_opt.value();
 
+    if (composite_name.find("cudnn.conv2d") != std::string::npos) {
+      return HandleConv2D(call_node, fn, composite_name);
+    } else if (composite_name.find("cudnn.attention") != std::string::npos) {
+      return HandleAttention(call_node, fn, composite_name);
+    } else {
+      LOG(FATAL) << "Unsupported composite function: " << composite_name;
+    }
+  }
+
+  NodeEntries HandleConv2D(const CallNode* call_node, const Function& fn,
+                           const std::string& composite_name) {
     NodeEntries inputs_tmp;
     for (const auto& arg : call_node->args) {
       auto res = VisitExpr(arg);
@@ -80,6 +91,42 @@ class cuDNNJSONSerializer : public JSONSerializer {
     return AddNode(node, GetRef<Expr>(call_node));
   }
 
+  NodeEntries HandleAttention(const CallNode* call_node, const Function& fn,
+                              const std::string& composite_name) {
+    std::string layout = 
composite_name.substr(composite_name.find_last_of(".") + 1);
+    NodeEntries inputs;
+    for (const auto& arg : call_node->args) {
+      auto res = VisitExpr(arg);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+    ICHECK_EQ(inputs.size(), 2);
+    auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
+                                                "kernel",       /* op_type_ */
+                                                inputs, 1 /* num_outputs_ */);
+    const CallNode* root_call = backend::GetOpInFunction(fn, 
"relax.nn.attention");
+    auto q_shape = Downcast<ShapeExpr>(
+        
Downcast<TensorStructInfo>(root_call->args[0]->struct_info_.value())->shape.value());
+    auto k_shape = Downcast<ShapeExpr>(
+        
Downcast<TensorStructInfo>(root_call->args[1]->struct_info_.value())->shape.value());
+    auto v_shape = Downcast<ShapeExpr>(
+        
Downcast<TensorStructInfo>(root_call->args[2]->struct_info_.value())->shape.value());
+    int num_heads = q_shape->values[2].as<IntImmNode>()->value;
+    int num_kv_heads = k_shape->values[2].as<IntImmNode>()->value;
+    int head_size = q_shape->values[3].as<IntImmNode>()->value;
+    int head_size_v = v_shape->values[3].as<IntImmNode>()->value;
+    SetCallNodeAttribute(node, root_call);
+
+    auto to_str_array = [](int val) {
+      return 
std::vector<dmlc::any>{std::vector<std::string>{std::to_string(val)}};
+    };
+    node->SetAttr("num_heads", to_str_array(num_heads));
+    node->SetAttr("num_kv_heads", to_str_array(num_kv_heads));
+    node->SetAttr("head_size", to_str_array(head_size));
+    node->SetAttr("head_size_v", to_str_array(head_size_v));
+    node->SetAttr("layout", 
std::vector<dmlc::any>{std::vector<std::string>{layout}});
+    return AddNode(node, GetRef<Expr>(call_node));
+  }
+
  private:
   /*! \brief The bindings to look up composite functions. */
   Map<Var, Expr> bindings_;
diff --git a/src/relax/transform/allocate_workspace.cc 
b/src/relax/transform/allocate_workspace.cc
index 1d4a017712..05aa8ce552 100644
--- a/src/relax/transform/allocate_workspace.cc
+++ b/src/relax/transform/allocate_workspace.cc
@@ -66,8 +66,10 @@ class ExternFunctionRewriter : ExprMutator {
       }
 
       new_params.push_back(workspace_param);
+      auto new_attrs = func_node->attrs;
+      new_attrs.CopyOnWrite()->dict.erase(attr::kWorkspaceSize);
       return Function(new_params, VisitExpr(func_node->body), 
func_node->ret_struct_info,
-                      func_node->is_pure, func_node->attrs);
+                      func_node->is_pure, new_attrs);
     }
     return ExprMutator::VisitExpr_(func_node);
   }
@@ -122,6 +124,7 @@ class WorkspaceProvider : ExprMutator {
       builder_->UpdateFunction(new_gvar,
                                WithAttr(f, tvm::attr::kGlobalSymbol, 
new_gvar->name_hint));
       gvar_map_[gvar] = new_gvar;
+      new_gvars_.insert(new_gvar);
       builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
     }
 
@@ -164,8 +167,7 @@ class WorkspaceProvider : ExprMutator {
     auto new_op = VisitExpr(call_node->op);
 
     if (auto gv = new_op.as<GlobalVar>()) {
-      auto callee = builder_->GetContextIRModule()->Lookup(gv.value());
-      if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) {
+      if (new_gvars_.count(gv.value())) {
         auto new_args = call_node->args;
         ICHECK(workspace_var_main_.defined());
         new_args.push_back(workspace_var_main_);
@@ -185,6 +187,7 @@ class WorkspaceProvider : ExprMutator {
    * the new ones that are transformed to take an additional workspace 
parameter. This is only
    * needed since the struct info of the global variables changes between 
transformation. */
   std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+  std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> new_gvars_;
 };
 
 }  // namespace relax
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 2be7ad41f3..6030a28d93 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -33,6 +33,7 @@
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
+#include <tvm/tir/analysis.h>
 #include <tvm/tir/expr_functor.h>
 #include <tvm/tir/function.h>
 
@@ -595,8 +596,7 @@ class FunctionCreator : public ExprMutator {
       }
 
       StructInfo param_sinfo = GetStructInfo(expr);
-      // Exclude PrimValues from arg/params to make composite functions 
contain PrimValues.
-      if (!expr->IsInstance<PrimValueNode>()) {
+      if (!IsInlinableConstants(expr)) {
         Var param(std::move(name), GetStructInfo(expr));
         arguments_.push_back(expr);
         params_.push_back(param);
@@ -621,6 +621,21 @@ class FunctionCreator : public ExprMutator {
     return ExprMutator::VisitExpr(expr);
   }
 
+  // Check if the expression is constant PrimValue or ShapeExpr or tuple of 
them that can be
+  // inlined in the composite functions and excluded from args/params.
+  bool IsInlinableConstants(const Expr& expr) {
+    if (const auto* tuple = expr.as<TupleNode>()) {
+      return std::all_of(tuple->fields.begin(), tuple->fields.end(),
+                         [this](const Expr& e) { return 
IsInlinableConstants(e); });
+    } else if (const auto* prim_value = expr.as<PrimValueNode>()) {
+      return tvm::tir::UndefinedVars(prim_value->value).empty();
+    } else if (const auto* shape_expr = expr.as<ShapeExprNode>()) {
+      return std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
+                         [this](const PrimExpr& e) { return 
tvm::tir::UndefinedVars(e).empty(); });
+    }
+    return false;
+  }
+
  private:
   /*! \brief The variables defined in this function */
   std::unordered_set<const VarNode*> defined_vars_;
diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc 
b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
new file mode 100644
index 0000000000..f8b170fe20
--- /dev/null
+++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
+ * \brief cuDNN scale dot product attention implementation
+ */
+
+#include "./attention.h"
+
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../../cuda/cuda_common.h"
+#include "../cudnn_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t 
num_heads,
+                               int64_t num_kv_heads, int64_t head_size, 
int64_t head_size_v,
+                               double scale, const DLDataType& data_type,
+                               const std::string& layout) {
+  graph_ = std::make_unique<cudnn_frontend::graph::Graph>();
+
+  CHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16)
+      << "Only float16 is supported";
+
+  graph_->set_io_data_type(cudnn_frontend::DataType_t::HALF)
+      .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT)
+      .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT);
+
+  auto q_desc = 
cudnn_frontend::graph::Tensor_attributes().set_name("Q").set_uid(kTensorIDQ);
+  auto k_desc = 
cudnn_frontend::graph::Tensor_attributes().set_name("K").set_uid(kTensorIDK);
+  auto v_desc = 
cudnn_frontend::graph::Tensor_attributes().set_name("V").set_uid(kTensorIDV);
+  auto o_desc = 
cudnn_frontend::graph::Tensor_attributes().set_name("Out").set_uid(kTensorIDOut);
+
+  std::vector<int64_t> q_stride, k_stride, v_stride,
+      o_stride;  // stride in the order of (batch, num_heads, seq_len, 
head_size)
+
+  if (layout == "BS3NH") {
+    int64_t stride_H = 1;
+    int64_t q_stride_N = head_size;
+    int64_t k_stride_N = head_size;
+    int64_t v_stride_N = head_size_v;
+    int64_t stride_S =
+        num_heads * q_stride_N + num_kv_heads * k_stride_N + num_kv_heads * 
v_stride_N;
+    int64_t stride_B = stride_S * seq_len;
+    q_stride = {stride_B, q_stride_N, stride_S, stride_H};
+    k_stride = {stride_B, k_stride_N, stride_S, stride_H};
+    v_stride = {stride_B, v_stride_N, stride_S, stride_H};
+    o_stride = {seq_len * num_heads * head_size_v, head_size_v, num_heads * 
head_size_v, 1};
+    offset_k_ = num_heads * head_size;
+    offset_v_ = offset_k_ + num_kv_heads * head_size;
+  } else if (layout == "SBN3H") {
+    CHECK_EQ(num_kv_heads, num_heads);
+    int64_t stride_H = 1;
+    int64_t stride_N = head_size + head_size + head_size_v;
+    int64_t stride_B = num_heads * stride_N;
+    int64_t stride_S = stride_B * batch;
+    q_stride = k_stride = v_stride = {stride_B, stride_N, stride_S, stride_H};
+    o_stride = {num_heads * head_size_v, head_size_v, num_heads * head_size_v 
* batch, 1};
+    offset_k_ = head_size;
+    offset_v_ = offset_k_ * 2;
+  } else {
+    LOG(FATAL) << "Unsupported layout: " << layout;
+  }
+
+  q_desc = q_desc.set_dim({batch, num_heads, seq_len, 
head_size}).set_stride(q_stride);
+  k_desc = k_desc.set_dim({batch, num_kv_heads, seq_len, 
head_size}).set_stride(k_stride);
+  v_desc = v_desc.set_dim({batch, num_kv_heads, seq_len, 
head_size_v}).set_stride(v_stride);
+  auto sdpa_options = cudnn_frontend::graph::SDPA_attributes()
+                          .set_name("flash_attention")
+                          .set_is_inference(true)
+                          .set_alibi_mask(false)
+                          .set_causal_mask(false)
+                          .set_attn_scale(scale);
+
+  auto q = graph_->tensor(q_desc);
+  auto k = graph_->tensor(k_desc);
+  auto v = graph_->tensor(v_desc);
+  auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options);
+  CHECK(stats == nullptr);
+  o->set_output(true).set_dim({batch, num_heads, seq_len, 
head_size_v}).set_stride(o_stride);
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, 
{cudnn_frontend::HeurMode_t::A}));
+}
+
+void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, 
DLTensor* out) {
+  CUDNN_CALL(
+      cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle, 
tvm::runtime::GetCUDAStream()));
+  auto* qkv_base = reinterpret_cast<uint8_t*>(qkv->data) + qkv->byte_offset;
+  auto* q_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_q_;
+  auto* k_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_k_;
+  auto* v_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_v_;
+  auto* out_ptr = reinterpret_cast<uint8_t*>(out->data) + out->byte_offset;
+
+  size_t workspace_size = graph_->get_workspace_size();
+  CHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small";
+  std::unordered_map<cudnn_frontend::graph::Tensor_attributes::uid_t, void*> 
inputs = {
+      {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, 
{kTensorIDOut, out_ptr}};
+
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs, 
workspace->data));
+}
+
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h 
b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h
new file mode 100644
index 0000000000..4d0309fb3b
--- /dev/null
+++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_frontend/attention.h
+ * \brief cuDNN scale dot product attention implementation
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
+#define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
+
+#include <cudnn_frontend.h>
+#include <tvm/runtime/registry.h>
+
+#include <memory>
+#include <string>
+
+#define CUDNN_FRONTEND_CALL(func)                    \
+  do {                                               \
+    auto status = (func);                            \
+    CHECK(status.is_good()) << status.get_message(); \
+  } while (0)
+
+namespace tvm {
+namespace contrib {
+
+class CuDNNSDPARunnerNode : public tvm::runtime::Object {
+ public:
+  CuDNNSDPARunnerNode() {}
+
+  ~CuDNNSDPARunnerNode() {}
+
+  static constexpr const char* _type_key = "contrib.cudnn.SDPARunner";
+
+  void Init(int64_t batch, int64_t seq_len, int64_t num_heads, int64_t 
num_kv_heads,
+            int64_t head_size, int64_t head_size_v, double scale, const 
DLDataType& data_type,
+            const std::string& layout);
+
+  void Run(const DLTensor* qkv, DLTensor* workspace, DLTensor* out);
+
+  static constexpr int kTensorIDQ = 0;
+  static constexpr int kTensorIDK = 1;
+  static constexpr int kTensorIDV = 2;
+  static constexpr int kTensorIDOut = 4;
+
+ private:
+  std::unique_ptr<cudnn_frontend::graph::Graph> graph_{nullptr};
+  int64_t offset_q_{0};
+  int64_t offset_k_{0};
+  int64_t offset_v_{0};
+};
+
+class CuDNNSDPARunner : public tvm::runtime::ObjectRef {
+ public:
+  static CuDNNSDPARunner Create() {
+    auto n = make_object<CuDNNSDPARunnerNode>();
+    return CuDNNSDPARunner(n);
+  }
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CuDNNSDPARunner, 
tvm::runtime::ObjectRef,
+                                        CuDNNSDPARunnerNode);
+};
+
+}  // namespace contrib
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_FRONTEND_ATTENTION_H_
diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc 
b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
index 7d701396d0..3f4b659275 100644
--- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
+++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
@@ -31,6 +31,10 @@
 
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
+
+#ifdef TVM_USE_CUDNN_FRONTEND
+#include "./cudnn_frontend/attention.h"
+#endif
 #include "cudnn_utils.h"
 
 namespace tvm {
@@ -47,78 +51,19 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
       : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
 
   void Init(const Array<NDArray>& consts) override {
-    auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
-    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
-    ICHECK(func != nullptr);
-    stream = static_cast<cudaStream_t>((*func)().operator void*());
-
-    auto attr_in_name = [](const std::string& op_name, const std::string& 
attr_name) {
-      return op_name.find(attr_name) != std::string::npos;
-    };
-
-    auto vstr2vint = [](const JSONGraphNode& node, const std::string& attrStr) 
{
-      auto string_to_int = [](const std::string& str) { return std::stoi(str); 
};
-      auto string_vec = node.GetAttr<std::vector<std::string>>(attrStr);
-      std::vector<int> int_vec(string_vec.size());
-      std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), 
string_to_int);
-      return int_vec;
-    };
+    op_execs_.resize(nodes_.size());
     // get some config from the graph
     for (size_t i = 0; i < nodes_.size(); ++i) {
       const auto& node = nodes_[i];
       if (node.GetOpType() == "kernel") {
-        op_name = node.GetOpName();
-        std::vector<int> input_dims, kernel_dims, output_dims;
-        auto input_node = nodes_[0];
-        auto input_shapes = input_node.GetOpShape()[0];
-        auto kernel_node = nodes_[1];
-        auto kernel_shapes = kernel_node.GetOpShape()[0];
-        auto output_shapes = node.GetOpShape()[0];
-        for (const auto& _i : input_shapes) {
-          input_dims.emplace_back(static_cast<int>(_i));
-        }
-        for (const auto& _i : kernel_shapes) {
-          kernel_dims.emplace_back(static_cast<int>(_i));
+        std::string op_name = node.GetOpName();
+        if (op_name.find("conv2d") != std::string::npos) {
+          op_execs_[i] = GetConv2DExec(node);
+        } else if (op_name.find("attention") != std::string::npos) {
+          op_execs_[i] = GetAttentionExec(node);
+        } else {
+          LOG(FATAL) << "Unsupported op: " << op_name;
         }
-        for (const auto& _i : output_shapes) {
-          output_dims.emplace_back(static_cast<int>(_i));
-        }
-        has_bias = attr_in_name(op_name, "bias");
-        groups = 
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
-        padding = vstr2vint(node, "padding");
-        strides = vstr2vint(node, "strides");
-        dilation = vstr2vint(node, "dilation");
-        conv_dtype = node.GetAttr<std::vector<std::string>>("out_dtype")[0];
-        std::string layout = 
node.GetAttr<std::vector<std::string>>("out_layout")[0];
-        dims = layout.size() - 2;  // remove O and I dims
-
-        if (layout == "NCHW")
-          format = CUDNN_TENSOR_NCHW;
-        else if (layout == "NHWC")
-          format = CUDNN_TENSOR_NHWC;
-        else
-          LOG(FATAL) << "Unsupported layout: " << layout;
-
-        if (attr_in_name(op_name, "relu")) {
-          act = CUDNN_ACTIVATION_RELU;
-        } else if (attr_in_name(op_name, "relu6")) {
-          act = CUDNN_ACTIVATION_CLIPPED_RELU;
-          coef = 6.0;
-        } else if (attr_in_name(op_name, "leaky_relu")) {
-          act = CUDNN_ACTIVATION_RELU;
-          coef = 0.1;
-        }
-        this->handle = entry_ptr->handle;
-        this->kernel_node = node;
-
-        // find best algo
-        TVMRetValue best_algo;
-
-        tvm::contrib::FindAlgo(format, dims, groups, padding.data(), 
strides.data(),
-                               dilation.data(), input_dims.data(), 
kernel_dims.data(),
-                               output_dims.data(), conv_dtype, conv_dtype, 
false, &best_algo);
-
-        this->algo = best_algo.operator int();
       }
     }
   }
@@ -126,27 +71,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
   const char* type_key() const override { return "cudnn_json"; }  // May be 
overridden
 
   void Run() override {
-    auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
-      const DLTensor* bias = nullptr;
-      if (has_bias) {
-        bias = GetInput(node, 2);
+    for (const auto& f : op_execs_) {
+      if (f != nullptr) {
+        f();
       }
-      return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
-    };
-
-    auto [a_ptr, b_ptr, bias_ptr] = get_inputs(kernel_node, has_bias);
-    uint32_t output_eid = EntryID(outputs_[0]);
-    auto out_ptr = data_entry_[output_eid];
-
-    if (this->has_bias) {
-      tvm::contrib::ConvolutionBiasActivationForward(
-          this->mode, this->format, this->algo, this->dims, this->groups, 
this->act, this->coef,
-          this->padding.data(), this->strides.data(), this->dilation.data(), 
a_ptr, b_ptr, out_ptr,
-          bias_ptr, this->conv_dtype);
-    } else {
-      tvm::contrib::ConvolutionForward(
-          this->mode, this->format, this->algo, this->dims, this->groups, 
this->padding.data(),
-          this->strides.data(), this->dilation.data(), a_ptr, b_ptr, out_ptr, 
this->conv_dtype);
     }
   }
 
@@ -157,27 +85,150 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
     ICHECK(eid < data_entry_.size());
     return data_entry_[eid];
   }
-  /*conv op name*/
-  std::string op_name;
-  /*conv mode: CUDNN_CROSS_CORRELATION by default*/
-  int mode = CUDNN_CROSS_CORRELATION;
-  /*algo: by default we select the implicit gemm algo, will be tuned in the 
initial pass.*/
-  int algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-  /*if has bias*/
-  bool has_bias = false;
-  /*args for function call*/
-  int act = CUDNN_ACTIVATION_IDENTITY;
-  double coef = 1.0;
-  int format = CUDNN_TENSOR_NHWC;
-  int dims = 2;
-  int groups = 1;
-  std::vector<int> padding;
-  std::vector<int> strides;
-  std::vector<int> dilation;
-  std::string conv_dtype;
-  cudaStream_t stream;
-  cudnnHandle_t handle;
-  tvm::runtime::json::JSONGraphNode kernel_node;
+
+  bool attr_in_name(const std::string& op_name, const std::string& attr_name) {
+    return op_name.find(attr_name) != std::string::npos;
+  }
+
+  std::vector<int> vstr2vint(const JSONGraphNode& node, const std::string& 
attrStr) {
+    auto string_to_int = [](const std::string& str) { return std::stoi(str); };
+    auto string_vec = node.GetAttr<std::vector<std::string>>(attrStr);
+    std::vector<int> int_vec(string_vec.size());
+    std::transform(string_vec.begin(), string_vec.end(), int_vec.begin(), 
string_to_int);
+    return int_vec;
+  }
+
+  std::function<void()> GetConv2DExec(const JSONGraphNode& node) {
+    auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
+    auto op_name = node.GetOpName();
+
+    std::vector<int> input_dims, kernel_dims, output_dims;
+    auto input_node = nodes_[0];
+    auto input_shapes = input_node.GetOpShape()[0];
+    auto kernel_shapes = nodes_[1].GetOpShape()[0];
+    auto output_shapes = node.GetOpShape()[0];
+    for (const auto& _i : input_shapes) {
+      input_dims.emplace_back(static_cast<int>(_i));
+    }
+    for (const auto& _i : kernel_shapes) {
+      kernel_dims.emplace_back(static_cast<int>(_i));
+    }
+    for (const auto& _i : output_shapes) {
+      output_dims.emplace_back(static_cast<int>(_i));
+    }
+    bool has_bias = attr_in_name(op_name, "bias");
+    int groups = 
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
+    std::vector<int> padding = vstr2vint(node, "padding");
+    std::vector<int> strides = vstr2vint(node, "strides");
+    std::vector<int> dilation = vstr2vint(node, "dilation");
+    auto conv_dtype = node.GetAttr<std::vector<std::string>>("out_dtype")[0];
+    std::string layout = 
node.GetAttr<std::vector<std::string>>("out_layout")[0];
+    int dims = layout.size() - 2;  // remove O and I dims
+
+    int format = CUDNN_TENSOR_NHWC;
+    if (layout == "NCHW") {
+      format = CUDNN_TENSOR_NCHW;
+    } else if (layout == "NHWC") {
+      format = CUDNN_TENSOR_NHWC;
+    } else {
+      LOG(FATAL) << "Unsupported layout: " << layout;
+    }
+
+    int act = CUDNN_ACTIVATION_IDENTITY;
+    double coef = 1.0;
+    if (attr_in_name(op_name, "relu")) {
+      act = CUDNN_ACTIVATION_RELU;
+    } else if (attr_in_name(op_name, "relu6")) {
+      act = CUDNN_ACTIVATION_CLIPPED_RELU;
+      coef = 6.0;
+    } else if (attr_in_name(op_name, "leaky_relu")) {
+      act = CUDNN_ACTIVATION_RELU;
+      coef = 0.1;
+    }
+
+    /*conv mode: CUDNN_CROSS_CORRELATION by default*/
+    int mode = CUDNN_CROSS_CORRELATION;
+
+    // find best algo
+    TVMRetValue best_algo;
+
+    tvm::contrib::FindAlgo(format, dims, groups, padding.data(), 
strides.data(), dilation.data(),
+                           input_dims.data(), kernel_dims.data(), 
output_dims.data(), conv_dtype,
+                           conv_dtype, false, &best_algo);
+
+    int algo = best_algo.operator int();
+    std::function<void()> op_exec = [=]() {
+      auto stream = static_cast<cudaStream_t>(GetCUDAStream());
+      CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream));
+
+      auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
+        const DLTensor* bias = nullptr;
+        if (has_bias) {
+          bias = GetInput(node, 2);
+        }
+        return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
+      };
+
+      auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, has_bias);
+      uint32_t output_eid = EntryID(outputs_[0]);
+      auto out_ptr = data_entry_[output_eid];
+      if (has_bias) {
+        tvm::contrib::ConvolutionBiasActivationForward(
+            mode, format, algo, dims, groups, act, coef, padding.data(), 
strides.data(),
+            dilation.data(), a_ptr, b_ptr, out_ptr, bias_ptr, conv_dtype);
+      } else {
+        tvm::contrib::ConvolutionForward(mode, format, algo, dims, groups, 
padding.data(),
+                                         strides.data(), dilation.data(), 
a_ptr, b_ptr, out_ptr,
+                                         conv_dtype);
+      }
+    };
+    return op_exec;
+  }
+
+  std::function<void()> GetAttentionExec(const JSONGraphNode& node) {
+#ifdef TVM_USE_CUDNN_FRONTEND
+    auto dtype = node.GetOpDataType()[0];
+    int num_heads = vstr2vint(node, "num_heads")[0];
+    int num_kv_heads = vstr2vint(node, "num_kv_heads")[0];
+    int head_size = vstr2vint(node, "head_size")[0];
+    int head_size_v = vstr2vint(node, "head_size_v")[0];
+    std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
+    const auto& input_qkv_node = nodes_[EntryID(node.GetInputs()[0])];
+    auto qkv_shapes = input_qkv_node.GetOpShape()[0];
+
+    int64_t batch, seq_len;
+    if (layout == "BS3NH") {
+      ICHECK_EQ(qkv_shapes.size(), 3);
+      batch = qkv_shapes[0];
+      seq_len = qkv_shapes[1];
+    } else if (layout == "SBN3H") {
+      ICHECK_EQ(qkv_shapes.size(), 4);
+      batch = qkv_shapes[1];
+      seq_len = qkv_shapes[0];
+    } else {
+      LOG(FATAL) << "Unsupported layout: " << layout;
+    }
+    double scale = 1 / std::sqrt(head_size);
+    std::string scale_attr = 
node.GetAttr<std::vector<std::string>>("scale")[0];
+    if (scale_attr.size()) {
+      scale = std::stod(scale_attr);
+    }
+
+    auto runner = tvm::contrib::CuDNNSDPARunner::Create();
+    runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size, 
head_size_v, scale, dtype,
+                 layout);
+    return [=]() {
+      auto qkv = GetInput(node, 0);
+      auto workspace = const_cast<DLTensor*>(GetInput(node, 1));
+      auto out = const_cast<DLTensor*>(data_entry_[EntryID(outputs_[0])]);
+      runner->Run(qkv, workspace, out);
+    };
+#else
+    LOG(FATAL) << "Please build with CUDNN frontend to use attention op";
+#endif
+  }
+
+  std::vector<std::function<void()>> op_execs_;
 };
 
 runtime::Module cuDNNJSONRuntimeCreate(String symbol_name, String graph_json,
diff --git a/tests/python/relax/test_codegen_cudnn.py 
b/tests/python/relax/test_codegen_cudnn.py
index 0f911905f8..59f49bfde8 100644
--- a/tests/python/relax/test_codegen_cudnn.py
+++ b/tests/python/relax/test_codegen_cudnn.py
@@ -22,7 +22,8 @@ import tvm.testing
 import tvm.topi.testing
 from tvm import relax
 from tvm.relax.backend.contrib.cudnn import partition_for_cudnn
-from tvm.relax.testing import get_relax_matmul_module
+from tvm.relax.testing import get_relax_matmul_module, 
get_relax_stacked_attention_module
+from tvm.contrib.pickle_memoize import memoize
 from tvm.script import relax as R
 
 from tvm.script.ir_builder import IRBuilder
@@ -99,7 +100,7 @@ def get_relax_conv2d_module(
 def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False):
     mod = partition_for_cudnn(mod)
     mod = relax.transform.RunCodegen()(mod)
-    return build_and_run(mod, np_inputs, "cuda", cuda_graph)
+    return build_and_run(mod, np_inputs, "cuda", cuda_graph=cuda_graph)
 
 
 def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
@@ -244,5 +245,65 @@ def test_conv2d_nchw_oihw_offload(data_shape, 
weight_shape, dtype, with_bias, ac
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+@memoize("topi.tests.test_codegen_cudnn.test_stacked_attention_offload")
+def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, qk_scale, 
dtype, layout):
+    if layout == "BS3NH":
+        qkv = np.random.randn(b, s, n * h * 2 + n * h_v).astype(dtype)
+        split_qkv = np.split(qkv, [n * h, n * h * 2], axis=2)
+        q = split_qkv[0].reshape(b, s, n, h)
+        k = split_qkv[1].reshape(b, s, n, h)
+        v = split_qkv[2].reshape(b, s, n, h_v)
+        layout = "BSNH"
+    elif layout == "SBN3H":
+        qkv = np.random.randn(s, b, n, h * 2 + h_v).astype(dtype)
+        q, k, v = np.split(qkv, [h, h * 2], axis=3)
+        layout = "SBNH"
+    else:
+        raise ValueError("Unsupported layout: {}".format(layout))
+    if not bias_shape == "none":
+        bias = np.random.randn(*bias_shape).astype(dtype)
+        score = score + bias  # b, n, s, s
+    else:
+        bias = None
+    ref = tvm.topi.testing.attention_python(q, k, v, bias, qk_scale, "none", 
None, layout)
+    return qkv, bias, ref
+
+
[email protected](
+    params=[
+        # B, S, N, H, bias_shape scale, single_shape, layout
+        (4, 8, 32, (64, 32), "none", 1.0, False, "BS3NH"),
+        (4, 8, 32, (64, 64), "none", "none", True, "BS3NH"),
+        (4, 8, 32, (64, 32), "none", 1.0, False, "SBN3H"),
+        (4, 8, 32, (64, 64), "none", "none", True, "SBN3H"),
+    ]
+)
+def stacked_attention_size(request):
+    return request.param
+
+
[email protected](reason="require cudnn frontend")
+def test_stacked_attention_split_offload(stacked_attention_size):
+    b, s, n, (h, h_v), bias_shape, scale, single_shape, layout = 
stacked_attention_size
+    qkv, bias, ref = get_numpy_stacked_attention_ref(
+        b, s, n, h, h_v, bias_shape, scale, "float16", layout
+    )
+    if scale == "none":
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape, 
layout=layout
+        )
+        scale = 1.0 / np.sqrt(h)
+    else:
+        mod = get_relax_stacked_attention_module(
+            qkv, b, s, n, h, h_v, "split", bias, scale, 
single_shape=single_shape, layout=layout
+        )
+
+    if bias is None:
+        out = get_result_with_relax_cudnn_offload(mod, [qkv])
+    else:
+        out = get_result_with_relax_cudnn_offload(mod, [qkv, bias])
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=2e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 969651f72f..3fa3f2d914 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -24,7 +24,11 @@ from tvm import relax
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
 from tvm.contrib.pickle_memoize import memoize
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
-from tvm.relax.testing import get_relax_matmul_module
+from tvm.relax.testing import (
+    get_relax_matmul_module,
+    get_relax_attention_module,
+    get_relax_stacked_attention_module,
+)
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
@@ -594,47 +598,6 @@ def attention_size(request):
     return request.param
 
 
-def get_relax_attention_module(
-    q_shape,
-    k_shape,
-    v_shape,
-    *,
-    dtype,
-    bias_shape=None,
-    qk_scale=None,
-    causal_mask=None,
-    window_size=None,
-):
-    from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import relax as relax_builder
-    from tvm.script.ir_builder import tir as T
-
-    if qk_scale is not None:
-        qk_scale = T.FloatImm("float32", qk_scale)
-
-    if window_size is not None:
-        window_size = T.IntImm("int32", window_size)
-
-    with IRBuilder() as builder:
-        with relax_builder.function():
-            R.func_name("main")
-            q = R.arg("q", R.Tensor(q_shape, dtype))
-            k = R.arg("k", R.Tensor(k_shape, dtype))
-            v = R.arg("v", R.Tensor(v_shape, dtype))
-            bias = None
-            if bias_shape is not None and bias_shape != "none":
-                bias = R.arg("bias", R.Tensor(bias_shape, dtype))
-
-            with R.dataflow() as frame:
-                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal_mask, window_size))
-                R.output(result)
-
-            R.func_ret_value(frame.output_vars[0])
-
-    func = builder.get()
-    return tvm.IRModule({"main": func})
-
-
 def get_numpy_attention_ref(
     b,
     s,
@@ -649,59 +612,20 @@ def get_numpy_attention_ref(
     window_size=None,
     num_kv_head=None,
 ):
-    if num_kv_head is None:
-        num_kv_head = n
-
+    num_kv_head = num_kv_head or n
     q = np.random.randn(b, s, n, h).astype(dtype)
-    k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
-    v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
-
-    if num_kv_head is None:
-        k = k_orig
-        v = v_orig
-    else:
-        factor = n // num_kv_head
-        k = np.repeat(k_orig, factor, axis=2)
-        v = np.repeat(v_orig, factor, axis=2)
-
-    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
-    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
-    if not qk_scale == "none":
-        score = qt @ kt * qk_scale  # b, n, s, s_kv
-    else:
-        score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
-    if not bias_shape == "none":
-        bias = np.random.randn(*bias_shape).astype(dtype)
-        score = score + bias  # b, n, s, s_kv
-    else:
+    k = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
+    v = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
+    if bias_shape == "none":
         bias = None
-    if causal == "none":
-        attn = tvm.topi.testing.softmax_python(score, -1)
     else:
-        if causal == "TopLeft":
-            offset = 0
-        elif causal == "BottomRight":
-            offset = abs(s - s_kv)
-        else:
-            raise NotImplementedError()
-        score_masked = np.tril(score, k=offset)
-
-        if window_size:
-            score_masked = np.triu(score_masked, -window_size + 1)
-
-        score_masked_exp = np.tril(
-            np.exp(score_masked - np.max(score_masked, axis=-1, 
keepdims=True)), k=offset
-        )
-
-        if window_size:
-            score_masked_exp = np.triu(score_masked_exp, -window_size + 1)
+        bias = np.random.randn(*bias_shape).astype(dtype)
 
-        score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
-        attn = np.divide(score_masked_exp, score_masked_sum)
+    ref = tvm.topi.testing.attention_python(
+        q, k, v, bias, qk_scale, causal=causal, window_size=window_size, 
layout="BSNH"
+    )
 
-    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
-    ref = attn @ vt  # b, n, s, h_v
-    return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+    return q, k, v, bias, ref
 
 
 def test_attention_offload(attention_size, attention_dtype):
@@ -844,69 +768,14 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, qk_scale, dtype
     q = np.reshape(split_qkv[0], (b, s, n, h))
     k = np.reshape(split_qkv[1], (b, s, n, h))
     v = np.reshape(split_qkv[2], (b, s, n, h_v))
-    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
-    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s
-    if not qk_scale == "none":
-        score = qt @ kt * qk_scale  # b, n, s, s
-    else:
-        score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s
     if not bias_shape == "none":
         bias = np.random.randn(*bias_shape).astype(dtype)
-        score = score + bias  # b, n, s, s
     else:
         bias = None
-    attn = tvm.topi.testing.softmax_python(score, -1)
-    vt = v.transpose(0, 2, 1, 3)  # b, n, s, h_v
-    ref = attn @ vt  # b, n, s, h_v
-    return qkv, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
-
-
-def get_relax_stacked_attention_module(
-    qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None, single_shape=False
-):
-    dtype = str(qkv.dtype)
-
-    from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import relax as relax_builder
-    from tvm.script.ir_builder import tir as T
-
-    if qk_scale is not None:
-        qk_scale = T.FloatImm("float32", qk_scale)
-
-    if single_shape:
-        qk_shape = R.shape([b, s, n, h])
-        v_shape = qk_shape
-    else:
-        qk_shape = [b, s, n, h]
-        v_shape = [b, s, n, h_v]
-
-    with IRBuilder() as builder:
-        with relax_builder.function():
-            R.func_name("main")
-            qkv = R.arg("qkv", R.Tensor(qkv.shape, dtype))
-            if bias is not None:
-                bias = R.arg("bias", R.Tensor(bias.shape, dtype))
-            with R.dataflow() as frame:
-                if op == "split":
-                    qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2)
-                    q = R.reshape(qkv_tuple[0], qk_shape)
-                    k = R.reshape(qkv_tuple[1], qk_shape)
-                    v = R.reshape(qkv_tuple[2], v_shape)
-                elif op == "strided_slice":
-                    q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], 
[1]), qk_shape)
-                    k = R.reshape(R.strided_slice(qkv, [2], [n * h], [n * h * 
2], [1]), qk_shape)
-                    v = R.reshape(
-                        R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n 
* h_v], [1]), v_shape
-                    )
-                else:
-                    raise NotImplementedError()
-                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
-                R.output(result)
-
-            R.func_ret_value(frame.output_vars[0])
-
-    func = builder.get()
-    return tvm.IRModule({"main": func})
+    ref = tvm.topi.testing.attention_python(
+        q, k, v, bias, qk_scale, causal="none", window_size=None, layout="BSNH"
+    )
+    return qkv, bias, ref
 
 
 @pytest.fixture(
@@ -926,11 +795,30 @@ def 
test_stacked_attention_split_offload(stacked_attention_size):
     qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, scale, "float16")
     if scale == "none":
         mod = get_relax_stacked_attention_module(
-            qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape
+            qkv,
+            b,
+            s,
+            n,
+            h,
+            h_v,
+            "split",
+            bias,
+            single_shape=single_shape,
+            layout="BS3NH",
         )
     else:
         mod = get_relax_stacked_attention_module(
-            qkv, b, s, n, h, h_v, "split", bias, scale, 
single_shape=single_shape
+            qkv,
+            b,
+            s,
+            n,
+            h,
+            h_v,
+            "split",
+            bias,
+            scale,
+            single_shape=single_shape,
+            layout="BS3NH",
         )
 
     if bias is None:
@@ -945,11 +833,30 @@ def 
test_stacked_attention_strided_slice_offload(stacked_attention_size):
     qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, scale, "float32")
     if scale == "none":
         mod = get_relax_stacked_attention_module(
-            qkv, b, s, n, h, h_v, "strided_slice", bias, 
single_shape=single_shape
+            qkv,
+            b,
+            s,
+            n,
+            h,
+            h_v,
+            "strided_slice",
+            bias,
+            single_shape=single_shape,
+            layout="BS3NH",
         )
     else:
         mod = get_relax_stacked_attention_module(
-            qkv, b, s, n, h, h_v, "strided_slice", bias, scale, 
single_shape=single_shape
+            qkv,
+            b,
+            s,
+            n,
+            h,
+            h_v,
+            "strided_slice",
+            bias,
+            scale,
+            single_shape=single_shape,
+            layout="BS3NH",
         )
     if bias is None:
         out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=2)
diff --git a/tests/python/relax/test_transform_allocate_workspace.py 
b/tests/python/relax/test_transform_allocate_workspace.py
index 1198642d3f..248d195d65 100644
--- a/tests/python/relax/test_transform_allocate_workspace.py
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -95,7 +95,6 @@ class Expected:
         R.func_attr(
             {
                 "Codegen": "cutlass",
-                "WorkspaceSize": 65536,
                 "global_symbol": "fused_relax_nn_attention_cutlass1",
             }
         )
@@ -107,7 +106,7 @@ class Expected:
             v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
             workspace_1: R.Tensor((65536,), dtype="uint8"),
         ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
-            R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, 
"WorkspaceSize": 65536})
+            R.func_attr({"Composite": "cutlass.attention", "Primitive": 1})
             with R.dataflow():
                 gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = 
R.nn.attention(
                     q_1, k_1, v_1, scale=None
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py 
b/tests/python/relax/test_transform_merge_composite_functions.py
index 6a36314a74..cff832a21f 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -1053,7 +1053,6 @@ def test_reshape():
         @R.function
         def fused_relax_reshape_relax_matmul_tensorrt(
             inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"),
-            param_0: R.Shape([1, 784]),
             lv1: R.Tensor((784, 512), dtype="float32"),
         ) -> R.Tensor((1, 512), dtype="float32"):
             R.func_attr({"Codegen": "tensorrt"})
@@ -1069,7 +1068,7 @@ def test_reshape():
                     R.output(gv)
                 return gv
 
-            lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0)
+            lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, 
R.shape([1, 784]))
 
             @R.function
             def lv1_1_1(
@@ -1100,7 +1099,7 @@ def test_reshape():
                 )
                 gv: R.Tensor(
                     (1, 512), dtype="float32"
-                ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, 
R.shape([1, 784]), lv1)
+                ) = cls.fused_relax_reshape_relax_matmul_tensorrt(inp_0, lv1)
                 R.output(gv)
             return gv
 

Reply via email to