This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 38e2b886dd [Unity][BYOC] Integrate Flash attention v2 kernel into
CUTLASS BYOC (#15467)
38e2b886dd is described below
commit 38e2b886dd97cfaa0031b204441c427ac598a856
Author: masahi <[email protected]>
AuthorDate: Tue Aug 8 07:11:02 2023 +0900
[Unity][BYOC] Integrate Flash attention v2 kernel into CUTLASS BYOC (#15467)
* Integrate flash attention into cutlass BYOC
* update note on causal mask
* disable flash v2 for sm < 80
---
CMakeLists.txt | 4 +
cmake/modules/contrib/CUTLASS.cmake | 1 +
python/tvm/contrib/cutlass/attention_operation.py | 97 ++++++++++++++++
python/tvm/contrib/cutlass/build.py | 2 +
python/tvm/contrib/cutlass/gen_tensor_op.py | 132 +++++++++++++---------
tests/python/relax/test_codegen_cutlass.py | 2 +-
6 files changed, 184 insertions(+), 54 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 389879a883..47d57d56bd 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -843,4 +843,8 @@ if(USE_CUDA AND USE_CUTLASS)
install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION
lib${LIB_SUFFIX})
target_link_libraries(tvm PRIVATE fpA_intB_gemm)
target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm)
+
+ install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION
lib${LIB_SUFFIX})
+ target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
+ target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
endif()
diff --git a/cmake/modules/contrib/CUTLASS.cmake
b/cmake/modules/contrib/CUTLASS.cmake
index f8aaa2f40d..bd3e3b1166 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -21,6 +21,7 @@ if(USE_CUDA AND USE_CUTLASS)
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
+ add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
message(STATUS "Build with CUTLASS")
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index b6a9517f80..9b4fa78127 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -159,3 +159,100 @@ def instantiate_attention_template(attrs):
)
return substitute_template(template, attrs)
+
+
+def instantiate_flash_attention_template(attrs):
+ """Return host code for flash attention."""
+
+ template = """
+ int q_head_stride = ${head_dim};
+ int k_head_stride = ${head_dim};
+ int v_head_stride = ${head_dim};
+ int o_head_stride = ${head_dim};
+ int q_row_stride = q_head_stride * ${num_heads};
+ int k_row_stride = k_head_stride * ${num_heads};
+ int v_row_stride = v_head_stride * ${num_heads};
+ int o_row_stride = o_head_stride * ${num_heads};
+ int q_batch_stride = q_row_stride * ${num_queries};
+ int k_batch_stride = k_row_stride * ${num_keys};
+ int v_batch_stride = v_row_stride * ${num_keys};
+ int o_batch_stride = o_row_stride * ${num_queries};
+
+ flash_attn::flash_attention_forward(
+ static_cast<const
cutlass::half_t*>(${query}->data),
+ static_cast<const cutlass::half_t*>(${key}->data),
+ static_cast<const cutlass::half_t*>(${value}->data),
+ static_cast<cutlass::half_t*>(out0->data),
+ ${num_batches},
+ ${num_queries},
+ ${num_keys},
+ ${num_heads},
+ ${num_heads},
+ ${head_dim},
+ q_batch_stride,
+ k_batch_stride,
+ v_batch_stride,
+ o_batch_stride,
+ q_head_stride,
+ k_head_stride,
+ v_head_stride,
+ o_head_stride,
+ q_row_stride,
+ k_row_stride,
+ v_row_stride,
+ o_row_stride,
+ ${scale},
+ ${is_causal},
+ nullptr);
+ """
+
+ template_stacked = """
+ int q_head_stride = ${head_dim};
+ int k_head_stride = ${head_dim};
+ int v_head_stride = ${head_dim};
+ int o_head_stride = ${head_dim};
+ int row_stride = q_head_stride * ${num_heads} +
+ k_head_stride * ${num_heads} +
+ v_head_stride * ${num_heads};
+ int q_row_stride = row_stride;
+ int k_row_stride = row_stride;
+ int v_row_stride = row_stride;
+ int o_row_stride = o_head_stride * ${num_heads};
+
+ int q_batch_stride = q_row_stride * ${num_queries};
+ int k_batch_stride = k_row_stride * ${num_keys};
+ int v_batch_stride = v_row_stride * ${num_keys};
+ int o_batch_stride = o_row_stride * ${num_queries};
+
+ flash_attn::flash_attention_forward(
+ static_cast<const cutlass::half_t*>(${qkv}->data),
+ static_cast<const cutlass::half_t*>(${qkv}->data) +
${head_dim} * ${num_heads},
+ static_cast<const cutlass::half_t*>(${qkv}->data) +
${head_dim} * ${num_heads} * 2,
+ static_cast<cutlass::half_t*>(out0->data),
+ ${num_batches},
+ ${num_queries},
+ ${num_keys},
+ ${num_heads},
+ ${num_heads},
+ ${head_dim},
+ q_batch_stride,
+ k_batch_stride,
+ v_batch_stride,
+ o_batch_stride,
+ q_head_stride,
+ k_head_stride,
+ v_head_stride,
+ o_head_stride,
+ q_row_stride,
+ k_row_stride,
+ v_row_stride,
+ o_row_stride,
+ ${scale},
+ ${is_causal},
+ nullptr);
+ """
+
+ if "qkv" in attrs:
+ return substitute_template(template_stacked, attrs)
+
+ return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 1b8b88bb1d..7b1ab67172 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -59,6 +59,7 @@ def _get_cutlass_compile_options(sm, threads,
use_fast_math=False):
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
cutlass_attention_include = os.path.join(cutlass_root,
"examples/41_fused_multi_head_attention")
cutlass_fpA_intB_gemm_include = os.path.join(cutlass_root,
"../cutlass_fpA_intB_gemm")
+ flash_attn_include = os.path.join(cutlass_root, "../libflash_attn/include")
kwargs = {}
kwargs["cc"] = "nvcc"
@@ -77,6 +78,7 @@ def _get_cutlass_compile_options(sm, threads,
use_fast_math=False):
f"-I{cutlass_util_include}",
f"-I{cutlass_attention_include}",
f"-I{cutlass_fpA_intB_gemm_include}",
+ f"-I{flash_attn_include}",
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index bf02d8f7b8..7133193c1e 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -29,7 +29,10 @@ from tvm.runtime import Object
from tvm.tir import IntImm
from . import _ffi_api as ffi
-from .attention_operation import instantiate_attention_template
+from .attention_operation import (
+ instantiate_attention_template,
+ instantiate_flash_attention_template,
+)
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
from .layer_norm_operation import instantiate_layer_norm_template
@@ -712,7 +715,6 @@ def instantiate_template(func_name, annotations, func_args):
return CodegenResult(code, headers)
elif "attention" in func_name:
- headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]
attrs["qkv_layout"] = annotations["qkv_layout"]
@@ -739,62 +741,86 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]),
int(attrs["head_dim_value"]))
-
- data_type_size = DataTypeSize[data_type]
- if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8)
% 16 == 0:
- attrs["kIsAligned"] = True
- elif (h % 4 == 0) and (h_v % 4 == 0):
- attrs["kIsAligned"] = False
- else:
- raise NotImplementedError()
- if h_v > 64:
- attrs["kQueriesPerBlock"] = 32
- attrs["kKeysPerBlock"] = 128
- attrs["kSingleValueIteration"] = h_v <= 128
- else:
- attrs["kQueriesPerBlock"] = 64
- attrs["kKeysPerBlock"] = 64
- attrs["kSingleValueIteration"] = True
- attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else
annotations["scale"]
)
- attrs["custom_mask_type"] = annotations["custom_mask_type"]
-
- assert (
- attrs["scale"] > 0 or attrs["scale"] < 0
- ), "Cutlass may generate nan occasionally when scale == 0.0"
- attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
- attrs["kSupportsDropout"] = False
-
- for arg in func_args:
- if "workspace" in arg:
- attrs["workspace"] = arg
- if "bias" in attrs:
- attrs["kSupportsBias"] = True
- if len(annotations["bias_shape"]) == 4:
- strides = "p.num_keys"
- if annotations["bias_shape"][2] == 1:
- attrs["bias_strideM"] = 0
- else:
- attrs["bias_strideM"] = strides
- strides = f"p.num_queries * {strides}"
- if annotations["bias_shape"][1] == 1:
- attrs["bias_strideH"] = 0
- else:
- attrs["bias_strideH"] = strides
- strides = f"p.num_heads * {strides}"
- if annotations["bias_shape"][0] == 1:
- attrs["bias_strideB"] = 0
- else:
- attrs["bias_strideB"] = strides
+
+ use_flash = (
+ annotations["ret_dtype"] == "float16"
+ and "bias" not in attrs
+ and int(attrs["head_dim"]) <= 256
+ and int(attrs["head_dim"]) % 8 == 0
+ and int(attrs["head_dim"]) == int(attrs["head_dim_value"])
+ # We have not thoroughly validated flash with causal mask yet, so
for now we support
+ # only non-causal cases.
+ and int(annotations["custom_mask_type"]) == 0
+ # Flash v2 is currently not supported for sm < 80
+ and int(annotations["arch"]) >= 80
+ )
+
+ if use_flash:
+ headers.append("flash.h")
+ attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0
+ code = instantiate_flash_attention_template(attrs)
+ else:
+ headers.append("kernel_forward.h")
+
+ data_type_size = DataTypeSize[data_type]
+ if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v
// 8) % 16 == 0:
+ attrs["kIsAligned"] = True
+ elif (h % 4 == 0) and (h_v % 4 == 0):
+ attrs["kIsAligned"] = False
else:
raise NotImplementedError()
- else:
- # To support negative scale in current Cutlass implementation,
- # kSupportsBias should be set true, or there are nan's as result.
- attrs["kSupportsBias"] = attrs["scale"] < 0
- code = instantiate_attention_template(attrs)
+ if h_v > 64:
+ attrs["kQueriesPerBlock"] = 32
+ attrs["kKeysPerBlock"] = 128
+ attrs["kSingleValueIteration"] = h_v <= 128
+ else:
+ attrs["kQueriesPerBlock"] = 64
+ attrs["kKeysPerBlock"] = 64
+ attrs["kSingleValueIteration"] = True
+
+ assert (
+ attrs["scale"] > 0 or attrs["scale"] < 0
+ ), "Cutlass may generate nan occasionally when scale == 0.0"
+ attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
+ attrs["kSupportsDropout"] = False
+
+ attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
+
+ attrs["custom_mask_type"] = annotations["custom_mask_type"]
+
+ for arg in func_args:
+ if "workspace" in arg:
+ attrs["workspace"] = arg
+ if "bias" in attrs:
+ attrs["kSupportsBias"] = True
+ if len(annotations["bias_shape"]) == 4:
+ strides = "p.num_keys"
+ if annotations["bias_shape"][2] == 1:
+ attrs["bias_strideM"] = 0
+ else:
+ attrs["bias_strideM"] = strides
+ strides = f"p.num_queries * {strides}"
+ if annotations["bias_shape"][1] == 1:
+ attrs["bias_strideH"] = 0
+ else:
+ attrs["bias_strideH"] = strides
+ strides = f"p.num_heads * {strides}"
+ if annotations["bias_shape"][0] == 1:
+ attrs["bias_strideB"] = 0
+ else:
+ attrs["bias_strideB"] = strides
+ else:
+ raise NotImplementedError()
+ else:
+ # To support negative scale in current Cutlass implementation,
+ # kSupportsBias should be set true, or there are nan's as
result.
+ attrs["kSupportsBias"] = attrs["scale"] < 0
+
+ code = instantiate_attention_template(attrs)
+
return CodegenResult(code, headers)
elif "layer_norm" in func_name:
headers.append("cutlass/util/device_layernorm.h")
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 30286a597a..d19189ff34 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -854,7 +854,7 @@ def stacked_attention_size(request):
def test_stacked_attention_split_offload(stacked_attention_size):
b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size
- qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, scale, "float32")
+ qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v,
bias_shape, scale, "float16")
if scale == "none":
mod = get_relax_stacked_attention_module(
qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape