This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e7f3648  [CUTLASS] Residual connection fusion (#9820)
e7f3648 is described below

commit e7f36487dfdb6c4b7b544be155d3869002d7281b
Author: masahi <[email protected]>
AuthorDate: Tue Jan 4 02:36:42 2022 +0900

    [CUTLASS] Residual connection fusion (#9820)
    
    * [CUTLASS] Support residual block fusion for conv2d
    
    commit d4a78a3e13530974e852b4c0480b7c8d0f792e68
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:33:41 2021 +0900
    
        fixed residual block check condition
    
    commit 6ee5a3913333e8ba2d5d0ed6842a58fe37baa547
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:25:04 2021 +0900
    
        minor fix
    
    commit 8af8b3078f11ee293d2e22d9e37e715c617ffb75
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:18:50 2021 +0900
    
        remove SimplifyExpr pass
    
    commit 20ae2d874917c69fabc6fcf03a3d47aff98eee91
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:16:46 2021 +0900
    
        fix bad merge
    
    commit 17eed222c5e69e7863c95563b638e5390c634b1b
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:13:53 2021 +0900
    
        black
    
    commit fda151b524cb28581256befa74575bbfa23efa4c
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 16:09:45 2021 +0900
    
        Support residual block fusion
    
    commit ce9d52fd629d6119abdd471b00ff6a79223d6752
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 15:56:32 2021 +0900
    
        Remove SimplifyExpr pass from the pipeline (makes DETR result nan)
    
    commit d3b681d95977b6fc0965a0a3ec8af3f866bd9e91
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 15:47:07 2021 +0900
    
        fix no_beta_scaling values
    
    commit 87b36dbbb11adb582ffb628fc6ad62668dcdee7e
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 14:59:40 2021 +0900
    
        fill in TODO doc
    
    commit fd67595831c7b8741f30577bc91488bcce34a76a
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Dec 23 14:31:06 2021 +0900
    
        Refactor cutlass kernel generation and selection
    
    * do not try to support broadcast binary op
    
    * add comments
    
    * remove residual input shape check
---
 3rdparty/cutlass                               |  2 +-
 python/tvm/contrib/cutlass/conv2d_operation.py | 45 +++++++++++----
 python/tvm/contrib/cutlass/gen_conv2d.py       | 31 +++++++++-
 python/tvm/contrib/cutlass/library.py          |  2 +
 python/tvm/relay/op/contrib/cutlass.py         | 48 +++++++++++++++-
 src/relay/backend/contrib/cutlass/codegen.cc   | 79 +++++++++++++++++++++++---
 src/relay/backend/utils.h                      | 19 ++++++-
 tests/python/contrib/test_cutlass.py           | 55 ++++++++++++++++--
 8 files changed, 252 insertions(+), 29 deletions(-)

diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index dceabd4..c2ee13a 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit dceabd4c5a2aa8cb29ce5a05311a57519baadddc
+Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 1c7f9a3..5318cc7 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -150,6 +150,7 @@ class EmitConv2dInstance:
       ${element_accumulator},
       ${element_epilogue}
     >"""
+
         self.epilogue_no_beta_scaling = """
     ${epilogue_functor}<
       ${element_c},
@@ -159,10 +160,22 @@ class EmitConv2dInstance:
       cutlass::epilogue::thread::ScaleType::NoBetaScaling
     >"""
 
+        self.epilogue_residual_block = """
+    ${epilogue_functor}<
+      ${element_c},
+      ${element_accumulator},
+      ${element_epilogue},
+      ${element_c},
+      ${epilogue_vector_length},
+      ${activation},
+      ${binary_op},
+      ${unary_op}
+    >"""
+
         self.template = """
   // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance 
"${operation_name}"
   using ${operation_name} =
-  typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
+  typename 
cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
     ${element_a},
     ${layout_a},
     ${element_b},
@@ -186,7 +199,7 @@ class EmitConv2dInstance:
   >::Kernel;
 """
 
-    def emit(self, operation, no_beta_scaling=False):
+    def emit(self, operation, no_beta_scaling=False, 
residual_block_info=False):
         """Instantiate a Conv2d kernel from given `operation`."""
         warp_shape = [
             int(
@@ -246,14 +259,26 @@ class EmitConv2dInstance:
             ],
             "align_a": str(operation.A.alignment),
             "align_b": str(operation.B.alignment),
+            "conv_kernel_postfix": "",
         }
 
-        template = substitute_template(
-            self.template,
-            {
-                "epilogue": self.epilogue_no_beta_scaling
-                if no_beta_scaling
-                else self.epilogue_default
-            },
-        )
+        if residual_block_info:
+            template = substitute_template(
+                self.template, {"epilogue": self.epilogue_residual_block}
+            )
+            values.update(
+                {
+                    "unary_op": residual_block_info["unary_op"],
+                    "binary_op": residual_block_info["binary_op"],
+                    "activation": residual_block_info["activation"],
+                    "conv_kernel_postfix": "WithBroadcast",
+                }
+            )
+        elif no_beta_scaling:
+            template = substitute_template(
+                self.template, {"epilogue": self.epilogue_no_beta_scaling}
+            )
+        else:
+            template = substitute_template(self.template, {"epilogue": 
self.epilogue_default})
+
         return substitute_template(template, values)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py 
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 4e4a7b2..39db9fd 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
     Instantiate a cutlass kernel from the given configuration,
     along with the epilouge functor
     """
-    epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+    if "residual" in op_type:
+        activation_map = {
+            "cutlass.conv2d_bias_hardswish": 
"cutlass::epilogue::thread::HardSwish",
+            "cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
+            "cutlass.conv2d_bias_sigmoid": 
"cutlass::epilogue::thread::Sigmoid",
+            "cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
+            "cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
+        }
+        prefix = op_type[: op_type.find("_residual")]
+        activation = activation_map[prefix]
+        binary_op = "cutlass::multiplies" if "residual_multiply" in op_type 
else "cutlass::plus"
+        unary_op = (
+            "cutlass::epilogue::thread::ReLu"
+            if op_type.endswith("relu")
+            else "cutlass::epilogue::thread::Identity"
+        )
+        residual_block_info = {
+            "activation": activation,
+            "binary_op": binary_op,
+            "unary_op": unary_op,
+        }
+        epilogue = EpilogueFunctor.LinearCombinationResidualBlock
+        no_beta_scaling = False
+    else:
+        residual_block_info = None
+        epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
 
     element_a, element_b, element_c, element_epilogue = data_type
 
@@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
     )
 
     name = op.procedural_name()
-    opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
+    opdef = EmitConv2dInstance().emit(
+        op, no_beta_scaling=no_beta_scaling, 
residual_block_info=residual_block_info
+    )
 
     return name, opdef
 
diff --git a/python/tvm/contrib/cutlass/library.py 
b/python/tvm/contrib/cutlass/library.py
index efc5dd5..08cdb32 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
     LinearCombinationSigmoid = enum_auto()
     LinearCombinationSilu = enum_auto()
     LinearCombinationHardSwish = enum_auto()
+    LinearCombinationResidualBlock = enum_auto()
 
 
 EpilogueFunctorTag = {
@@ -161,6 +162,7 @@ EpilogueFunctorTag = {
     EpilogueFunctor.LinearCombinationSigmoid: 
"cutlass::epilogue::thread::LinearCombinationSigmoid",
     EpilogueFunctor.LinearCombinationSilu: 
"cutlass::epilogue::thread::LinearCombinationSilu",
     EpilogueFunctor.LinearCombinationHardSwish: 
"cutlass::epilogue::thread::LinearCombinationHardSwish",
+    EpilogueFunctor.LinearCombinationResidualBlock: 
"cutlass::epilogue::thread::LinearCombinationResidualBlock",
 }
 
 
diff --git a/python/tvm/relay/op/contrib/cutlass.py 
b/python/tvm/relay/op/contrib/cutlass.py
index cbbc45a..31f0408 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name
 """Patterns supported CUTLASS."""
+from functools import partial
 from tvm import relay
 from tvm.ir.transform import Sequential, PassContext
 from tvm.relay import transform
@@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
     return conv2d_out
 
 
+def make_residual_block_pattern(tensor_op_out, binary_op="add", 
with_act="relu"):
+    """Add pattern for residual blocks."""
+    residual_input = wildcard()
+    binary_out = is_op(binary_op)(tensor_op_out, residual_input) | 
is_op(binary_op)(
+        residual_input, tensor_op_out
+    )
+
+    if with_act is not None and with_act == "relu":
+        return is_op("nn.relu")(binary_out)
+
+    return binary_out
+
+
 def check_dtype(lhs, rhs):
     """Check if dtypes in the given workload are supported by CUTLASS."""
     # Only fp16 inputs are supported for now.
@@ -139,6 +153,25 @@ def check_conv2d(call):
     return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)
 
 
+def check_conv2d_residual(call, binary_op):
+    """Check if the given conv2d workload can be offloaded to CUTLASS."""
+    conv2d = get_root_call(call, "nn.conv2d")
+    if not check_conv2d(call):
+        return False
+
+    residual_binop = get_root_call(call, binary_op)
+    lhs = residual_binop.args[0]
+    rhs = residual_binop.args[1]
+
+    # residual_input is pattern-matched as a wildcard. Make sure it does not 
sit between
+    # residual binary op and the root conv2d of this pattern.
+    # If the root conv2d is the parent of both lhs and rhs, we should reject 
this pattern.
+    if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, 
"nn.conv2d") == conv2d:
+        return False
+
+    return all(x == y for (x, y) in zip(lhs.checked_type.shape, 
rhs.checked_type.shape))
+
+
 def partition_for_cutlass(mod, params=None):
     """Partition the input module into CUTLASS-supported subgraphs."""
     dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
@@ -189,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
         ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
     ]
 
-    cutlass_patterns = dense_patterns + conv2d_patterns
+    residual_block_patterns = []
+
+    for with_act, postfix in [("relu", "_relu"), (None, "")]:
+        for name, pat, _ in conv2d_patterns[:-1]:
+            for bin_op in ["add", "multiply"]:
+                residual_block_patterns.append(
+                    (
+                        name + "_residual_" + bin_op + postfix,
+                        make_residual_block_pattern(pat, bin_op, 
with_act=with_act),
+                        partial(check_conv2d_residual, binary_op=bin_op),
+                    )
+                )
+
+    cutlass_patterns = residual_block_patterns + dense_patterns + 
conv2d_patterns
 
     if params is not None:
         mod["main"] = bind_params_by_name(mod["main"], params)
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc 
b/src/relay/backend/contrib/cutlass/codegen.cc
index a87ba2f..dc03eea 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -258,7 +258,7 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
 }
 
 std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
-                     const std::vector<std::string>& func_args) {
+                     const std::vector<std::string>& func_args, bool 
has_residual_block = false) {
   bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
   bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" 
&&
                          attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
@@ -268,8 +268,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
   CutlassPrint(conv2d_decl, "using ElementInputA = " + 
attrs.at("ElementInputA") + ";\n");
   CutlassPrint(conv2d_decl, "using ElementInputB = " + 
attrs.at("ElementInputB") + ";\n");
   CutlassPrint(conv2d_decl, "using ElementOutput = " + 
attrs.at("ElementOutput") + ";\n");
-
   CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + 
attrs.at("ElementOutput") + ";\n");
+
   CutlassPrint(conv2d_decl, attrs.at("op_def"));
   CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
                                 " = 
cutlass::conv::device::ImplicitGemmConvolution<" +
@@ -308,14 +308,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
   ICHECK(func_args.size() >= 2);
   CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + 
"->data);\n");
   CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + 
"->data);\n");
-  if (has_bias) {
+  if (has_residual_block) {
+    ICHECK(func_args.size() >= 4);
+    CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + 
"->data);\n");
+    CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + 
"->data);\n");
+  } else if (has_bias) {
     ICHECK(func_args.size() >= 3);
     CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + 
"->data);\n");
   }
 
   CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
   CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = 
ElementComputeEpilogue(1);\n");
-  if (has_bias && no_bias_scaling) {
+  if (has_bias && no_bias_scaling && !has_residual_block) {
     CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = 
ElementComputeEpilogue(0);\n");
   } else {
     CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = 
ElementComputeEpilogue(1);\n");
@@ -326,24 +330,38 @@ std::string Conv2dOp(std::string id, const Str2StrMap& 
attrs,
   CutlassPrint(conv2d_decl,
                "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, 
R, S, C)));\n");
   CutlassPrint(conv2d_decl,
-               "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, 
P, Q, K)));\n");
+               "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, 
P, Q, K)));\n\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, 
P, Q, K)));\n\n");
+
   CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
   CutlassPrint(conv2d_decl, " problem_size,\n");
   CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), 
layout_A},\n");
   CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), 
layout_B},\n");
-  if (has_bias) {
+
+  if (has_residual_block) {
+    CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual), 
layout_C},\n");
+  } else if (has_bias) {
     CutlassPrint(
         conv2d_decl,
         " {static_cast<ElementOutput*>(ptr_c_bias), 
cutlass::layout::TensorNHWC::Stride(0)},\n");
   } else {
-    CutlassPrint(conv2d_decl, " 
{static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+    CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out), 
layout_C},\n");
   }
-  CutlassPrint(conv2d_decl, " 
{static_cast<ElementOutput*>(ptr_out),layout_C},\n");
-  if (has_bias && no_bias_scaling) {
+
+  CutlassPrint(conv2d_decl, " 
{static_cast<ElementOutput*>(ptr_out),layout_D},\n");
+
+  if (has_residual_block) {
+    CutlassPrint(conv2d_decl, "{alpha, beta},\n");
+    CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n");  // 
split_k_slices
+    CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
+    CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
+  } else if (has_bias && no_bias_scaling) {
     CutlassPrint(conv2d_decl, " {alpha}\n};\n");
   } else {
     CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
   }
+
   CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
 
   CutlassPrint(conv2d_decl, "size_t workspace_size = 
conv2d_op.get_workspace_size(arguments);\n");
@@ -432,6 +450,21 @@ class CodegenCutlass : public 
MemoizedExprTranslator<std::vector<Output>>, publi
     return arg_names;
   }
 
+  bool IsConv2dResidualBlock(const std::string& func_name) {
+    return func_name.find("conv2d") != std::string::npos &&
+           func_name.find("residual") != std::string::npos;
+  }
+
+  // Is node `x` an ancestor of `y`?
+  bool IsAncestor(const CallNode* x, const CallNode* y) {
+    if (x == y) return true;
+    for (auto arg : y->args) {
+      const CallNode* arg_ptr = arg.as<CallNode>();
+      if (arg_ptr && IsAncestor(x, arg_ptr)) return true;
+    }
+    return false;
+  }
+
   GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
                                                    const CallNode* caller) {
     const auto pattern_name = 
callee->GetAttr<runtime::String>(attr::kComposite);
@@ -515,6 +548,30 @@ class CodegenCutlass : public 
MemoizedExprTranslator<std::vector<Output>>, publi
           GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", 
add_or_bias_add, "multiply"});
       return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", 
GetArgumentNames(caller),
                           Conv2dArgs(std::ref(attrs_)));
+    } else if (IsConv2dResidualBlock(pattern_name.value())) {
+      const CallNode* current_call = callee->body.as<CallNode>();
+      bool has_relu = current_call->args.size() == 1;
+      const CallNode* binop = has_relu ? current_call->args[0].as<CallNode>() 
: current_call;
+      ICHECK(binop->args.size() == 2);
+      // Figure out which of the first or second argument corresponds to the 
residual input
+      // The root conv2d call can be reached via the other input of the binary 
op
+      int residual_index;
+      if (binop->args[1].as<VarNode>()) {
+        residual_index = 1;
+      } else if (binop->args[0].as<VarNode>()) {
+        residual_index = 0;
+      } else {
+        const CallNode* lhs = binop->args[0].as<CallNode>();
+        const CallNode* rhs = binop->args[1].as<CallNode>();
+        ICHECK(lhs && rhs);
+        // The residual input should be an ancestor of the non-residual input
+        residual_index = IsAncestor(rhs, lhs) ? 1 : 0;
+      }
+      const auto* non_residual_input = 
binop->args[!residual_index].as<CallNode>();
+      const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d");
+      ICHECK(conv2d_call);
+      return GenerateBody(conv2d_call, pattern_name.value(), 
GetArgumentNames(caller),
+                          Conv2dArgs(std::ref(attrs_)));
     }
 
     LOG(FATAL) << "Unknown composite function: " << pattern_name;
@@ -560,6 +617,8 @@ class CodegenCutlass : public 
MemoizedExprTranslator<std::vector<Output>>, publi
       ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
     } else if (func_name == "cutlass_batch_matmul") {
       ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
+    } else if (IsConv2dResidualBlock(func_name)) {
+      ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true);
     } else if (func_name.find("conv2d") != std::string::npos) {
       ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
     }
@@ -623,6 +682,8 @@ class CutlassModuleCodegen : public 
CSourceModuleCodegenBase {
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_silu.h>\n";
     code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
+    code_stream_ << "#include 
<cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
+    code_stream_ << "#include 
<cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";
 
     ICHECK(ref->IsInstance<FunctionNode>());
     auto res = GenCutlassFunc(Downcast<Function>(ref));
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index df25a86..658283b 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -389,7 +389,6 @@ inline bool IsOp(const CallNode* call, const std::string& 
op_name) {
  * "nn.relu"}
  * \return A CallNode corresponding to the root op, whose name is 
expected_op_names[0]
  */
-
 inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
                                    const std::vector<std::string>& 
expected_op_names) {
   ICHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < 
expected_op_names.size() &&
@@ -406,6 +405,24 @@ inline const CallNode* GetRootCall(const CallNode* 
current_call, int depth,
 }
 
 /*!
+ * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in 
relu(add(conv2d))
+ * Unlike the previous definition, it does not verify operator names of 
intermediate nodes. Instead,
+ * it recursively visit child nodes until it finds a call node with the given 
op_name.
+ * \param call A Relay call node.
+ * \param op_name The name of an op to look for, such as ""nn.conv2d".
+ * \return A CallNode corresponding to the root op with the given op_name
+ */
+inline const CallNode* GetRootCall(const CallNode* current_call, const 
std::string& op_name) {
+  if (current_call == nullptr) return nullptr;
+  if (IsOp(current_call, op_name)) return current_call;
+
+  ICHECK_GT(current_call->args.size(), 0);
+
+  const auto* next_call = current_call->args[0].as<CallNode>();
+  return GetRootCall(next_call, op_name);
+}
+
+/*!
  * \brief Get the external symbol of the Relay function name.
  *
  * \param func The provided function.
diff --git a/tests/python/contrib/test_cutlass.py 
b/tests/python/contrib/test_cutlass.py
index b2bdb8c..54738dd 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -130,6 +130,17 @@ def get_conv2d_nchw_bias(d_shape, w_shape, padding, 
out_dtype="float16"):
     return relay.nn.bias_add(conv2d, bias)
 
 
+def silu(x):
+    return x * relay.sigmoid(x)
+
+
+def hardswish(x, out_dtype="float16"):
+    return x * (
+        relay.clip(x + relay.const(3, dtype=out_dtype), a_min=0, a_max=6)
+        / relay.const(6, dtype=out_dtype)
+    )
+
+
 def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"):
     return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, 
out_dtype=out_dtype))
 
@@ -140,15 +151,29 @@ def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, 
padding, out_dtype="float16")
 
 def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"):
     conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, 
out_dtype=out_dtype)
-    return conv_out * relay.sigmoid(conv_out)
+    return silu(conv_out)
 
 
 def get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding, 
out_dtype="float16"):
-    conv2d_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, 
out_dtype=out_dtype)
-    return conv2d_out * (
-        relay.clip(conv2d_out + relay.const(3, dtype=out_dtype), a_min=0, 
a_max=6)
-        / relay.const(6, dtype=out_dtype)
+    conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, 
out_dtype=out_dtype)
+    return hardswish(conv_out, out_dtype)
+
+
+def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding, 
out_dtype="float16"):
+    data = relay.var("data", shape=d_shape, dtype="float16")
+    weight = relay.var("weight", shape=w_shape, dtype="float16")
+    bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype)
+    out_channel = w_shape[0]
+    conv2d = relay.nn.conv2d(
+        data=data,
+        weight=weight,
+        kernel_size=w_shape[2:],
+        channels=out_channel,
+        padding=padding,
+        out_dtype=out_dtype,
     )
+    bias_add = relay.nn.bias_add(conv2d, bias)
+    return bias_add, data
 
 
 def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", 
use_fast_math=False):
@@ -492,5 +517,25 @@ def test_conv2d_fusion():
     )
 
 
+def test_conv2d_residual_block():
+    d_shape = (16, 16, 32, 32)
+    w_shape = (16, 16, 3, 3)
+    padding = (1, 1)
+
+    bias_add, residual_input = get_conv2d_nchw_bias_residual(d_shape, w_shape, 
padding)
+
+    for func, tol in [
+        (relay.nn.relu(bias_add + residual_input), 1e-5),
+        (relay.nn.relu(bias_add) + residual_input, 1e-5),
+        (relay.sigmoid(bias_add) * residual_input, 1e-5),
+        (relay.nn.relu(silu(bias_add) * residual_input), 1e-5),
+        # HardSwish requires higher tolerance since vectoring the residual 
block epilogue
+        # in cutlass.
+        # TODO(masahi): Invesitigate this issue
+        (relay.nn.relu(hardswish(bias_add) + residual_input), 1e-3),
+    ]:
+        verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol, 
run_benchmark=False)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to