This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c448c50cd1 [Unity][CUTLASS] Support more residual input shape (#14968)
c448c50cd1 is described below

commit c448c50cd19cc892b77bbcb68dc20e995fd67f40
Author: masahi <[email protected]>
AuthorDate: Sun May 28 11:06:12 2023 +0900

    [Unity][CUTLASS] Support more residual input shape (#14968)
    
    * improved cutlass residual fusion
    
    * update cutlass residual test
    
    * update residual check
    
    * fix residual check
    
    * clean
    
    * fix
    
    * minor
---
 python/tvm/contrib/cutlass/build.py            | 48 ++++++++++-------
 python/tvm/contrib/cutlass/conv2d_operation.py | 25 ++++++++-
 python/tvm/contrib/cutlass/gen_tensor_op.py    |  3 ++
 python/tvm/relax/backend/contrib/cutlass.py    | 11 +++-
 tests/python/relax/test_codegen_cutlass.py     | 74 +++++++++++++++++---------
 5 files changed, 111 insertions(+), 50 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 1d105bbe82..07583a4851 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -668,26 +668,34 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             use_multiprocessing=use_multiprocessing,
         )
 
-        return f.with_attrs(
-            {
-                "op_type": op_type,
-                "data_arg_idx": arg_idx["lhs"],
-                "weight_arg_idx": arg_idx["rhs"],
-                "bias_arg_idx": arg_idx.get("bias"),
-                "residual_arg_idx": arg_idx.get("residual"),
-                "arg0_dtype": data_dtype,
-                "arg1_dtype": weight_dtype,
-                "ret_dtype": out_dtype,
-                "arg0_shape": d_shape,
-                "arg1_shape": w_shape,
-                "ret_shape": out_shape,
-                "strides": strides,
-                "padding": padding,
-                "dilation": dilation,
-                "cutlass_op_name": op_name,
-                "cutlass_op_def": op_def,
-            }
-        )
+        attrs = {
+            "op_type": op_type,
+            "data_arg_idx": arg_idx["lhs"],
+            "weight_arg_idx": arg_idx["rhs"],
+            "bias_arg_idx": arg_idx.get("bias"),
+            "residual_arg_idx": arg_idx.get("residual"),
+            "arg0_dtype": data_dtype,
+            "arg1_dtype": weight_dtype,
+            "ret_dtype": out_dtype,
+            "arg0_shape": d_shape,
+            "arg1_shape": w_shape,
+            "ret_shape": out_shape,
+            "strides": strides,
+            "padding": padding,
+            "dilation": dilation,
+            "cutlass_op_name": op_name,
+            "cutlass_op_def": op_def,
+        }
+
+        residual_arg = arg_idx.get("residual")
+
+        if residual_arg:
+            residual_shape = signature[f"arg{residual_arg}_shape"]
+            attrs["residual_shape"] = residual_shape
+        elif "residual" in op_type:
+            attrs["residual_shape"] = d_shape
+
+        return f.with_attrs(attrs)
 
     def handle_matmul(self, f, op_type):
         """Tune and annotate a dense op."""
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index f2d2f01276..8ae9b1414d 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -394,11 +394,12 @@ def instantiate_conv2d_template(attrs):
   auto activation_shape = TensorNHWC::packed(cutlass::make_Coord(N, H, W, C));
   auto weight_shape = TensorNHWC::packed(cutlass::make_Coord(K, R, S, C));
   auto output_shape = TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K));
+  ${residual_shape_decl}
 
   TensorNHWC layout_A(${A_shape});
   TensorNHWC layout_B(${B_shape});
   TensorNHWC layout_C(${C_shape});
-  TensorNHWC layout_D(${C_shape});
+  TensorNHWC layout_D(${D_shape});
 
   using ElementOutput = ${ElementOutput};
   cutlass::TensorRef<ElementOutput, TensorNHWC> 
tensor_c{static_cast<ElementOutput*>(${tensor_c}), ${tensor_c_layout}};
@@ -506,18 +507,38 @@ def instantiate_conv2d_template(attrs):
     else:
         aux_map["additional_args"] = ""
 
+    aux_map["residual_shape_decl"] = ""
+
     if is_wgrad:
         aux_map["A_shape"] = "output_shape"
         aux_map["B_shape"] = "activation_shape"
         aux_map["C_shape"] = "weight_shape"
+        aux_map["D_shape"] = "weight_shape"
     elif is_dgrad:
         aux_map["A_shape"] = "output_shape"
         aux_map["B_shape"] = "weight_shape"
         aux_map["C_shape"] = "activation_shape"
+        aux_map["D_shape"] = "activation_shape"
     else:
         aux_map["A_shape"] = "activation_shape"
         aux_map["B_shape"] = "weight_shape"
-        aux_map["C_shape"] = "output_shape"
+        aux_map["D_shape"] = "output_shape"
+
+        if has_residual_block:
+            res_shape = list(attrs.pop("residual_shape"))
+            shape_str = f"cutlass::make_Coord({res_shape[0]}, {res_shape[1]}, 
{res_shape[2]}, K)"
+            aux_map[
+                "residual_shape_decl"
+            ] = f"auto residual_shape = TensorNHWC::packed({shape_str});"
+            aux_map["C_shape"] = "residual_shape"
+
+            if res_shape == [int(attrs[c]) for c in ["N", "H", "W", "K"]]:
+                aux_map["tensor_c_layout"] = "layout_C"
+            else:
+                # bias-like residual input
+                aux_map["tensor_c_layout"] = 
"cutlass::layout::TensorNHWC::Stride(0)"
+        else:
+            aux_map["C_shape"] = "output_shape"
 
     if use_split_k:
         aux_map["ElementOutput"] = "EpilogueOutputOp::ElementOutput"
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index abc13d3570..0caf62043e 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -689,6 +689,9 @@ def instantiate_template(func_name, annotations, func_args):
             attrs["split_k_mode"] = "kSerial"
             attrs["split_k_slices"] = 1
 
+        if "residual_shape" in annotations:
+            attrs["residual_shape"] = annotations["residual_shape"]
+
         code = instantiate_conv2d_template(attrs)
         return CodegenResult(code, headers)
 
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index dd2739fd98..dffd7c401c 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -47,6 +47,10 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
     )
 
 
+def _shape_1d(shape):
+    return reduce(operator.mul, shape, 1)
+
+
 def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
     """
     Check whether intermediate variables in the region to be fused are used 
outside
@@ -104,7 +108,10 @@ def _check_residual(root_call: Call, context: 
PatternCheckContext) -> bool:
         shape1 = [int(s) for s in root_var.struct_info.shape]
         shape2 = [int(s) for s in residual.struct_info.shape]
 
-        if shape1 != shape2:
+        out_channel = shape1[-1]
+        is_bias_like = lambda shape: (shape[-1] == out_channel and 
_shape_1d(shape) == out_channel)
+
+        if shape1 != shape2 and not is_bias_like(shape2):
             return False
 
     return True
@@ -402,7 +409,7 @@ class WorkspaceAnnotator(PyExprMutator):
         if "attention" in f.attrs["Composite"]:
             # Workspace is needed only for larger head sizes, but for 
simplicity we always allocate.
             out_dtype = f.ret_struct_info.dtype
-            out_size_1d = reduce(operator.mul, f.ret_struct_info.shape, 1)
+            out_size_1d = _shape_1d(f.ret_struct_info.shape)
             # This needs to be in sync with the actual value that the kernel 
expects.
             workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 
4}[out_dtype]
             return f.with_attr("WorkspaceSize", workspace_size_bytes)
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 339b22ca9b..8a1675ad35 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1011,35 +1011,57 @@ def 
test_attention_rewrite_offload(attention_rewrite_size):
         tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, 
atol=1e-5)
 
 
-def test_invalid_residual():
-    @tvm.script.ir_module
-    class Module:
-        @R.function
-        def main(
-            x: R.Tensor((2, 64, 64, 8), dtype="float16"),
-            w: R.Tensor((8, 3, 3, 8), dtype="float16"),
-            bias: R.Tensor((1, 1, 8), dtype="float16"),
-            residual: R.Tensor((2, 1, 1, 8), dtype="float16"),
-        ) -> R.Tensor((1, 256, 64, 64), dtype="float16"):
-            with R.dataflow():
-                conv = R.nn.conv2d(
-                    x,
-                    w,
-                    padding=[1, 1, 1, 1],
-                    out_dtype="float16",
-                    data_layout="NHWC",
-                    kernel_layout="OHWI",
+def test_conv2d_residual_broadcast():
+    data_shape = (2, 64, 64, 8)
+    weight_shape = (8, 3, 3, 8)
+    dtype = "float16"
+
+    def get_mod(residual_batch):
+        with IRBuilder() as builder:
+            with relax_builder.function():
+                R.func_name("main")
+                data = R.arg("data", R.Tensor(data_shape, dtype))
+                weight = R.arg("weight", R.Tensor(weight_shape, dtype))
+                bias = R.arg("bias", R.Tensor((1, 1, weight_shape[0]), dtype))
+                residual = R.arg(
+                    "residual", R.Tensor((residual_batch, 1, 1, 
weight_shape[0]), dtype)
                 )
-                bias_out = R.add(conv, bias)
-                out = R.add(bias_out, residual)
-                R.output(out)
-            return out
 
-    rewritten = partition_for_cutlass(Module)
-    func_names = [gv.name_hint for gv in rewritten.functions.keys()]
+                with R.dataflow() as frame:
+                    output = R.emit(
+                        R.nn.conv2d(
+                            data,
+                            weight,
+                            out_dtype=dtype,
+                            padding=(1, 1),
+                            data_layout="NHWC",
+                            kernel_layout="OHWI",
+                        )
+                    )
+                    output = R.emit(output + bias)
+                    output = R.emit(R.nn.relu(output))
+                    output = R.emit(R.add(output, residual))
+                    R.output(output)
+
+                R.func_ret_value(frame.output_vars[0])
+
+        func = builder.get()
+        return tvm.IRModule({"main": func})
 
-    assert "fused_relax_nn_conv2d_relax_add_relax_add_cutlass" not in 
func_names
-    assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
+    low = -1
+    high = 1
+
+    residual_batch = 1
+    mod = get_mod(residual_batch)
+    data = np.random.randint(low, high, size=data_shape).astype(dtype)
+    weight = np.random.randint(low, high, size=weight_shape).astype(dtype)
+    bias = np.random.randint(low, high, size=(1, 1, 
weight_shape[0])).astype(dtype)
+    bias2 = np.random.randint(low, high, size=(residual_batch, 1, 1, 
weight_shape[0])).astype(dtype)
+
+    args = [data, weight, bias, bias2]
+    out = get_result_with_relax_cutlass_offload(mod, *args)
+    ref = build_and_run(mod, args, "llvm")
+    tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
 
 @pytest.mark.parametrize(

Reply via email to