jackwish commented on a change in pull request #4351: [QNN] Lowering for
Depthwise Convolution.
URL: https://github.com/apache/incubator-tvm/pull/4351#discussion_r347756284
##########
File path: src/relay/qnn/op/convolution.cc
##########
@@ -417,23 +565,33 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const
Array<Expr>& new_args,
param->kernel_layout == "HWOI")
<< "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
- int batch_size, in_channels, out_channels, kernel_h, kernel_w;
- std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w) =
+ int batch_size, in_channels, out_channels, kernel_h, kernel_w,
channel_multiplier;
+ std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w,
channel_multiplier) =
GetWorkload(arg_types, param);
- // Fallback to int32 conv if there is dilation or depthwise conv2d
+ // Fallback to int32 conv if there is dilation or grouped conv2d
+
CHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D
dilation";
auto dilation_h = get_const_int(param->dilation[0]);
auto dilation_w = get_const_int(param->dilation[1]);
- if (dilation_h != 1 || dilation_w != 1 || param->groups != 1) {
+ if (dilation_h != 1 || dilation_w != 1 || (param->groups != 1 &&
!is_depthwise(param))) {
return Conv2DFallBack(data, weight, param);
+ } else if (is_depthwise(param)) {
+ CHECK_NE(channel_multiplier, -1);
+ auto padded_data = Conv2DPadInput(data, param);
+ auto term1 = Conv2DFirstTerm(padded_data, weight, param);
+ auto term2 =
+ DepthwiseConv2DSecondTerm(padded_data, param, kernel_h, kernel_w,
channel_multiplier);
+ auto term3 = DepthwiseConv2DThirdTerm(weight, param, out_channels,
channel_multiplier);
+ auto term4 = DepthwiseConv2DFourthTerm(param, kernel_h, kernel_w);
+ return Conv2DCombineTerms(term1, term2, term3, term4, param);
}
auto padded_data = Conv2DPadInput(data, param);
auto term1 = Conv2DFirstTerm(padded_data, weight, param);
auto term2 = Conv2DSecondTerm(padded_data, param, kernel_h, kernel_w,
out_channels);
- auto term3 = Conv2DThirdTerm(weight, param, batch_size, out_channels);
- auto term4 = Conv2DFourthTerm(param, batch_size, in_channels, kernel_h,
kernel_w);
+ auto term3 = Conv2DThirdTerm(weight, param, out_channels);
+ auto term4 = Conv2DFourthTerm(param, in_channels, kernel_h, kernel_w);
Review comment:
That is interesting, why removing batch semantic?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services