This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7f3648 [CUTLASS] Residual connection fusion (#9820)
e7f3648 is described below
commit e7f36487dfdb6c4b7b544be155d3869002d7281b
Author: masahi <[email protected]>
AuthorDate: Tue Jan 4 02:36:42 2022 +0900
[CUTLASS] Residual connection fusion (#9820)
* [CUTLASS] Support residual block fusion for conv2d
commit d4a78a3e13530974e852b4c0480b7c8d0f792e68
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:33:41 2021 +0900
fixed residual block check condition
commit 6ee5a3913333e8ba2d5d0ed6842a58fe37baa547
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:25:04 2021 +0900
minor fix
commit 8af8b3078f11ee293d2e22d9e37e715c617ffb75
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:18:50 2021 +0900
remove SimplifyExpr pass
commit 20ae2d874917c69fabc6fcf03a3d47aff98eee91
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:16:46 2021 +0900
fix bad merge
commit 17eed222c5e69e7863c95563b638e5390c634b1b
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:13:53 2021 +0900
black
commit fda151b524cb28581256befa74575bbfa23efa4c
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 16:09:45 2021 +0900
Support residual block fusion
commit ce9d52fd629d6119abdd471b00ff6a79223d6752
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 15:56:32 2021 +0900
Remove SimplifyExpr pass from the pipeline (makes DETR result nan)
commit d3b681d95977b6fc0965a0a3ec8af3f866bd9e91
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 15:47:07 2021 +0900
fix no_beta_scaling values
commit 87b36dbbb11adb582ffb628fc6ad62668dcdee7e
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 14:59:40 2021 +0900
fill in TODO doc
commit fd67595831c7b8741f30577bc91488bcce34a76a
Author: Masahiro Masuda <[email protected]>
Date: Thu Dec 23 14:31:06 2021 +0900
Refactor cutlass kernel generation and selection
* do not try to support broadcast binary op
* add comments
* remove residual input shape check
---
3rdparty/cutlass | 2 +-
python/tvm/contrib/cutlass/conv2d_operation.py | 45 +++++++++++----
python/tvm/contrib/cutlass/gen_conv2d.py | 31 +++++++++-
python/tvm/contrib/cutlass/library.py | 2 +
python/tvm/relay/op/contrib/cutlass.py | 48 +++++++++++++++-
src/relay/backend/contrib/cutlass/codegen.cc | 79 +++++++++++++++++++++++---
src/relay/backend/utils.h | 19 ++++++-
tests/python/contrib/test_cutlass.py | 55 ++++++++++++++++--
8 files changed, 252 insertions(+), 29 deletions(-)
diff --git a/3rdparty/cutlass b/3rdparty/cutlass
index dceabd4..c2ee13a 160000
--- a/3rdparty/cutlass
+++ b/3rdparty/cutlass
@@ -1 +1 @@
-Subproject commit dceabd4c5a2aa8cb29ce5a05311a57519baadddc
+Subproject commit c2ee13a0fe99241b0e798ce647acf98e237f1d0c
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 1c7f9a3..5318cc7 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -150,6 +150,7 @@ class EmitConv2dInstance:
${element_accumulator},
${element_epilogue}
>"""
+
self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
@@ -159,10 +160,22 @@ class EmitConv2dInstance:
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""
+ self.epilogue_residual_block = """
+ ${epilogue_functor}<
+ ${element_c},
+ ${element_accumulator},
+ ${element_epilogue},
+ ${element_c},
+ ${epilogue_vector_length},
+ ${activation},
+ ${binary_op},
+ ${unary_op}
+ >"""
+
self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance
"${operation_name}"
using ${operation_name} =
- typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
+ typename
cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
${element_a},
${layout_a},
${element_b},
@@ -186,7 +199,7 @@ class EmitConv2dInstance:
>::Kernel;
"""
- def emit(self, operation, no_beta_scaling=False):
+ def emit(self, operation, no_beta_scaling=False,
residual_block_info=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
@@ -246,14 +259,26 @@ class EmitConv2dInstance:
],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
+ "conv_kernel_postfix": "",
}
- template = substitute_template(
- self.template,
- {
- "epilogue": self.epilogue_no_beta_scaling
- if no_beta_scaling
- else self.epilogue_default
- },
- )
+ if residual_block_info:
+ template = substitute_template(
+ self.template, {"epilogue": self.epilogue_residual_block}
+ )
+ values.update(
+ {
+ "unary_op": residual_block_info["unary_op"],
+ "binary_op": residual_block_info["binary_op"],
+ "activation": residual_block_info["activation"],
+ "conv_kernel_postfix": "WithBroadcast",
+ }
+ )
+ elif no_beta_scaling:
+ template = substitute_template(
+ self.template, {"epilogue": self.epilogue_no_beta_scaling}
+ )
+ else:
+ template = substitute_template(self.template, {"epilogue":
self.epilogue_default})
+
return substitute_template(template, values)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py
b/python/tvm/contrib/cutlass/gen_conv2d.py
index 4e4a7b2..39db9fd 100644
--- a/python/tvm/contrib/cutlass/gen_conv2d.py
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
Instantiate a cutlass kernel from the given configuration,
along with the epilouge functor
"""
- epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
+ if "residual" in op_type:
+ activation_map = {
+ "cutlass.conv2d_bias_hardswish":
"cutlass::epilogue::thread::HardSwish",
+ "cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
+ "cutlass.conv2d_bias_sigmoid":
"cutlass::epilogue::thread::Sigmoid",
+ "cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
+ "cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
+ }
+ prefix = op_type[: op_type.find("_residual")]
+ activation = activation_map[prefix]
+ binary_op = "cutlass::multiplies" if "residual_multiply" in op_type
else "cutlass::plus"
+ unary_op = (
+ "cutlass::epilogue::thread::ReLu"
+ if op_type.endswith("relu")
+ else "cutlass::epilogue::thread::Identity"
+ )
+ residual_block_info = {
+ "activation": activation,
+ "binary_op": binary_op,
+ "unary_op": unary_op,
+ }
+ epilogue = EpilogueFunctor.LinearCombinationResidualBlock
+ no_beta_scaling = False
+ else:
+ residual_block_info = None
+ epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
element_a, element_b, element_c, element_epilogue = data_type
@@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
)
name = op.procedural_name()
- opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
+ opdef = EmitConv2dInstance().emit(
+ op, no_beta_scaling=no_beta_scaling,
residual_block_info=residual_block_info
+ )
return name, opdef
diff --git a/python/tvm/contrib/cutlass/library.py
b/python/tvm/contrib/cutlass/library.py
index efc5dd5..08cdb32 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()
+ LinearCombinationResidualBlock = enum_auto()
EpilogueFunctorTag = {
@@ -161,6 +162,7 @@ EpilogueFunctorTag = {
EpilogueFunctor.LinearCombinationSigmoid:
"cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu:
"cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish:
"cutlass::epilogue::thread::LinearCombinationHardSwish",
+ EpilogueFunctor.LinearCombinationResidualBlock:
"cutlass::epilogue::thread::LinearCombinationResidualBlock",
}
diff --git a/python/tvm/relay/op/contrib/cutlass.py
b/python/tvm/relay/op/contrib/cutlass.py
index cbbc45a..31f0408 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
+from functools import partial
from tvm import relay
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
@@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return conv2d_out
+def make_residual_block_pattern(tensor_op_out, binary_op="add",
with_act="relu"):
+ """Add pattern for residual blocks."""
+ residual_input = wildcard()
+ binary_out = is_op(binary_op)(tensor_op_out, residual_input) |
is_op(binary_op)(
+ residual_input, tensor_op_out
+ )
+
+ if with_act is not None and with_act == "relu":
+ return is_op("nn.relu")(binary_out)
+
+ return binary_out
+
+
def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
# Only fp16 inputs are supported for now.
@@ -139,6 +153,25 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)
+def check_conv2d_residual(call, binary_op):
+ """Check if the given conv2d workload can be offloaded to CUTLASS."""
+ conv2d = get_root_call(call, "nn.conv2d")
+ if not check_conv2d(call):
+ return False
+
+ residual_binop = get_root_call(call, binary_op)
+ lhs = residual_binop.args[0]
+ rhs = residual_binop.args[1]
+
+ # residual_input is pattern-matched as a wildcard. Make sure it does not
sit between
+ # residual binary op and the root conv2d of this pattern.
+ # If the root conv2d is the parent of both lhs and rhs, we should reject
this pattern.
+ if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs,
"nn.conv2d") == conv2d:
+ return False
+
+ return all(x == y for (x, y) in zip(lhs.checked_type.shape,
rhs.checked_type.shape))
+
+
def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
@@ -189,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]
- cutlass_patterns = dense_patterns + conv2d_patterns
+ residual_block_patterns = []
+
+ for with_act, postfix in [("relu", "_relu"), (None, "")]:
+ for name, pat, _ in conv2d_patterns[:-1]:
+ for bin_op in ["add", "multiply"]:
+ residual_block_patterns.append(
+ (
+ name + "_residual_" + bin_op + postfix,
+ make_residual_block_pattern(pat, bin_op,
with_act=with_act),
+ partial(check_conv2d_residual, binary_op=bin_op),
+ )
+ )
+
+ cutlass_patterns = residual_block_patterns + dense_patterns +
conv2d_patterns
if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc
b/src/relay/backend/contrib/cutlass/codegen.cc
index a87ba2f..dc03eea 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -258,7 +258,7 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
}
std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
- const std::vector<std::string>& func_args) {
+ const std::vector<std::string>& func_args, bool
has_residual_block = false) {
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid"
&&
attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
@@ -268,8 +268,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
CutlassPrint(conv2d_decl, "using ElementInputA = " +
attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " +
attrs.at("ElementInputB") + ";\n");
CutlassPrint(conv2d_decl, "using ElementOutput = " +
attrs.at("ElementOutput") + ";\n");
-
CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " +
attrs.at("ElementOutput") + ";\n");
+
CutlassPrint(conv2d_decl, attrs.at("op_def"));
CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
" =
cutlass::conv::device::ImplicitGemmConvolution<" +
@@ -308,14 +308,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] +
"->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] +
"->data);\n");
- if (has_bias) {
+ if (has_residual_block) {
+ ICHECK(func_args.size() >= 4);
+ CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] +
"->data);\n");
+ CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] +
"->data);\n");
+ } else if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] +
"->data);\n");
}
CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha =
ElementComputeEpilogue(1);\n");
- if (has_bias && no_bias_scaling) {
+ if (has_bias && no_bias_scaling && !has_residual_block) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta =
ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta =
ElementComputeEpilogue(1);\n");
@@ -326,24 +330,38 @@ std::string Conv2dOp(std::string id, const Str2StrMap&
attrs,
CutlassPrint(conv2d_decl,
"TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K,
R, S, C)));\n");
CutlassPrint(conv2d_decl,
- "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N,
P, Q, K)));\n");
+ "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N,
P, Q, K)));\n\n");
+ CutlassPrint(conv2d_decl,
+ "TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N,
P, Q, K)));\n\n");
+
CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a),
layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b),
layout_B},\n");
- if (has_bias) {
+
+ if (has_residual_block) {
+ CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual),
layout_C},\n");
+ } else if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias),
cutlass::layout::TensorNHWC::Stride(0)},\n");
} else {
- CutlassPrint(conv2d_decl, "
{static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+ CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),
layout_C},\n");
}
- CutlassPrint(conv2d_decl, "
{static_cast<ElementOutput*>(ptr_out),layout_C},\n");
- if (has_bias && no_bias_scaling) {
+
+ CutlassPrint(conv2d_decl, "
{static_cast<ElementOutput*>(ptr_out),layout_D},\n");
+
+ if (has_residual_block) {
+ CutlassPrint(conv2d_decl, "{alpha, beta},\n");
+ CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); //
split_k_slices
+ CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
+ CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
+ } else if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}
+
CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
CutlassPrint(conv2d_decl, "size_t workspace_size =
conv2d_op.get_workspace_size(arguments);\n");
@@ -432,6 +450,21 @@ class CodegenCutlass : public
MemoizedExprTranslator<std::vector<Output>>, publi
return arg_names;
}
+ bool IsConv2dResidualBlock(const std::string& func_name) {
+ return func_name.find("conv2d") != std::string::npos &&
+ func_name.find("residual") != std::string::npos;
+ }
+
+ // Is node `x` an ancestor of `y`?
+ bool IsAncestor(const CallNode* x, const CallNode* y) {
+ if (x == y) return true;
+ for (auto arg : y->args) {
+ const CallNode* arg_ptr = arg.as<CallNode>();
+ if (arg_ptr && IsAncestor(x, arg_ptr)) return true;
+ }
+ return false;
+ }
+
GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name =
callee->GetAttr<runtime::String>(attr::kComposite);
@@ -515,6 +548,30 @@ class CodegenCutlass : public
MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d",
add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish",
GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
+ } else if (IsConv2dResidualBlock(pattern_name.value())) {
+ const CallNode* current_call = callee->body.as<CallNode>();
+ bool has_relu = current_call->args.size() == 1;
+ const CallNode* binop = has_relu ? current_call->args[0].as<CallNode>()
: current_call;
+ ICHECK(binop->args.size() == 2);
+ // Figure out which of the first or second argument corresponds to the
residual input
+ // The root conv2d call can be reached via the other input of the binary
op
+ int residual_index;
+ if (binop->args[1].as<VarNode>()) {
+ residual_index = 1;
+ } else if (binop->args[0].as<VarNode>()) {
+ residual_index = 0;
+ } else {
+ const CallNode* lhs = binop->args[0].as<CallNode>();
+ const CallNode* rhs = binop->args[1].as<CallNode>();
+ ICHECK(lhs && rhs);
+ // The residual input should be an ancestor of the non-residual input
+ residual_index = IsAncestor(rhs, lhs) ? 1 : 0;
+ }
+ const auto* non_residual_input =
binop->args[!residual_index].as<CallNode>();
+ const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d");
+ ICHECK(conv2d_call);
+ return GenerateBody(conv2d_call, pattern_name.value(),
GetArgumentNames(caller),
+ Conv2dArgs(std::ref(attrs_)));
}
LOG(FATAL) << "Unknown composite function: " << pattern_name;
@@ -560,6 +617,8 @@ class CodegenCutlass : public
MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
+ } else if (IsConv2dResidualBlock(func_name)) {
+ ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true);
} else if (func_name.find("conv2d") != std::string::npos) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}
@@ -623,6 +682,8 @@ class CutlassModuleCodegen : public
CSourceModuleCodegenBase {
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_silu.h>\n";
code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
+ code_stream_ << "#include
<cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
+ code_stream_ << "#include
<cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";
ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index df25a86..658283b 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -389,7 +389,6 @@ inline bool IsOp(const CallNode* call, const std::string&
op_name) {
* "nn.relu"}
* \return A CallNode corresponding to the root op, whose name is
expected_op_names[0]
*/
-
inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
const std::vector<std::string>&
expected_op_names) {
ICHECK(current_call && depth >= 0 && static_cast<size_t>(depth) <
expected_op_names.size() &&
@@ -406,6 +405,24 @@ inline const CallNode* GetRootCall(const CallNode*
current_call, int depth,
}
/*!
+ * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in
relu(add(conv2d))
+ * Unlike the previous definition, it does not verify operator names of
intermediate nodes. Instead,
+ * it recursively visit child nodes until it finds a call node with the given
op_name.
+ * \param call A Relay call node.
+ * \param op_name The name of an op to look for, such as ""nn.conv2d".
+ * \return A CallNode corresponding to the root op with the given op_name
+ */
+inline const CallNode* GetRootCall(const CallNode* current_call, const
std::string& op_name) {
+ if (current_call == nullptr) return nullptr;
+ if (IsOp(current_call, op_name)) return current_call;
+
+ ICHECK_GT(current_call->args.size(), 0);
+
+ const auto* next_call = current_call->args[0].as<CallNode>();
+ return GetRootCall(next_call, op_name);
+}
+
+/*!
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
diff --git a/tests/python/contrib/test_cutlass.py
b/tests/python/contrib/test_cutlass.py
index b2bdb8c..54738dd 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -130,6 +130,17 @@ def get_conv2d_nchw_bias(d_shape, w_shape, padding,
out_dtype="float16"):
return relay.nn.bias_add(conv2d, bias)
+def silu(x):
+ return x * relay.sigmoid(x)
+
+
+def hardswish(x, out_dtype="float16"):
+ return x * (
+ relay.clip(x + relay.const(3, dtype=out_dtype), a_min=0, a_max=6)
+ / relay.const(6, dtype=out_dtype)
+ )
+
+
def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"):
return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding,
out_dtype=out_dtype))
@@ -140,15 +151,29 @@ def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape,
padding, out_dtype="float16")
def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"):
conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding,
out_dtype=out_dtype)
- return conv_out * relay.sigmoid(conv_out)
+ return silu(conv_out)
def get_conv2d_nchw_bias_hardswish(d_shape, w_shape, padding,
out_dtype="float16"):
- conv2d_out = get_conv2d_nchw_bias(d_shape, w_shape, padding,
out_dtype=out_dtype)
- return conv2d_out * (
- relay.clip(conv2d_out + relay.const(3, dtype=out_dtype), a_min=0,
a_max=6)
- / relay.const(6, dtype=out_dtype)
+ conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding,
out_dtype=out_dtype)
+ return hardswish(conv_out, out_dtype)
+
+
+def get_conv2d_nchw_bias_residual(d_shape, w_shape, padding,
out_dtype="float16"):
+ data = relay.var("data", shape=d_shape, dtype="float16")
+ weight = relay.var("weight", shape=w_shape, dtype="float16")
+ bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype)
+ out_channel = w_shape[0]
+ conv2d = relay.nn.conv2d(
+ data=data,
+ weight=weight,
+ kernel_size=w_shape[2:],
+ channels=out_channel,
+ padding=padding,
+ out_dtype=out_dtype,
)
+ bias_add = relay.nn.bias_add(conv2d, bias)
+ return bias_add, data
def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so",
use_fast_math=False):
@@ -492,5 +517,25 @@ def test_conv2d_fusion():
)
+def test_conv2d_residual_block():
+ d_shape = (16, 16, 32, 32)
+ w_shape = (16, 16, 3, 3)
+ padding = (1, 1)
+
+ bias_add, residual_input = get_conv2d_nchw_bias_residual(d_shape, w_shape,
padding)
+
+ for func, tol in [
+ (relay.nn.relu(bias_add + residual_input), 1e-5),
+ (relay.nn.relu(bias_add) + residual_input, 1e-5),
+ (relay.sigmoid(bias_add) * residual_input, 1e-5),
+ (relay.nn.relu(silu(bias_add) * residual_input), 1e-5),
+ # HardSwish requires higher tolerance since vectoring the residual
block epilogue
+ # in cutlass.
+ # TODO(masahi): Invesitigate this issue
+ (relay.nn.relu(hardswish(bias_add) + residual_input), 1e-3),
+ ]:
+ verify_conv2d(func, func, d_shape, w_shape, sm=80, atol=tol, rtol=tol,
run_benchmark=False)
+
+
if __name__ == "__main__":
pytest.main([__file__])