This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 38e2b886dd [Unity][BYOC] Integrate Flash attention v2 kernel into 
CUTLASS BYOC (#15467)
38e2b886dd is described below

commit 38e2b886dd97cfaa0031b204441c427ac598a856
Author: masahi <[email protected]>
AuthorDate: Tue Aug 8 07:11:02 2023 +0900

    [Unity][BYOC] Integrate Flash attention v2 kernel into CUTLASS BYOC (#15467)
    
    * Integrate flash attention into cutlass BYOC
    
    * update note on causal mask
    
    * disable flash v2 for sm < 80
---
 CMakeLists.txt                                    |   4 +
 cmake/modules/contrib/CUTLASS.cmake               |   1 +
 python/tvm/contrib/cutlass/attention_operation.py |  97 ++++++++++++++++
 python/tvm/contrib/cutlass/build.py               |   2 +
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 132 +++++++++++++---------
 tests/python/relax/test_codegen_cutlass.py        |   2 +-
 6 files changed, 184 insertions(+), 54 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 389879a883..47d57d56bd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -843,4 +843,8 @@ if(USE_CUDA AND USE_CUTLASS)
   install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION 
lib${LIB_SUFFIX})
   target_link_libraries(tvm PRIVATE fpA_intB_gemm)
   target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm)
+
+  install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION 
lib${LIB_SUFFIX})
+  target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
+  target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
 endif()
diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index f8aaa2f40d..bd3e3b1166 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -21,6 +21,7 @@ if(USE_CUDA AND USE_CUTLASS)
 
   set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
   add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
+  add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
   list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
 
   message(STATUS "Build with CUTLASS")
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index b6a9517f80..9b4fa78127 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -159,3 +159,100 @@ def instantiate_attention_template(attrs):
     )
 
     return substitute_template(template, attrs)
+
+
+def instantiate_flash_attention_template(attrs):
+    """Return host code for flash attention."""
+
+    template = """
+    int q_head_stride = ${head_dim};
+    int k_head_stride = ${head_dim};
+    int v_head_stride = ${head_dim};
+    int o_head_stride = ${head_dim};
+    int q_row_stride = q_head_stride * ${num_heads};
+    int k_row_stride = k_head_stride * ${num_heads};
+    int v_row_stride = v_head_stride * ${num_heads};
+    int o_row_stride = o_head_stride * ${num_heads};
+    int q_batch_stride = q_row_stride * ${num_queries};
+    int k_batch_stride = k_row_stride * ${num_keys};
+    int v_batch_stride = v_row_stride * ${num_keys};
+    int o_batch_stride = o_row_stride * ${num_queries};
+
+    flash_attn::flash_attention_forward(
+                            static_cast<const 
cutlass::half_t*>(${query}->data),
+                           static_cast<const cutlass::half_t*>(${key}->data),
+                           static_cast<const cutlass::half_t*>(${value}->data),
+                           static_cast<cutlass::half_t*>(out0->data),
+                           ${num_batches},
+                           ${num_queries},
+                           ${num_keys},
+                           ${num_heads},
+                           ${num_heads},
+                           ${head_dim},
+                           q_batch_stride,
+                           k_batch_stride,
+                           v_batch_stride,
+                           o_batch_stride,
+                           q_head_stride,
+                           k_head_stride,
+                           v_head_stride,
+                           o_head_stride,
+                           q_row_stride,
+                           k_row_stride,
+                           v_row_stride,
+                           o_row_stride,
+                           ${scale},
+                           ${is_causal},
+                           nullptr);
+    """
+
+    template_stacked = """
+    int q_head_stride = ${head_dim};
+    int k_head_stride = ${head_dim};
+    int v_head_stride = ${head_dim};
+    int o_head_stride = ${head_dim};
+    int row_stride = q_head_stride * ${num_heads} +
+                     k_head_stride * ${num_heads} +
+                     v_head_stride * ${num_heads};
+    int q_row_stride = row_stride;
+    int k_row_stride = row_stride;
+    int v_row_stride = row_stride;
+    int o_row_stride = o_head_stride * ${num_heads};
+
+    int q_batch_stride = q_row_stride * ${num_queries};
+    int k_batch_stride = k_row_stride * ${num_keys};
+    int v_batch_stride = v_row_stride * ${num_keys};
+    int o_batch_stride = o_row_stride * ${num_queries};
+
+    flash_attn::flash_attention_forward(
+    static_cast<const cutlass::half_t*>(${qkv}->data),
+                           static_cast<const cutlass::half_t*>(${qkv}->data) + 
${head_dim} * ${num_heads},
+                           static_cast<const cutlass::half_t*>(${qkv}->data) + 
${head_dim} * ${num_heads} * 2,
+                           static_cast<cutlass::half_t*>(out0->data),
+                           ${num_batches},
+                           ${num_queries},
+                           ${num_keys},
+                           ${num_heads},
+                           ${num_heads},
+                           ${head_dim},
+                           q_batch_stride,
+                           k_batch_stride,
+                           v_batch_stride,
+                           o_batch_stride,
+                           q_head_stride,
+                           k_head_stride,
+                           v_head_stride,
+                           o_head_stride,
+                           q_row_stride,
+                           k_row_stride,
+                           v_row_stride,
+                           o_row_stride,
+                           ${scale},
+                           ${is_causal},
+                           nullptr);
+    """
+
+    if "qkv" in attrs:
+        return substitute_template(template_stacked, attrs)
+
+    return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 1b8b88bb1d..7b1ab67172 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -59,6 +59,7 @@ def _get_cutlass_compile_options(sm, threads, 
use_fast_math=False):
     cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
     cutlass_attention_include = os.path.join(cutlass_root, 
"examples/41_fused_multi_head_attention")
     cutlass_fpA_intB_gemm_include = os.path.join(cutlass_root, 
"../cutlass_fpA_intB_gemm")
+    flash_attn_include = os.path.join(cutlass_root, "../libflash_attn/include")
 
     kwargs = {}
     kwargs["cc"] = "nvcc"
@@ -77,6 +78,7 @@ def _get_cutlass_compile_options(sm, threads, 
use_fast_math=False):
         f"-I{cutlass_util_include}",
         f"-I{cutlass_attention_include}",
         f"-I{cutlass_fpA_intB_gemm_include}",
+        f"-I{flash_attn_include}",
     ]
     if use_fast_math:
         kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index bf02d8f7b8..7133193c1e 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -29,7 +29,10 @@ from tvm.runtime import Object
 from tvm.tir import IntImm
 
 from . import _ffi_api as ffi
-from .attention_operation import instantiate_attention_template
+from .attention_operation import (
+    instantiate_attention_template,
+    instantiate_flash_attention_template,
+)
 from .conv2d_operation import instantiate_conv2d_template
 from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
 from .layer_norm_operation import instantiate_layer_norm_template
@@ -712,7 +715,6 @@ def instantiate_template(func_name, annotations, func_args):
         return CodegenResult(code, headers)
 
     elif "attention" in func_name:
-        headers.append("kernel_forward.h")
         data_type = dtype_map[annotations["arg0_dtype"]]
 
         attrs["qkv_layout"] = annotations["qkv_layout"]
@@ -739,62 +741,86 @@ def instantiate_template(func_name, annotations, 
func_args):
         attrs["head_dim"] = h = annotations["head_dim"]
         attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
         attrs["kMaxK"] = max(int(attrs["head_dim"]), 
int(attrs["head_dim_value"]))
-
-        data_type_size = DataTypeSize[data_type]
-        if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) 
% 16 == 0:
-            attrs["kIsAligned"] = True
-        elif (h % 4 == 0) and (h_v % 4 == 0):
-            attrs["kIsAligned"] = False
-        else:
-            raise NotImplementedError()
-        if h_v > 64:
-            attrs["kQueriesPerBlock"] = 32
-            attrs["kKeysPerBlock"] = 128
-            attrs["kSingleValueIteration"] = h_v <= 128
-        else:
-            attrs["kQueriesPerBlock"] = 64
-            attrs["kKeysPerBlock"] = 64
-            attrs["kSingleValueIteration"] = True
-        attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
         attrs["scale"] = (
             float(1 / math.sqrt(h.value)) if annotations["scale"] is None else 
annotations["scale"]
         )
-        attrs["custom_mask_type"] = annotations["custom_mask_type"]
-
-        assert (
-            attrs["scale"] > 0 or attrs["scale"] < 0
-        ), "Cutlass may generate nan occasionally when scale == 0.0"
-        attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
-        attrs["kSupportsDropout"] = False
-
-        for arg in func_args:
-            if "workspace" in arg:
-                attrs["workspace"] = arg
-        if "bias" in attrs:
-            attrs["kSupportsBias"] = True
-            if len(annotations["bias_shape"]) == 4:
-                strides = "p.num_keys"
-                if annotations["bias_shape"][2] == 1:
-                    attrs["bias_strideM"] = 0
-                else:
-                    attrs["bias_strideM"] = strides
-                    strides = f"p.num_queries * {strides}"
-                if annotations["bias_shape"][1] == 1:
-                    attrs["bias_strideH"] = 0
-                else:
-                    attrs["bias_strideH"] = strides
-                    strides = f"p.num_heads * {strides}"
-                if annotations["bias_shape"][0] == 1:
-                    attrs["bias_strideB"] = 0
-                else:
-                    attrs["bias_strideB"] = strides
+
+        use_flash = (
+            annotations["ret_dtype"] == "float16"
+            and "bias" not in attrs
+            and int(attrs["head_dim"]) <= 256
+            and int(attrs["head_dim"]) % 8 == 0
+            and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
+            # We have not thoroughly validated flash with causal mask yet, so 
for now we support
+            # only non-causal cases.
+            and int(annotations["custom_mask_type"]) == 0
+            # Flash v2 is currently not supported for sm < 80
+            and int(annotations["arch"]) >= 80
+        )
+
+        if use_flash:
+            headers.append("flash.h")
+            attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0
+            code = instantiate_flash_attention_template(attrs)
+        else:
+            headers.append("kernel_forward.h")
+
+            data_type_size = DataTypeSize[data_type]
+            if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v 
// 8) % 16 == 0:
+                attrs["kIsAligned"] = True
+            elif (h % 4 == 0) and (h_v % 4 == 0):
+                attrs["kIsAligned"] = False
             else:
                 raise NotImplementedError()
-        else:
-            # To support negative scale in current Cutlass implementation,
-            # kSupportsBias should be set true, or there are nan's as result.
-            attrs["kSupportsBias"] = attrs["scale"] < 0
-        code = instantiate_attention_template(attrs)
+            if h_v > 64:
+                attrs["kQueriesPerBlock"] = 32
+                attrs["kKeysPerBlock"] = 128
+                attrs["kSingleValueIteration"] = h_v <= 128
+            else:
+                attrs["kQueriesPerBlock"] = 64
+                attrs["kKeysPerBlock"] = 64
+                attrs["kSingleValueIteration"] = True
+
+            assert (
+                attrs["scale"] > 0 or attrs["scale"] < 0
+            ), "Cutlass may generate nan occasionally when scale == 0.0"
+            attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
+            attrs["kSupportsDropout"] = False
+
+            attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
+
+            attrs["custom_mask_type"] = annotations["custom_mask_type"]
+
+            for arg in func_args:
+                if "workspace" in arg:
+                    attrs["workspace"] = arg
+            if "bias" in attrs:
+                attrs["kSupportsBias"] = True
+                if len(annotations["bias_shape"]) == 4:
+                    strides = "p.num_keys"
+                    if annotations["bias_shape"][2] == 1:
+                        attrs["bias_strideM"] = 0
+                    else:
+                        attrs["bias_strideM"] = strides
+                        strides = f"p.num_queries * {strides}"
+                    if annotations["bias_shape"][1] == 1:
+                        attrs["bias_strideH"] = 0
+                    else:
+                        attrs["bias_strideH"] = strides
+                        strides = f"p.num_heads * {strides}"
+                    if annotations["bias_shape"][0] == 1:
+                        attrs["bias_strideB"] = 0
+                    else:
+                        attrs["bias_strideB"] = strides
+                else:
+                    raise NotImplementedError()
+            else:
+                # To support negative scale in current Cutlass implementation,
+                # kSupportsBias should be set true, or there are nan's as 
result.
+                attrs["kSupportsBias"] = attrs["scale"] < 0
+
+            code = instantiate_attention_template(attrs)
+
         return CodegenResult(code, headers)
     elif "layer_norm" in func_name:
         headers.append("cutlass/util/device_layernorm.h")
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 30286a597a..d19189ff34 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -854,7 +854,7 @@ def stacked_attention_size(request):
 
 def test_stacked_attention_split_offload(stacked_attention_size):
     b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size
-    qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, scale, "float32")
+    qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, 
bias_shape, scale, "float16")
     if scale == "none":
         mod = get_relax_stacked_attention_module(
             qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape

Reply via email to