DickJC123 commented on a change in pull request #13749: Add NHWC layout support 
to Pooling (cpu, gpu cuda, gpu cuDNN)
URL: https://github.com/apache/incubator-mxnet/pull/13749#discussion_r254456778
 
 

 ##########
 File path: src/operator/nn/cudnn/cudnn_pooling-inl.h
 ##########
 @@ -167,102 +169,215 @@ class CuDNNPoolingOp {
     }
   }
 
+/*!
+ * \brief Returns whether the cuDNN library version supports the pooling 
operation
+ * described by `param`: cuDNN v5 and earlier does not support 3D pooling for 
example.
+ * CuDNN v7.1.4 backprop kernel doesn't support window sizes 9 and above.
+ */
+  static bool Supports(const PoolingParam &param, const TBlob& input) {
+    using namespace mshadow;
+    static bool sum_pooling_warning_issued = false;
+    static bool lp_pooling_warning_issued = false;
+    static bool unsupported_dim_warning_issued = false;
+    int layout = param.GetLayout(input.ndim());
+
+    switch (param.pool_type) {
+      case pool_enum::kMaxPooling:
+      case pool_enum::kAvgPooling:
+        break;
+      case pool_enum::kSumPooling:
+        if (!sum_pooling_warning_issued) {
+          sum_pooling_warning_issued = true;
+          LOG(WARNING) << "Sum pooling is not supported by cudnn, MXNet sum 
pooling is applied.";
+        }
+        return false;
+      case pool_enum::kLpPooling:
+        if (!lp_pooling_warning_issued) {
+          lp_pooling_warning_issued = true;
+          LOG(WARNING) << "Lp pooling is not supported by cudnn, MXNet Lp 
pooling is applied.";
+        }
+        return false;
+      default:
+        return false;
+    }
+
+    if (param.kernel.ndim() == 2) {
+      // 2d pooling
+      if (!(layout == mshadow::kNCHW || layout == mshadow::kNHWC))
+        return false;
+#if CUDNN_VERSION == 7104
+      // CuDNN v7.1.4 backprop kernel doesn't support window sizes 9 and above.
+      // Perform shape calculations in a standard (NCHW) layout space
+      mshadow::Shape<4> input_shape = input.shape_.get<4>();
+      mshadow::Shape<4> dshape_nchw = (layout == mshadow::kNHWC) ?
+                                      ConvertLayout(input_shape, 
mshadow::kNHWC, mshadow::kNCHW) :
+                                      input_shape;
+      int window_height = param.global_pool ? dshape_nchw[2] : param.kernel[0];
+      int window_width = param.global_pool ? dshape_nchw[3] : param.kernel[1];
+      if (window_height > 8 || window_width > 8)
+        return false;
+#endif
+#if CUDNN_VERSION >= 7105 && CUDNN_VERSION < 7500
+      // Avoid strided NHWC max pooling for some configs
+      if (layout == mshadow::kNHWC &&
+          param.pool_type == pool_enum::kMaxPooling && !param.global_pool) {
+        if (param.stride[0] >= 3 ||
+            param.stride[0] == 2 && param.kernel[0] % 2 == 0 && 
param.kernel[0] != 2)
+          return false;
+        if (param.stride[1] >= 3 ||
+            param.stride[1] == 2 && param.kernel[1] % 2 == 0 && 
param.kernel[1] != 2)
+          return false;
+      }
+#endif
+    } else if (param.kernel.ndim() == 3) {
+      // 3d pooling
+#if CUDNN_MAJOR < 5
+      LogUnsupportedDim(&unsupported_dim_warning_issued, param.kernel.ndim());
+      return false;
+#endif
+      if (!(layout == mshadow::kNCDHW || layout == mshadow::kNDHWC))
+        return false;
+    } else {
+      // Unsupported kernel dim
+      LogUnsupportedDim(&unsupported_dim_warning_issued, param.kernel.ndim());
+      return false;
+    }
+
+    return true;
+  }
+
  private:
-  inline void Init(mshadow::Stream<gpu> *s, const TBlob &in_data,
+  // Return boolean saying whether pooling configuration is supported
+  inline bool Init(mshadow::Stream<gpu> *s, const TBlob &in_data,
       const TBlob &out_data) {
     using namespace mshadow;
+    bool is_supported = true;
     #if CUDNN_MAJOR >= 5
     nan_prop_ = CUDNN_NOT_PROPAGATE_NAN;
     #endif
+    int layout = param_.GetLayout(in_data.ndim());
     if (param_.kernel.ndim() == 2) {
-      // 2d conv
+      // 2d pooling
+      CHECK(layout == mshadow::kNCHW || layout == mshadow::kNHWC) << "Need 2D 
layout NCHW or NHWC.";
+      cudnnTensorFormat_t cudnn_layout = (layout == mshadow::kNCHW) ? 
CUDNN_TENSOR_NCHW
+                                                                    : 
CUDNN_TENSOR_NHWC;
       Tensor<gpu, 4, DType> data = in_data.get<gpu, 4, DType>(s);
       Tensor<gpu, 4, DType> out = out_data.get<gpu, 4, DType>(s);
-      mshadow::Shape<4> dshape = data.shape_;
+      // Perform shape calculations in a standard (NCHW) layout space
+      mshadow::Shape<4> dshape_nchw = (layout == mshadow::kNHWC) ?
+                                      ConvertLayout(data.shape_, 
mshadow::kNHWC, mshadow::kNCHW) :
+                                      data.shape_;
+      mshadow::Shape<4> oshape_nchw = (layout == mshadow::kNHWC) ?
+                                      ConvertLayout(out.shape_, 
mshadow::kNHWC, mshadow::kNCHW) :
+                                      out.shape_;
       CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_,
-                                            CUDNN_TENSOR_NCHW,
+                                            cudnn_layout,
                                             dtype_,
-                                            data.shape_[0],
-                                            data.shape_[1],
-                                            data.shape_[2],
-                                            data.shape_[3]));
+                                            dshape_nchw[0],
+                                            dshape_nchw[1],
+                                            dshape_nchw[2],
+                                            dshape_nchw[3]));
       CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_,
-                                            CUDNN_TENSOR_NCHW,
+                                            cudnn_layout,
                                             dtype_,
-                                            out.shape_[0],
-                                            out.shape_[1],
-                                            out.shape_[2],
-                                            out.shape_[3]));
+                                            oshape_nchw[0],
+                                            oshape_nchw[1],
+                                            oshape_nchw[2],
+                                            oshape_nchw[3]));
+      int window_height = param_.global_pool ? dshape_nchw[2] : 
param_.kernel[0];
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to