mozga-intel commented on a change in pull request #20202:
URL: https://github.com/apache/incubator-mxnet/pull/20202#discussion_r637940730
##########
File path: src/operator/nn/mkldnn/mkldnn_pooling.cc
##########
@@ -106,127 +106,72 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam
¶m) {
switch (param.pool_type) {
case pool_enum::kMaxPooling:
return mkldnn::algorithm::pooling_max;
- break;
case pool_enum::kAvgPooling:
- if (param.count_include_pad.has_value() &&
!param.count_include_pad.value()) {
+ if (param.count_include_pad.has_value() &&
+ !param.count_include_pad.value()) {
return mkldnn::algorithm::pooling_avg_exclude_padding;
} else {
return mkldnn::algorithm::pooling_avg_include_padding;
}
- break;
default:
LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method.";
return mkldnn::algorithm::pooling_max;
}
}
+void prepareKernels(mkldnn::memory::dims &kernel, mkldnn::memory::dims
&strides,
+ mkldnn::memory::dims &pad_l, mkldnn::memory::dims &pad_r,
+ const PoolingParam ¶m, const mkldnn::memory::desc
&data_md, int kernel_ndims) {
+ CHECK_GE(param.pad.ndim(), kernel_ndims);
+ CHECK_GE(param.stride.ndim(), kernel_ndims);
+
+ for (int idx = 0; idx < kernel_ndims; ++idx) {
+ kernel[idx] = param.kernel[idx];
+ pad_l[idx] = param.pad[idx];
+ pad_r[idx] = param.pad[idx];
+ strides[idx] = param.stride[idx];
+ }
+ if (param.pooling_convention == pool_enum::kFull) {
+ for (int idx = 0; idx < kernel_ndims; ++idx) {
+ pad_r[idx] = GetPaddingSizeFull(data_md.data.dims[idx + 2], pad_l[idx],
+ pad_r[idx], kernel[idx], strides[idx]);
+ }
+ }
+ if (param.global_pool) {
+ for (int idx = 0; idx < kernel_ndims; ++idx) {
+ kernel[idx] = data_md.data.dims[idx + 2];
+ strides[idx] = 1;
+ pad_l[idx] = pad_r[idx] = 0;
+ }
+ }
+ for (int idx = 0; idx < kernel_ndims; ++idx) {
+ CHECK_GT(kernel[idx], 0) << "Filter dimensions cannot be zero.";
+ }
+}
+
void InitPoolingPrimitiveParams(const PoolingParam ¶m,
const mkldnn::memory::desc &data_md,
const mkldnn::memory::dims &new_kernel,
const mkldnn::memory::dims &new_strides,
const mkldnn::memory::dims &new_pad_l,
const mkldnn::memory::dims &new_pad_r) {
const int kernel_ndims = param.kernel.ndim();
- mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
- mkldnn::memory::dims& strides =
const_cast<mkldnn::memory::dims&>(new_strides);
- mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
- mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
- if (kernel_ndims == 1) {
- CHECK_GE(param.pad.ndim(), 1);
- CHECK_GE(param.stride.ndim(), 1);
- kernel[0] = param.kernel[0];
- pad_l[0] = param.pad[0];
- pad_r[0] = param.pad[0];
- strides[0] = param.stride[0];
-
- if (param.pooling_convention == pool_enum::kFull) {
- pad_r[0] =
- GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0],
kernel[0], strides[0]);
- }
-
- if (param.global_pool) {
- kernel[0] = data_md.data.dims[2];
- strides[0] = 1;
- pad_l[0] = pad_r[0] = 0;
- }
+ mkldnn::memory::dims &kernel = const_cast<mkldnn::memory::dims
&>(new_kernel);
Review comment:
This suggestion will be placed in the next pull request - we have to
look at the wider structure of this file.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]