This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d56fa44b93 [Unity][BYOC] Support matmul + residual block fusion in 
CUTLASS BYOC (#14317)
d56fa44b93 is described below

commit d56fa44b9378acab57d8150175dcd7d803af38ce
Author: masahi <[email protected]>
AuthorDate: Fri Mar 17 08:20:02 2023 +0900

    [Unity][BYOC] Support matmul + residual block fusion in CUTLASS BYOC 
(#14317)
    
    * enable residual fusion support for matmul
    
    * disallow residual fusion without bias
    
    * support conv2d + residual add without bias via conv2d + bias pattern
---
 python/tvm/contrib/cutlass/build.py            |   1 +
 python/tvm/contrib/cutlass/conv2d_operation.py |  14 +--
 python/tvm/contrib/cutlass/gemm_operation.py   | 132 ++++++++++++++++++-------
 python/tvm/contrib/cutlass/gen_gemm.py         |  38 ++++++-
 python/tvm/contrib/cutlass/gen_tensor_op.py    |  22 +++--
 python/tvm/relax/backend/contrib/cutlass.py    |  30 +++---
 tests/python/relax/test_codegen_cutlass.py     |  67 ++++++++-----
 7 files changed, 223 insertions(+), 81 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 416c780fec..7e92e6a887 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -765,6 +765,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 "op_type": op_type,
                 "lhs_arg_idx": arg_idx["lhs"],
                 "rhs_arg_idx": arg_idx["rhs"],
+                "residual_arg_idx": arg_idx.get("residual"),
                 "bias_arg_idx": arg_idx.get("bias"),
                 "arg0_dtype": signature["arg0_dtype"],
                 "arg1_dtype": signature["arg1_dtype"],
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 5996d50d88..f2d2f01276 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -466,7 +466,7 @@ def instantiate_conv2d_template(attrs):
     use_split_k = "splitk" in attrs["cutlass_op_name"]
     is_wgrad = "backward_weight" in op_type
     is_dgrad = "conv2d_transpose" in op_type
-    has_residual_blcok = "residual" in op_type
+    has_residual_block = "residual" in op_type
     no_bias_scaling = op_type not in [
         "cutlass.conv2d_bias_sigmoid",
         "cutlass.conv2d_bias_silu",
@@ -475,12 +475,12 @@ def instantiate_conv2d_template(attrs):
 
     aux_map = {}
 
-    if (not has_bias or no_bias_scaling) and not has_residual_blcok:
-        aux_map["beta"] = "0"
+    if (not has_bias or no_bias_scaling) and not has_residual_block:
+        aux_map["beta"] = 0
     else:
-        aux_map["beta"] = "1"
+        aux_map["beta"] = 1
 
-    if has_residual_blcok:
+    if has_residual_block:
         aux_map["bias_decl"] = "void* ptr_bias = (void*)(${bias_arg}->data);\n"
         aux_map["residual_decl"] = "void* ptr_residual = 
(void*)(${residual_arg}->data);"
         aux_map["tensor_c"] = "ptr_residual"
@@ -496,12 +496,12 @@ def instantiate_conv2d_template(attrs):
         aux_map["tensor_c"] = "ptr_out"
         aux_map["tensor_c_layout"] = "layout_C"
 
-    if has_bias and no_bias_scaling and not has_residual_blcok:
+    if has_bias and no_bias_scaling and not has_residual_block:
         aux_map["alpha_beta"] = "alpha"
     else:
         aux_map["alpha_beta"] = "alpha, beta"
 
-    if has_residual_blcok:
+    if has_residual_block:
         aux_map["additional_args"] = ", static_cast<ElementOutput*>(ptr_bias), 
nullptr, 0, K"
     else:
         aux_map["additional_args"] = ""
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 1675e1f035..eb9f92dad3 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -164,6 +164,7 @@ class EmitGemmInstance:
       ${element_accumulator},
       ${element_epilogue}
     >"""
+
         self.epilogue_no_beta_scaling = """
     ${epilogue_functor}<
       ${element_c},
@@ -172,6 +173,19 @@ class EmitGemmInstance:
       ${element_epilogue},
       cutlass::epilogue::thread::ScaleType::NoBetaScaling
     >"""
+
+        self.epilogue_residual_block = """
+    ${epilogue_functor}<
+      ${element_c},
+      ${element_accumulator},
+      ${element_epilogue},
+      ${element_c},
+      ${epilogue_vector_length},
+      ${activation},
+      ${binary_op},
+      ${unary_op}
+    >"""
+
         self.gemm_template = """
   // Gemm operator ${operation_name}
   using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
@@ -188,13 +202,11 @@ class EmitGemmInstance:
     ${swizzling_functor},
     ${stages},
     ${align_a},
-    ${align_b},
-    ${split_k_serial}
-    ${math_operation}
+    ${align_b}
   >;
 """
 
-    def emit(self, operation, no_beta_scaling=False, batched=False):
+    def emit(self, operation, no_beta_scaling=False, batched=False, 
residual_block_info=False):
         """Instantiate a GEMM kernel from given `operation`."""
         warp_shape = [
             operation.tile_description.threadblock_shape[idx]
@@ -246,22 +258,73 @@ class EmitGemmInstance:
         }
 
         values["kernel_name"] = "GemmBatched" if batched else "Gemm"
-        values["split_k_serial"] = "" if batched else "false,"
 
-        gemm_template = substitute_template(
-            self.gemm_template,
-            {
-                "epilogue": self.epilogue_no_beta_scaling
-                if no_beta_scaling
-                else self.epilogue_default
-            },
-        )
-        return substitute_template(gemm_template, values)
+        if residual_block_info:
+            values["kernel_name"] = "GemmUniversalWithBroadcast"
+            template = substitute_template(
+                self.gemm_template, {"epilogue": self.epilogue_residual_block}
+            )
+            values.update(
+                {
+                    "unary_op": residual_block_info["unary_op"],
+                    "binary_op": residual_block_info["binary_op"],
+                    "activation": residual_block_info["activation"],
+                }
+            )
+        elif no_beta_scaling:
+            template = substitute_template(
+                self.gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
+            )
+        else:
+            template = substitute_template(self.gemm_template, {"epilogue": 
self.epilogue_default})
+
+        return substitute_template(template, values)
 
 
 def instantiate_gemm_template(attrs):
     """Return CUTLASS host code for GEMM based on a template and the provided 
attribute map."""
 
+    argument_template_default = """
+  typename ${kernel}::Arguments arguments{
+   problem_size,
+   {static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
+   {static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
+   {static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
+   {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
+   {${alpha_beta}},
+   ${split_k_slices_or_batch}
+  };
+    """
+
+    # See cutlass/gemm/kernel/gemm_with_fused_epilogue.h
+    # Batched GEMM + residual fusion is not supported for now.
+    argument_template_residual = """
+  typename ${kernel}::Arguments arguments{
+    cutlass::gemm::GemmUniversalMode::kGemm,
+    problem_size,
+    1, // batch_count,
+    {${alpha_beta}},
+    static_cast<ElementInputA*>(ptr_a),
+    static_cast<ElementInputB*>(ptr_b),
+    static_cast<ElementOutput*>(ptr_residual),
+    static_cast<ElementOutput*>(ptr_out),
+    static_cast<ElementOutput*>(ptr_bias),
+    nullptr, // ptr_Tensor
+    0, // batch_stride_A,
+    0, // batch_stride_B,
+    0, // batch_stride_C,
+    0, // batch_stride_D,
+    0, // batch_stride_Vector,
+    0, // batch_stride_Tensor,
+    ${lda},
+    ${ldb},
+    ${ldc},
+    ${ldc},
+    0, // ldv, the stride for bias
+    0, // ldt
+  };
+    """
+
     template = """
   using ElementInputA = ${ElementInputA};
   using ElementInputB = ${ElementInputB};
@@ -280,17 +343,10 @@ def instantiate_gemm_template(attrs):
   void* ptr_a = (void*)(${lhs_arg}->data);
   void* ptr_b = (void*)(${rhs_arg}->data);
   ${bias_decl}
+  ${residual_decl}
   void* ptr_out = (void*)(out0->data);
 
-  typename ${kernel}::Arguments arguments{
-   problem_size,
-   {static_cast<ElementInputA*>(ptr_a), ${lda}}, ${batch_stride_A}
-   {static_cast<ElementInputB*>(ptr_b), ${ldb}}, ${batch_stride_B}
-   {static_cast<ElementOutput*>(${ptr_c}), ${c_stride}}, ${batch_stride_C}
-   {static_cast<ElementOutput*>(ptr_out), ${ldc}}, ${batch_stride_D}
-   {${alpha_beta}},
-   ${split_k_slices_or_batch}
-  };
+  ${argument}
   size_t workspace_size = ${kernel}::get_workspace_size(arguments);
   cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
   ${kernel} gemm_op;
@@ -301,29 +357,31 @@ def instantiate_gemm_template(attrs):
   status = gemm_op();
   CHECK(status == cutlass::Status::kSuccess);
 """
-    has_bias = "bias" in attrs["op_type"]
-    is_gelu = "gelu" in attrs["op_type"]
+    op_type = attrs["op_type"]
+    has_bias = "bias" in op_type
+    is_gelu = "gelu" in op_type
     batched = "batch" in attrs
+    has_residual_block = "residual" in op_type
     aux_map = {"kernel": "Gemm"}
 
     if has_bias:
         aux_map.update(
             {
-                "bias_decl": "void* ptr_c_bias = 
(void*)(${bias_arg}->data);\n",
-                "ptr_c": "ptr_c_bias",
-                "c_stride": "0",
+                "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n",
+                "ptr_c": "ptr_bias",
+                "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"],
             }
         )
     else:
         aux_map.update({"bias_decl": "", "ptr_c": "ptr_out", "c_stride": 
attrs["ldc"]})
 
-    if is_gelu:
+    if is_gelu or has_residual_block:
         # GeLU epilogue does not compile with NoBetaScaling, so we explicitly 
specify the scale.
-        aux_map["beta"] = "1"
+        aux_map["beta"] = 1
     else:
-        aux_map["beta"] = "0"
+        aux_map["beta"] = 0
 
-    if has_bias and not is_gelu:
+    if has_bias and not is_gelu and not has_residual_block:
         aux_map["alpha_beta"] = "alpha"
     else:
         aux_map["alpha_beta"] = "alpha, beta"
@@ -341,7 +399,15 @@ def instantiate_gemm_template(attrs):
     if batched:
         attrs["split_k_slices_or_batch"] = attrs["batch"]
     else:
-        attrs["split_k_slices_or_batch"] = "1"
+        attrs["split_k_slices_or_batch"] = 1
+
+    if has_residual_block:
+        assert not batched, "Residual fusion is supported only for non-batched 
GEMM for now."
+        template = substitute_template(template, {"argument": 
argument_template_residual})
+        aux_map["residual_decl"] = "void* ptr_residual = 
(void*)(${residual_arg}->data);\n"
+    else:
+        template = substitute_template(template, {"argument": 
argument_template_default})
+        aux_map["residual_decl"] = ""
 
     template = substitute_template(template, aux_map)
 
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py 
b/python/tvm/contrib/cutlass/gen_gemm.py
index 6aa4c51221..f5f160a400 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -53,7 +53,36 @@ def create_gemm_operator_with_epilogue(
     if batched:
         swizzling_functor = SwizzlingFunctor.Batched
 
-    epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+    if "residual" in op_type:
+        if "hardswish" in op_type:
+            activation = "cutlass::epilogue::thread::HardSwish"
+        elif "silu" in op_type:
+            activation = "cutlass::epilogue::thread::SiLu"
+        elif "sigmoid" in op_type:
+            activation = "cutlass::epilogue::thread::Sigmoid"
+        elif "gelu" in op_type:
+            activation = "cutlass::epilogue::thread::GELU"
+        elif "relu" in op_type:
+            activation = "cutlass::epilogue::thread::ReLu"
+        else:
+            activation = "cutlass::epilogue::thread::Identity"
+
+        binary_op = "cutlass::multiplies" if "residual_multiply" in op_type 
else "cutlass::plus"
+        unary_op = (
+            "cutlass::epilogue::thread::ReLu"
+            if op_type.endswith("relu")
+            else "cutlass::epilogue::thread::Identity"
+        )
+        residual_block_info = {
+            "activation": activation,
+            "binary_op": binary_op,
+            "unary_op": unary_op,
+        }
+        epilogue = EpilogueFunctor.LinearCombinationResidualBlock
+        no_beta_scaling = False
+    else:
+        residual_block_info = None
+        epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
 
     op = GemmOperation(
         tile_description.minimum_compute_capability,
@@ -68,7 +97,12 @@ def create_gemm_operator_with_epilogue(
 
     return (
         op.procedural_name(),
-        EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, 
batched=batched),
+        EmitGemmInstance().emit(
+            op,
+            no_beta_scaling=no_beta_scaling,
+            batched=batched,
+            residual_block_info=residual_block_info,
+        ),
     )
 
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 9833600be0..6b2587a0b0 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -535,7 +535,9 @@ def instantiate_template(func_name, annotations, func_args):
         transposed = "transposed" in func_name
         lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 
0)
         rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 
1)
-        bias_arg_idx = _get_optional_int_annotation(annotations, 
"bias_arg_idx", 2)
+        bias_arg_idx = _get_optional_int_annotation(annotations, 
"bias_arg_idx", None)
+        residual_arg_idx = _get_optional_int_annotation(annotations, 
"residual_arg_idx", None)
+
         lhs_arg = func_args[lhs_arg_idx]
         rhs_arg = func_args[rhs_arg_idx]
         lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
@@ -545,8 +547,12 @@ def instantiate_template(func_name, annotations, 
func_args):
 
         attrs["lhs_arg"] = lhs_arg
         attrs["rhs_arg"] = rhs_arg
-        if len(func_args) > 2:
+
+        if bias_arg_idx is not None:
             attrs["bias_arg"] = func_args[bias_arg_idx]
+        if residual_arg_idx is not None:
+            attrs["residual_arg"] = func_args[residual_arg_idx]
+
         attrs["ElementInputA"] = 
DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]]
         attrs["ElementInputB"] = 
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
         attrs["ElementOutput"] = 
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
@@ -610,20 +616,24 @@ def instantiate_template(func_name, annotations, 
func_args):
         else:
             headers.append("cutlass/gemm/device/gemm.h")
 
+        if "residual" in func_name:
+            
headers.append("cutlass/gemm/device/gemm_universal_with_broadcast.h")
+
         code = instantiate_gemm_template(attrs)
         return CodegenResult(code, headers)
 
     elif "conv2d" in func_name:
         data_arg_idx = _get_optional_int_annotation(annotations, 
"data_arg_idx", 0)
         weight_arg_idx = _get_optional_int_annotation(annotations, 
"weight_arg_idx", 1)
-        bias_arg_idx = _get_optional_int_annotation(annotations, 
"bias_arg_idx", 2)
-        residual_arg_idx = _get_optional_int_annotation(annotations, 
"residual_arg_idx", 3)
+        bias_arg_idx = _get_optional_int_annotation(annotations, 
"bias_arg_idx", None)
+        residual_arg_idx = _get_optional_int_annotation(annotations, 
"residual_arg_idx", None)
 
         attrs["data_arg"] = func_args[data_arg_idx]
         attrs["weight_arg"] = func_args[weight_arg_idx]
-        if len(func_args) > bias_arg_idx:
+
+        if bias_arg_idx is not None:
             attrs["bias_arg"] = func_args[bias_arg_idx]
-        if len(func_args) > residual_arg_idx:
+        if residual_arg_idx is not None:
             attrs["residual_arg"] = func_args[residual_arg_idx]
 
         activation_shape = annotations[f"arg{data_arg_idx}_shape"]
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 29026b194d..e1b9226d68 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -195,17 +195,25 @@ def residual_block_patterns():
     patterns = []
 
     for activation, name_postfix in [(None, ""), ("relax.nn.relu", "_relu")]:
-        for name, pat, arg_pat, _ in conv2d_patterns()[1:]:
-            for bin_op in ["relax.add", "relax.multiply"]:
-                patterns.append(
-                    (
-                        name + "_residual_" + bin_op.split(".")[-1] + 
name_postfix,
-                        *make_residual_block_pattern(
-                            (pat, arg_pat), binary_op=bin_op, 
activation=activation
-                        ),
-                        _check_conv2d,
-                    )
-                )
+        for check, base_patterns in [
+            (_check_conv2d, conv2d_patterns()),
+            (_check_matmul, matmul_patterns()),
+        ]:
+            for name, pat, arg_pat, _ in base_patterns:
+                # Append residual patterns only to those base patterns with 
bias add,
+                # since conv2d or matmul + residual add without bias is 
already supported
+                # via conv2d or matmul + bias patterns (the residual input is 
treated as "bias").
+                if "bias" in name:
+                    for bin_op in ["relax.add", "relax.multiply"]:
+                        patterns.append(
+                            (
+                                name + "_residual_" + bin_op.split(".")[-1] + 
name_postfix,
+                                *make_residual_block_pattern(
+                                    (pat, arg_pat), binary_op=bin_op, 
activation=activation
+                                ),
+                                check,
+                            )
+                        )
 
     return patterns
 
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 02802211a9..de15f7083a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -101,6 +101,7 @@ def get_result_with_relax_cutlass_offload(mod, *args, 
assert_all_bindings_fused=
     assert len(patterns) != 0, "Cannot find cutlass patterns"
 
     mod = partition_for_cutlass(mod)
+
     if assert_all_bindings_fused:
         assert len(mod["main"].body.blocks[0].bindings) == 1
 
@@ -158,8 +159,8 @@ def get_relax_conv2d_module(
                     output = R.emit(activation(output))
                 if residual_bin_op is not None:
                     output = R.emit(residual_bin_op(output, data))
-                if residual_activation is not None:
-                    output = R.emit(residual_activation(output))
+                    if residual_activation is not None:
+                        output = R.emit(residual_activation(output))
                 R.output(output)
 
             R.func_ret_value(frame.output_vars[0])
@@ -169,7 +170,14 @@ def get_relax_conv2d_module(
 
 
 def get_relax_matmul_module(
-    x_shape, y_shape, dtype, transposed_y=False, with_bias=False, 
activation=None
+    x_shape,
+    y_shape,
+    dtype,
+    transposed_y=False,
+    with_bias=False,
+    activation=None,
+    residual_bin_op=None,
+    residual_activation=None,
 ):
     if transposed_y:
         n = y_shape[-2]
@@ -193,6 +201,10 @@ def get_relax_matmul_module(
                     result = R.emit(result + bias)
                 if activation is not None:
                     result = R.emit(activation(result))
+                if residual_bin_op is not None:
+                    result = R.emit(residual_bin_op(result, x))
+                    if residual_activation is not None:
+                        result = R.emit(residual_activation(result))
                 R.output(result)
 
             R.func_ret_value(frame.output_vars[0])
@@ -285,34 +297,40 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, 
epilogue, residual_bloc
 
 
 @pytest.mark.parametrize(
-    "x_shape, y_shape, transpose_y, epilogue",
+    "x_shape, y_shape, transpose_y, epilogue, residual_block",
     [
         # Regular
-        ((32, 6), (6, 16), False, "none"),
-        ((_vars["a"], 6), (6, 16), False, "bias"),
+        ((32, 6), (6, 16), False, "none", "none"),
+        ((_vars["a"], 6), (6, 16), False, "bias", "none"),
         # Transposed
-        ((4, 16), (16, 128), True, "relu"),
-        ((35, 8), (8, 8), True, "gelu"),
+        ((4, 16), (16, 128), True, "relu", "none"),
+        ((35, 8), (8, 8), True, "gelu", "none"),
         # 3D x 3D
-        ((6, 32, 8), (6, 8, 10), False, "bias"),
-        ((6, 32, 8), (6, 8, 10), True, "none"),
-        ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
+        ((6, 32, 8), (6, 8, 10), False, "bias", "none"),
+        ((6, 32, 8), (6, 8, 10), True, "none", "none"),
+        ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
         # 3D x 2D
-        ((6, 32, 8), (8, 10), False, "none"),
-        ((_vars["a"], 32, 8), (8, 10), False, "bias"),
-        ((10, 16, 8), (8, 10), True, "relu"),
+        ((6, 32, 8), (8, 10), False, "none", "none"),
+        ((_vars["a"], 32, 8), (8, 10), False, "bias", "none"),
+        ((10, 16, 8), (8, 10), True, "relu", "none"),
         # 2D x 3D
-        ((32, 8), (10, 8, 10), False, "relu"),
-        ((32, 8), (_vars["a"], 8, 10), True, "gelu"),
+        ((32, 8), (10, 8, 10), False, "relu", "none"),
+        ((32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
         # ND x 2D
-        ((3, 6, 32, 8), (8, 10), False, "bias"),
-        ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none"),
+        ((3, 6, 32, 8), (8, 10), False, "bias", "none"),
+        ((_vars["a"], _vars["b"], 6, 32, 8), (8, 10), False, "none", "none"),
         # 2D x ND
-        ((32, 8), (5, 3, 8, 10), False, "gelu"),
+        ((32, 8), (5, 3, 8, 10), False, "gelu", "none"),
         # ND x ND
-        ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
-        ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu"),
-        ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none"),
+        ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu", "none"),
+        ((3, 2, 4, 16, 15), (1, 1, 15, 2), True, "gelu", "none"),
+        ((1, 1, 16, 15), (3, 2, _vars["a"], 15, 2), False, "none", "none"),
+        # Residual
+        ((32, 8), (8, 8), False, "bias", "add"),
+        ((4, 16), (16, 16), True, "relu", "add_relu"),
+        # Residual fusion without bias - this is supported via the matmul + 
bias pattern
+        # where bias == residual input
+        ((4, 16), (16, 16), False, "none", "add"),
     ],
 )
 @pytest.mark.parametrize(
@@ -326,6 +344,7 @@ def test_matmul_offload(
     y_shape,
     transpose_y,
     epilogue,
+    residual_block,
     dtype,
 ):
     with_bias, activation = _epilogue_table[epilogue]
@@ -346,6 +365,8 @@ def test_matmul_offload(
         bias = None
         args = (x, y)
 
+    residual_bin_op, residual_activation = 
_residual_block_table[residual_block]
+
     mod = get_relax_matmul_module(
         x_shape,
         y_shape,
@@ -353,6 +374,8 @@ def test_matmul_offload(
         with_bias=with_bias,
         transposed_y=transpose_y,
         activation=activation,
+        residual_bin_op=residual_bin_op,
+        residual_activation=residual_activation,
     )
     out = get_result_with_relax_cutlass_offload(mod, *args)
     ref = build_and_run(mod, args, "llvm", legalize=True)

Reply via email to