This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d5725a8443 [Unity][CUTLASS] Support batched matmul + residual fusion
(#14613)
d5725a8443 is described below
commit d5725a84430c139e2992ba5cbee557fddcead724
Author: masahi <[email protected]>
AuthorDate: Sat Apr 15 05:02:08 2023 +0900
[Unity][CUTLASS] Support batched matmul + residual fusion (#14613)
support batched matmul + residual fusion
---
python/tvm/contrib/cutlass/gemm_operation.py | 21 ++++++++++-----------
tests/python/relax/test_codegen_cutlass.py | 2 ++
2 files changed, 12 insertions(+), 11 deletions(-)
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index b820ead016..60ee106919 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -297,12 +297,11 @@ def instantiate_gemm_template(attrs):
"""
# See cutlass/gemm/kernel/gemm_with_fused_epilogue.h
- # Batched GEMM + residual fusion is not supported for now.
argument_template_residual = """
typename ${kernel}::Arguments arguments{
- cutlass::gemm::GemmUniversalMode::kGemm,
+ cutlass::gemm::GemmUniversalMode::${gemm_universal_mode},
problem_size,
- 1, // batch_count,
+ ${split_k_slices_or_batch}, // batch_count
{${alpha_beta}},
static_cast<ElementInputA*>(ptr_a),
static_cast<ElementInputB*>(ptr_b),
@@ -310,10 +309,10 @@ def instantiate_gemm_template(attrs):
static_cast<ElementOutput*>(ptr_out),
static_cast<ElementOutput*>(ptr_bias),
nullptr, // ptr_Tensor
- 0, // batch_stride_A,
- 0, // batch_stride_B,
- 0, // batch_stride_C,
- 0, // batch_stride_D,
+ ${batch_stride_A}
+ ${batch_stride_B}
+ ${batch_stride_C}
+ ${batch_stride_D}
0, // batch_stride_Vector,
0, // batch_stride_Tensor,
${lda},
@@ -388,13 +387,13 @@ def instantiate_gemm_template(attrs):
aux_map["alpha_beta"] = "alpha, beta"
for key in ["batch_stride_A", "batch_stride_B", "batch_stride_C"]:
- if not batched:
+ if not batched and not has_residual_block:
aux_map[key] = ""
else:
- aux_map[key] = attrs[key] + ","
+ aux_map[key] = attrs.get(key, "0") + ","
aux_map["batch_stride_D"] = aux_map["batch_stride_C"]
- if has_bias and batched:
+ if has_bias and batched and not has_residual_block:
aux_map["batch_stride_C"] = "0,"
if batched:
@@ -403,9 +402,9 @@ def instantiate_gemm_template(attrs):
attrs["split_k_slices_or_batch"] = 1
if has_residual_block:
- assert not batched, "Residual fusion is supported only for non-batched
GEMM for now."
template = substitute_template(template, {"argument":
argument_template_residual})
aux_map["residual_decl"] = "void* ptr_residual =
(void*)(${residual_arg}->data);\n"
+ aux_map["gemm_universal_mode"] = "kBatched" if batched else "kGemm"
else:
template = substitute_template(template, {"argument":
argument_template_default})
aux_map["residual_decl"] = ""
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 9288db3eb5..4309627bf0 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -317,6 +317,8 @@ def test_cutlass_partition_conv2d_residual_blocked():
# Residual
((32, 8), (8, 8), False, "bias", "add"),
((4, 16), (16, 16), True, "relu", "add_relu"),
+ ((8, 32, 8), (8, 8, 8), False, "bias", "add"),
+ ((5, 3, 32, 8), (8, 8), True, "relu", "add"),
# Residual fusion without bias - this is supported via the matmul +
bias pattern
# where bias == residual input
((4, 16), (16, 16), False, "none", "add"),