This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d56fa44b93 [Unity][BYOC] Support matmul + residual block fusion in
CUTLASS BYOC (#14317)
d56fa44b93 is described below
commit d56fa44b9378acab57d8150175dcd7d803af38ce
Author: masahi <[email protected]>
AuthorDate: Fri Mar 17 08:20:02 2023 +0900
[Unity][BYOC] Support matmul + residual block fusion in CUTLASS BYOC
(#14317)
* enable residual fusion support for matmul
* disallow residual fusion without bias
* support conv2d + residual add without bias via conv2d + bias pattern
---
python/tvm/contrib/cutlass/build.py | 1 +
python/tvm/contrib/cutlass/conv2d_operation.py | 14 +--
python/tvm/contrib/cutlass/gemm_operation.py | 132 ++++++++++++++++++-------
python/tvm/contrib/cutlass/gen_gemm.py | 38 ++++++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 22 +++--
python/tvm/relax/backend/contrib/cutlass.py | 30 +++---
tests/python/relax/test_codegen_cutlass.py | 67 ++++++++-----
7 files changed, 223 insertions(+), 81 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 416c780fec..7e92e6a887 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -765,6 +765,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"op_type": op_type,
"lhs_arg_idx": arg_idx["lhs"],
"rhs_arg_idx": arg_idx["rhs"],
+ "residual_arg_idx": arg_idx.get("residual"),
"bias_arg_idx": arg_idx.get("bias"),
"arg0_dtype": signature["arg0_dtype"],
"arg1_dtype": signature["arg1_dtype"],
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 5996d50d88..f2d2f01276 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -466,7 +466,7 @@ def instantiate_conv2d_template(attrs):
use_split_k = "splitk" in attrs["cutlass_op_name"]
is_wgrad = "backward_weight" in op_type
is_dgrad = "conv2d_transpose" in op_type
- has_residual_blcok = "residual" in op_type
+ has_residual_block = "residual" in op_type
no_bias_scaling = op_type not in [
"cutlass.conv2d_bias_sigmoid",
"cutlass.conv2d_bias_silu",
@@ -475,12 +475,12 @@ def instantiate_conv2d_template(attrs):
aux_map = {}
- if (not has_bias or no_bias_scaling) and not has_residual_blcok:
- aux_map["beta"] = "0"
+ if (not has_bias or no_bias_scaling) and not has_residual_block:
+ aux_map["beta"] = 0
else:
- aux_map["beta"] = "1"
+ aux_map["beta"] = 1
- if has_residual_blcok:
+ if has_residual_block:
aux_map["bias_decl"] = "void* ptr_bias = (void*)(${bias_arg}->data);\n"
aux_map["residual_decl"] = "void* ptr_residual =
(void*)(${residual_arg}->data);"
aux_map["tensor_c"] = "ptr_residual"
@@ -496,12 +496,12 @@ def instantiate_conv2d_template(attrs):
aux_map["tensor_c"] = "ptr_out"
aux_map["tensor_c_layout"] = "layout_C"
- if has_bias and no_bias_scaling and not has_residual_blcok:
+ if has_bias and no_bias_scaling and not has_residual_block:
aux_map["alpha_beta"] = "alpha"
else:
aux_map["alpha_beta"] = "alpha, beta"
- if has_residual_blcok:
+ if has_residual_block:
aux_map["additional_args"] = ", static_cast<ElementOutput*>(ptr_bias),
nullptr, 0, K"
else:
aux_map["additional_args"] = ""
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 1675e1f035..eb9f92dad3 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -164,6 +164,7 @@ class EmitGemmInstance:
${element_accumulator},
${element_epilogue}
>"""
+
self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
@@ -172,6 +173,19 @@ class EmitGemmInstance:
${element_epilogue},
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""
+
+ self.epilogue_residual_block = """
+ ${epilogue_functor}<
+ ${element_c},
+ ${element_accumulator},
+ ${element_epilogue},
+ ${element_c},
+ ${epilogue_vector_length},
+ ${activation},
+ ${binary_op},
+ ${unary_op}
+ >"""
+
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
@@ -188,13 +202,11 @@ class EmitGemmInstance:
${swizzling_functor},
${stages},
${align_a},
- ${align_b},
- ${split_k_serial}
- ${math_operation}
+ ${align_b}
>;
"""
- def emit(self, operation, no_beta_scaling=False, batched=False):
+ def emit(self, operation, no_beta_scaling=False, batched=False,
residual_block_info=False):
"""Instantiate a GEMM kernel from given `operation`."""
warp_shape = [
operation.tile_description.threadblock_shape[idx]
@@ -246,22 +258,73 @@ class EmitGemmInstance:
}
values["kernel_name"] = "GemmBatched" if batched else "Gemm"
- values["split_k_serial"] = "" if batched else "false,"
- gemm_template = substitute_template(
- self.gemm_template,
- {
- "epilogue": self.epilogue_no_beta_scaling
- if no_beta_scaling
- else self.epilogue_default
- },
- )
- return substitute_template(gemm_template, values)
+ if residual_block_info:
+ values["kernel_name"] = "GemmUniversalWithBroadcast"
+ template = substitute_template(
+ self.gemm_template, {"epilogue": self.epilogue_residual_block}
+ )
+ values.update(
+ {
+ "unary_op": residual_block_info["unary_op"],
+ "binary_op": residual_block_info["binary_op"],
+ "activation": residual_block_info["activation"],
+ }
+ )
+ elif no_beta_scaling:
+ template = substitute_template(
+ self.gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
+ )
+ else:
+ template = substitute_template(self.gemm_template, {"epilogue":
self.epilogue_default})
+
+ return substitute_template(template, values)
def instantiate_gemm_template(attrs):
"""Return CUTLASS host code for GEMM based on a template and the provided
attribute map."""
+ argument_template_default = """
+ typename ${kernel}::Arguments arguments{
+ problem_size,
+ {static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
+ {static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
+ {static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
+ {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
+ {${alpha_beta}},
+ ${split_k_slices_or_batch}
+ };
+ """
+
+ # See cutlass/gemm/kernel/gemm_with_fused_epilogue.h
+ # Batched GEMM + residual fusion is not supported for now.
+ argument_template_residual = """
+ typename ${kernel}::Arguments arguments{
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ 1, // batch_count,
+ {${alpha_beta}},
+ static_cast<ElementInputA*>(ptr_a),
+ static_cast<ElementInputB*>(ptr_b),
+ static_cast<ElementOutput*>(ptr_residual),
+ static_cast<ElementOutput*>(ptr_out),
+ static_cast<ElementOutput*>(ptr_bias),
+ nullptr, // ptr_Tensor
+ 0, // batch_stride_A,
+ 0, // batch_stride_B,
+ 0, // batch_stride_C,
+ 0, // batch_stride_D,
+ 0, // batch_stride_Vector,
+ 0, // batch_stride_Tensor,
+ ${lda},
+ ${ldb},
+ ${ldc},
+ ${ldc},
+ 0, // ldv, the stride for bias
+ 0, // ldt
+ };
+ """
+
template = """
using ElementInputA = ${ElementInputA};
using ElementInputB = ${ElementInputB};
@@ -280,17 +343,10 @@ def instantiate_gemm_template(attrs):
void* ptr_a = (void*)(${lhs_arg}->data);
void* ptr_b = (void*)(${rhs_arg}->data);
${bias_decl}
+ ${residual_decl}
void* ptr_out = (void*)(out0->data);
- typename ${kernel}::Arguments arguments{
- problem_size,
- {static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
- {static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
- {static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
- {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
- {${alpha_beta}},
- ${split_k_slices_or_batch}
- };
+ ${argument}
size_t workspace_size = ${kernel}::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
${kernel} gemm_op;
@@ -301,29 +357,31 @@ def instantiate_gemm_template(attrs):
status = gemm_op();
CHECK(status == cutlass::Status::kSuccess);
"""
- has_bias = "bias" in attrs["op_type"]
- is_gelu = "gelu" in attrs["op_type"]
+ op_type = attrs["op_type"]
+ has_bias = "bias" in op_type
+ is_gelu = "gelu" in op_type
batched = "batch" in attrs
+ has_residual_block = "residual" in op_type
aux_map = {"kernel": "Gemm"}
if has_bias:
aux_map.update(
{
- "bias_decl": "void* ptr_c_bias =
(void*)(${bias_arg}->data);\n",
- "ptr_c": "ptr_c_bias",
- "c_stride": "0",
+ "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n",
+ "ptr_c": "ptr_bias",
+ "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"],
}
)
else:
aux_map.update({"bias_decl": "", "ptr_c": "ptr_out", "c_stride":
attrs["ldc"]})
- if is_gelu:
+ if is_gelu or has_residual_block:
# GeLU epilogue does not compile with NoBetaScaling, so we explicitly
specify the scale.
- aux_map["beta"] = "1"
+ aux_map["beta"] = 1
else:
- aux_map["beta"] = "0"
+ aux_map["beta"] = 0
- if has_bias and not is_gelu:
+ if has_bias and not is_gelu and not has_residual_block:
aux_map["alpha_beta"] = "alpha"
else:
aux_map["alpha_beta"] = "alpha, beta"
@@ -341,7 +399,15 @@ def instantiate_gemm_template(attrs):
if batched:
attrs["split_k_slices_or_batch"] = attrs["batch"]
else:
- attrs["split_k_slices_or_batch"] = "1"
+ attrs["split_k_slices_or_batch"] = 1
+
+ if has_residual_block:
+ assert not batched, "Residual fusion is supported only for non-batched
GEMM for now."
+ template = substitute_template(template, {"argument":
argument_template_residual})
+ aux_map["residual_decl"] = "void* ptr_residual =
(void*)(${residual_arg}->data);\n"
+ else:
+ template = substitute_template(template, {"argument":
argument_template_default})
+ aux_map["residual_decl"] = ""
template = substitute_template(template, aux_map)
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py
b/python/tvm/contrib/cutlass/gen_gemm.py
index 6aa4c51221..f5f160a400 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -53,7 +53,36 @@ def create_gemm_operator_with_epilogue(
if batched:
swizzling_functor = SwizzlingFunctor.Batched
- epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+ if "residual" in op_type:
+ if "hardswish" in op_type:
+ activation = "cutlass::epilogue::thread::HardSwish"
+ elif "silu" in op_type:
+ activation = "cutlass::epilogue::thread::SiLu"
+ elif "sigmoid" in op_type:
+ activation = "cutlass::epilogue::thread::Sigmoid"
+ elif "gelu" in op_type:
+ activation = "cutlass::epilogue::thread::GELU"
+ elif "relu" in op_type:
+ activation = "cutlass::epilogue::thread::ReLu"
+ else:
+ activation = "cutlass::epilogue::thread::Identity"
+
+ binary_op = "cutlass::multiplies" if "residual_multiply" in op_type
else "cutlass::plus"
+ unary_op = (
+ "cutlass::epilogue::thread::ReLu"
+ if op_type.endswith("relu")
+ else "cutlass::epilogue::thread::Identity"
+ )
+ residual_block_info = {
+ "activation": activation,
+ "binary_op": binary_op,
+ "unary_op": unary_op,
+ }
+ epilogue = EpilogueFunctor.LinearCombinationResidualBlock
+ no_beta_scaling = False
+ else:
+ residual_block_info = None
+ epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
op = GemmOperation(
tile_description.minimum_compute_capability,
@@ -68,7 +97,12 @@ def create_gemm_operator_with_epilogue(
return (
op.procedural_name(),
- EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling,
batched=batched),
+ EmitGemmInstance().emit(
+ op,
+ no_beta_scaling=no_beta_scaling,
+ batched=batched,
+ residual_block_info=residual_block_info,
+ ),
)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 9833600be0..6b2587a0b0 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -535,7 +535,9 @@ def instantiate_template(func_name, annotations, func_args):
transposed = "transposed" in func_name
lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx",
0)
rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx",
1)
- bias_arg_idx = _get_optional_int_annotation(annotations,
"bias_arg_idx", 2)
+ bias_arg_idx = _get_optional_int_annotation(annotations,
"bias_arg_idx", None)
+ residual_arg_idx = _get_optional_int_annotation(annotations,
"residual_arg_idx", None)
+
lhs_arg = func_args[lhs_arg_idx]
rhs_arg = func_args[rhs_arg_idx]
lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
@@ -545,8 +547,12 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["lhs_arg"] = lhs_arg
attrs["rhs_arg"] = rhs_arg
- if len(func_args) > 2:
+
+ if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
+ if residual_arg_idx is not None:
+ attrs["residual_arg"] = func_args[residual_arg_idx]
+
attrs["ElementInputA"] =
DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]]
attrs["ElementInputB"] =
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
attrs["ElementOutput"] =
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
@@ -610,20 +616,24 @@ def instantiate_template(func_name, annotations,
func_args):
else:
headers.append("cutlass/gemm/device/gemm.h")
+ if "residual" in func_name:
+
headers.append("cutlass/gemm/device/gemm_universal_with_broadcast.h")
+
code = instantiate_gemm_template(attrs)
return CodegenResult(code, headers)
elif "conv2d" in func_name:
data_arg_idx = _get_optional_int_annotation(annotations,
"data_arg_idx", 0)
weight_arg_idx = _get_optional_int_annotation(annotations,
"weight_arg_idx", 1)
- bias_arg_idx = _get_optional_int_annotation(annotations,
"bias_arg_idx", 2)
- residual_arg_idx = _get_optional_int_annotation(annotations,
"residual_arg_idx", 3)
+ bias_arg_idx = _get_optional_int_annotation(annotations,
"bias_arg_idx", None)
+ residual_arg_idx = _get_optional_int_annotation(annotations,
"residual_arg_idx", None)
attrs["data_arg"] = func_args[data_arg_idx]
attrs["weight_arg"] = func_args[weight_arg_idx]
- if len(func_args) > bias_arg_idx:
+
+ if bias_arg_idx is not None:
attrs["bias_arg"] = func_args[bias_arg_idx]
- if len(func_args) > residual_arg_idx:
+ if residual_arg_idx is not None:
attrs["residual_arg"] = func_args[residual_arg_idx]
activation_shape = annotations[f"arg{data_arg_idx}_shape"]
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 29026b194d..e1b9226d68 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -195,17 +195,25 @@ def residual_block_patterns():
patterns = []
for activation, name_postfix in [(None, ""), ("relax.nn.relu", "_relu")]:
- for name, pat, arg_pat, _ in conv2d_patterns()[1:]:
- for bin_op in ["relax.add", "relax.multiply"]:
- patterns.append(
- (
- name + "_residual_" + bin_op.split(".")[-1] +
name_postfix,
- *make_residual_block_pattern(
- (pat, arg_pat), binary_op=bin_op,
activation=activation
- ),
- _check_conv2d,
- )
- )
+ for check, base_patterns in [
+ (_check_conv2d, conv2d_patterns()),
+ (_check_matmul, matmul_patterns()),
+ ]:
+ for name, pat, arg_pat, _ in base_patterns:
+ # Append residual patterns only to those base patterns with
bias add,
+ # since conv2d or matmul + residual add without bias is
already supported
+ # via conv2d or matmul + bias patterns (the residual input is
treated as "bias").
+ if "bias" in name:
+ for bin_op in ["relax.add", "relax.multiply"]:
+ patterns.append(
+ (
+ name + "_residual_" + bin_op.split(".")[-1] +
name_postfix,
+ *make_residual_block_pattern(
+ (pat, arg_pat), binary_op=bin_op,
activation=activation
+ ),
+ check,
+ )
+ )
return patterns
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 02802211a9..de15f7083a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -101,6 +101,7 @@ def get_result_with_relax_cutlass_offload(mod, *args,
assert_all_bindings_fused=
assert len(patterns) != 0, "Cannot find cutlass patterns"
mod = partition_for_cutlass(mod)
+
if assert_all_bindings_fused:
assert len(mod["main"].body.blocks[0].bindings) == 1
@@ -158,8 +159,8 @@ def get_relax_conv2d_module(
output = R.emit(activation(output))
if residual_bin_op is not None:
output = R.emit(residual_bin_op(output, data))
- if residual_activation is not None:
- output = R.emit(residual_activation(output))
+ if residual_activation is not None:
+ output = R.emit(residual_activation(output))
R.output(output)
R.func_ret_value(frame.output_vars[0])
@@ -169,7 +170,14 @@ def get_relax_conv2d_module(
def get_relax_matmul_module(
- x_shape, y_shape, dtype, transposed_y=False, with_bias=False,
activation=None
+ x_shape,
+ y_shape,
+ dtype,
+ transposed_y=False,
+ with_bias=False,
+ activation=None,
+ residual_bin_op=None,
+ residual_activation=None,
):
if transposed_y:
n = y_shape[-2]
@@ -193,6 +201,10 @@ def get_relax_matmul_module(
result = R.emit(result + bias)
if activation is not None:
result = R.emit(activation(result))
+ if residual_bin_op is not None:
+ result = R.emit(residual_bin_op(result, x))
+ if residual_activation is not None:
+ result = R.emit(residual_activation(result))
R.output(result)
R.func_ret_value(frame.output_vars[0])
@@ -285,34 +297,40 @@ def test_conv2d_offload(data_shape, weight_shape, dtype,
epilogue, residual_bloc
@pytest.mark.parametrize(
- "x_shape, y_shape, transpose_y, epilogue",
+ "x_shape, y_shape, transpose_y, epilogue, residual_block",
[
# Regular
- ((32, 6), (6, 16), False, "none"),
- ((_vars["a"], 6), (6, 16), False, "bias"),
+ ((32, 6), (6, 16), False, "none", "none"),
+ ((_vars["a"], 6), (6, 16), False, "bias", "none"),
# Transposed
- ((4, 16), (16, 128), True, "relu"),
- ((35, 8), (8, 8), True, "gelu"),
+ ((4, 16), (16, 128), True, "relu", "none"),
+ ((35, 8), (8, 8), True, "gelu", "none"),
# 3D x 3D
- ((6, 32, 8), (6, 8, 10), False, "bias"),
- ((6, 32, 8), (6, 8, 10), True, "none"),
- ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
+ ((6, 32, 8), (6, 8, 10), False, "bias", "none"),
+ ((6, 32, 8), (6, 8, 10), True, "none", "none"),
+ ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
# 3D x 2D
- ((6, 32, 8), (8, 10), False, "none"),
- ((_vars["a"], 32, 8), (8, 10), False, "bias"),
- ((10, 16, 8), (8, 10), True, "relu"),
+ ((6, 32, 8), (8, 10), False, "none", "none"),
+ ((_vars["a"], 32, 8), (8, 10), False, "bias", "none"),
+ ((10, 16, 8), (8, 10), True, "relu", "none"),
# 2D x 3D
- ((32, 8), (10, 8, 10), False, "relu"),
- ((32, 8), (_vars["a"], 8, 10), True, "gelu"),
+ ((32, 8), (10, 8, 10), False, "relu", "none"),
+ ((32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
# ND x 2D
- ((3, 6, 32, 8), (8, 10), False, "bias"),
- ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"),
+ ((3, 6, 32, 8), (8, 10), False, "bias", "none"),
+ ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none", "none"),
# 2D x ND
- ((32, 8), (5, 3, 8, 10), False, "gelu"),
+ ((32, 8), (5, 3, 8, 10), False, "gelu", "none"),
# ND x ND
- ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
- ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"),
- ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"),
+ ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu", "none"),
+ ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu", "none"),
+ ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none", "none"),
+ # Residual
+ ((32, 8), (8, 8), False, "bias", "add"),
+ ((4, 16), (16, 16), True, "relu", "add_relu"),
+ # Residual fusion without bias - this is supported via the matmul +
bias pattern
+ # where bias == residual input
+ ((4, 16), (16, 16), False, "none", "add"),
],
)
@pytest.mark.parametrize(
@@ -326,6 +344,7 @@ def test_matmul_offload(
y_shape,
transpose_y,
epilogue,
+ residual_block,
dtype,
):
with_bias, activation = _epilogue_table[epilogue]
@@ -346,6 +365,8 @@ def test_matmul_offload(
bias = None
args = (x, y)
+ residual_bin_op, residual_activation =
_residual_block_table[residual_block]
+
mod = get_relax_matmul_module(
x_shape,
y_shape,
@@ -353,6 +374,8 @@ def test_matmul_offload(
with_bias=with_bias,
transposed_y=transpose_y,
activation=activation,
+ residual_bin_op=residual_bin_op,
+ residual_activation=residual_activation,
)
out = get_result_with_relax_cutlass_offload(mod, *args)
ref = build_and_run(mod, args, "llvm", legalize=True)