This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new b748c45  [v1.6.x] Backport [MKL-DNN] Integrate Conv3d and Pool3d/1d 
(#17884) and Fix Sanity pipeline in 1.6.x (#18206)
b748c45 is described below

commit b748c45d0ca83eebcd629efb500d3cbc0b6dce7c
Author: Chaitanya Prakash Bapat <[email protected]>
AuthorDate: Wed May 6 11:58:04 2020 -0700

    [v1.6.x] Backport [MKL-DNN] Integrate Conv3d and Pool3d/1d (#17884) and Fix 
Sanity pipeline in 1.6.x (#18206)
    
    * [MKL-DNN] Integrate Conv3d and Pool3d/1d (#17884)
    
    * Integrate MKl-DNN conv3d and pool3d/1d
    
    * fix UT & address comments
    
    * clean code
    
    * rebase against latest master
    
    * pylint astroid sanity issue
    
    * astroid and pylint versions only supported in py3
    
    * remove kBFloat16 as its not supported int 1.6
    
    * added missing definition GetPaddingSizeFull
    
    * Remove dilation restriction for conv3d (#17491)
    
    * Remove conv3d dilation restriction
    
    * Remove comment
    
    * fix unix-gpu test for num_outputs and inputs
    
    Co-authored-by: Wuxun Zhang <[email protected]>
    Co-authored-by: reminisce <[email protected]>
---
 ci/docker/install/requirements                     |   1 +
 src/operator/nn/convolution.cc                     |   4 -
 src/operator/nn/mkldnn/mkldnn_act.cc               |  12 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h           |  49 ++--
 src/operator/nn/mkldnn/mkldnn_base.cc              |  47 ++--
 src/operator/nn/mkldnn/mkldnn_convolution.cc       |  60 ++++-
 src/operator/nn/mkldnn/mkldnn_pooling-inl.h        |  65 +++--
 src/operator/nn/mkldnn/mkldnn_pooling.cc           | 277 +++++++++++----------
 src/operator/nn/pooling.cc                         |   7 +-
 .../mkldnn/mkldnn_quantized_pooling.cc             |   4 +-
 src/operator/quantization/quantized_conv.cc        |  97 ++++++--
 src/operator/quantization/quantized_pooling.cc     | 100 +++++---
 src/operator/subgraph/mkldnn/mkldnn_conv.cc        |   9 +-
 .../subgraph/mkldnn/mkldnn_conv_property.h         |   3 +-
 .../subgraph/mkldnn/mkldnn_subgraph_base-inl.h     |   3 +-
 tests/cpp/operator/mkldnn_operator_test.cc         |   4 +-
 tests/python/mkl/test_mkldnn.py                    |  11 +-
 tests/python/quantization/test_quantization.py     |  54 ++--
 tests/python/unittest/test_operator.py             |   4 +
 19 files changed, 515 insertions(+), 296 deletions(-)

diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements
index 61c9ef8..5f9f28c 100644
--- a/ci/docker/install/requirements
+++ b/ci/docker/install/requirements
@@ -29,6 +29,7 @@ nose==1.3.7
 nose-timer==0.7.3
 numpy>1.16.0,<1.18.0
 pylint==2.3.1; python_version >= '3.0'
+astroid==2.3.3; python_version >= '3.0'
 requests<2.19.0,>=2.18.4
 scipy==1.2.1
 six==1.11.0
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 6d9f84f..6c8ab3a 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -223,8 +223,6 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
       SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter));
     }
 
-    // Note: 3D dilation currently not supported.
-    // Calculations below done to preserve symmetry with 1D/2D code.
     const index_t dilated_ksize_d = param_.DilatedKernelSize(0);
     const index_t dilated_ksize_y = param_.DilatedKernelSize(1);
     const index_t dilated_ksize_x = param_.DilatedKernelSize(2);
@@ -239,8 +237,6 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
       << "incorrect stride size: " << param_.stride;
     CHECK_GT(param_.dilate.Size(), 0U) \
       << "incorrect dilate size: " << param_.dilate;
-    CHECK_EQ(param_.dilate.Size(), 1U)
-      << "Dilate is not supported in 3d convolution";
     Shape<5> oshape;
     oshape[0] = dshape[0];
     oshape[1] = param_.num_filter;
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc 
b/src/operator/nn/mkldnn/mkldnn_act.cc
index f3966e6..08e9f4f 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -48,10 +48,10 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
 }
 
 bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
-  // MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
+  // MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
   if ((input.shape().ndim() < 1) ||
-      (input.shape().ndim() > 4) ||
-      (input.dtype() != mshadow::kFloat32))
+      (input.shape().ndim() > 5) ||
+      !(input.dtype() == mshadow::kFloat32))
     return false;
   return SupportMKLDNNAct(param);
 }
@@ -62,10 +62,10 @@ bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param) {
 }
 
 bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input) 
{
-  // MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
+  // MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
   if ((input.shape().ndim() < 1) ||
-      (input.shape().ndim() > 4) ||
-      (input.dtype() != mshadow::kFloat32))
+      (input.shape().ndim() > 5) ||
+      !(input.dtype() == mshadow::kFloat32))
     return false;
   return SupportMKLDNNLeakyRelu(param);
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 9763c42..b7dc54c 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -129,15 +129,8 @@ static inline bool SupportMKLDNN(int dtype, const 
mxnet::TShape &shape) {
     // MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
     return false;
   }
-  return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
-}
-
-static inline bool SupportMKLDNNRnn(const NDArray &input) {
-  if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
-      && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
-    return true;
-  }
-  return false;
+  return (dtype == mshadow::kFloat32) &&
+                    (ndim == 1 || ndim == 2 || ndim == 4);
 }
 
 static inline bool SupportMKLDNNQuantize(int dtype) {
@@ -302,20 +295,32 @@ inline static mkldnn::memory::desc GetWeightDesc(const 
NDArray &arr,
   if (num_groups == 1) {
     return GetMemDesc(arr, dtype);
   } else {
-    auto ndim = arr.shape().ndim();
-    CHECK((ndim == 3) || (ndim == 4))
-        << "MKL-DNN weight currectly supports 3d and 4d layout";
+    const auto ndim = arr.shape().ndim();
+    CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
+        << "MKL-DNN weight currently supports 3d or 4d or 5d layout";
     auto tz = mkldnn::memory::dims{0};
-    const int N = 0, H = 2, W = 3, C = 1;
-    if (ndim == 3) {
-      tz = mkldnn::memory::dims{
-          num_groups, static_cast<int>(arr.shape()[N] / num_groups),
-          static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
-    } else {
-      tz = mkldnn::memory::dims{
-          num_groups, static_cast<int>(arr.shape()[N] / num_groups),
-          static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
-          static_cast<int>(arr.shape()[W])};
+    int N = 0, C = 1, H = 2, W = 3;
+    int D = -1;
+    if (ndim == 5) {
+      D = 2;
+      H = 3;
+      W = 4;
+    }
+    switch (ndim) {
+      case 3:
+        tz = mkldnn::memory::dims{
+                num_groups, arr.shape()[N] / num_groups,
+                arr.shape()[C], arr.shape()[H]};
+        break;
+      case 4:
+        tz = mkldnn::memory::dims{
+                num_groups, arr.shape()[N] / num_groups,
+                arr.shape()[C], arr.shape()[H], arr.shape()[W]};
+        break;
+      case 5:
+        tz = mkldnn::memory::dims{
+                num_groups, arr.shape()[N] / num_groups,
+                arr.shape()[C], arr.shape()[D], arr.shape()[H], 
arr.shape()[W]};
     }
     return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), 
mkldnn::memory::format_tag::any};
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 1b147c6..bd361ae 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -240,31 +240,44 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int 
num_groups) {
   auto tz = mkldnn::memory::dims{0};
   auto format_tag = mkldnn::memory::format_tag::undef;
   auto engine = CpuEngine::Get()->get_engine();
-  const int O = 0, I = 1, H = 2, W = 3;
-  if (arr.shape().ndim() == 2) {
-    tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]), 
static_cast<int>(arr.shape()[I])};
+  const int ndim = arr.shape().ndim();
+  int O = 0, I = 1, H = 2, W = 3;
+  int D = -1;
+  if (ndim == 5) {
+    D = 2;
+    H = 3;
+    W = 4;
+  }
+  if (ndim == 2) {
+    tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]};
     format_tag = mkldnn::memory::format_tag::oi;
-  } else if (arr.shape().ndim() == 3) {
+  } else if (ndim == 3) {
     tz = num_groups > 1
-             ? mkldnn::memory::dims{num_groups, 
static_cast<int>(arr.shape()[O] / num_groups),
-                                    static_cast<int>(arr.shape()[I]),
-                                    static_cast<int>(arr.shape()[H])}
-             : mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
-                                    static_cast<int>(arr.shape()[I]),
-                                    static_cast<int>(arr.shape()[H])};
+             ? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
+                                    arr.shape()[I], arr.shape()[H]}
+             : mkldnn::memory::dims{arr.shape()[O],
+                                    arr.shape()[I], arr.shape()[H]};
     format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw
                                 : mkldnn::memory::format_tag::oiw;
-  } else if (arr.shape().ndim() == 4) {
+  } else if (ndim == 4) {
     tz = num_groups > 1
-             ? mkldnn::memory::dims{num_groups, 
static_cast<int>(arr.shape()[O] / num_groups),
-                                    static_cast<int>(arr.shape()[I]),
-                                    static_cast<int>(arr.shape()[H]),
-                                    static_cast<int>(arr.shape()[W])}
+             ? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
+                                    arr.shape()[I], arr.shape()[H],
+                                    arr.shape()[W]}
              : mkldnn::memory::dims{
-                   static_cast<int>(arr.shape()[O]), 
static_cast<int>(arr.shape()[I]),
-                   static_cast<int>(arr.shape()[H]), 
static_cast<int>(arr.shape()[W])};
+                   arr.shape()[O], arr.shape()[I],  arr.shape()[H], 
arr.shape()[W]};
     format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw
                                 : mkldnn::memory::format_tag::oihw;
+  } else if (ndim == 5) {
+    tz = num_groups > 1
+             ? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
+                                    arr.shape()[I], arr.shape()[D],
+                                    arr.shape()[H], arr.shape()[W]}
+             : mkldnn::memory::dims{
+                   arr.shape()[O], arr.shape()[I], arr.shape()[D],
+                   arr.shape()[H], arr.shape()[W]};
+    format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goidhw
+                                : mkldnn::memory::format_tag::oidhw;
   } else {
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc 
b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index ada42a2..42cbb72 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -37,11 +37,13 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam);
 
 bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
   if ((params.kernel.ndim() != 1) &&
-      (params.kernel.ndim() != 2))
+      (params.kernel.ndim() != 2) &&
+      (params.kernel.ndim() != 3))
     return false;
   return SupportMKLDNNQuantize(input.dtype()) &&
          ((input.shape().ndim() == 3) ||
-          (input.shape().ndim() == 4));
+          (input.shape().ndim() == 4) ||
+          (input.shape().ndim() == 5));
 }
 
 std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
@@ -77,9 +79,19 @@ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> 
GetConvFwdImpl(
     strides[1] = param.conv_param.stride[1];
     padding[0] = param.conv_param.pad[0];
     padding[1] = param.conv_param.pad[1];
+  } else if (param.conv_param.kernel.ndim() == 3) {
+    CHECK_GE(param.conv_param.stride.ndim(), 3);
+    CHECK_GE(param.conv_param.pad.ndim(), 3);
+    CHECK_GE(param.conv_param.dilate.ndim(), 3);
+    strides[0] = param.conv_param.stride[0];
+    strides[1] = param.conv_param.stride[1];
+    strides[2] = param.conv_param.stride[2];
+    padding[0] = param.conv_param.pad[0];
+    padding[1] = param.conv_param.pad[1];
+    padding[2] = param.conv_param.pad[2];
   } else {
     LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size "
-               << param.conv_param.kernel.ndim() << ", supporting only 1 or 
2.";
+               << param.conv_param.kernel.ndim() << ", supporting only 1 or 2 
or 3.";
   }
   mkldnn::primitive_attr attr;
   mkldnn::post_ops ops;
@@ -141,9 +153,13 @@ 
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> GetConvFwdImpl(
     } else if (param.conv_param.dilate.ndim() == 2) {
       dilates[0] = param.conv_param.dilate[0] - 1;
       dilates[1] = param.conv_param.dilate[1] - 1;
+    } else if (param.conv_param.dilate.ndim() == 3) {
+      dilates[0] = param.conv_param.dilate[0] - 1;
+      dilates[1] = param.conv_param.dilate[1] - 1;
+      dilates[2] = param.conv_param.dilate[2] - 1;
     } else {
       LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << 
param.conv_param.dilate.ndim()
-                 << ", supporting only 1 or 2.";
+                 << ", supporting only 1 or 2 or 3.";
     }
     if (bias_md_ptr == nullptr) {
       mkldnn::convolution_forward::desc desc(prop, 
mkldnn::algorithm::convolution_direct, data_md,
@@ -181,9 +197,19 @@ static 
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
     strides[1] = param.stride[1];
     padding[0] = param.pad[0];
     padding[1] = param.pad[1];
+  } else if (param.kernel.ndim() == 3) {
+    CHECK_GE(param.stride.ndim(), 3);
+    CHECK_GE(param.pad.ndim(), 3);
+    CHECK_GE(param.dilate.ndim(), 3);
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
+    strides[2] = param.stride[2];
+    padding[0] = param.pad[0];
+    padding[1] = param.pad[1];
+    padding[2] = param.pad[2];
   } else {
     LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
-               << ", supporting only 1 or 2.";
+               << ", supporting only 1 or 2 or 3.";
   }
 
   auto GetConvBwdDataPd = [&data, &weight, &output,
@@ -216,9 +242,13 @@ static 
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> GetCon
     } else if (param.dilate.ndim() == 2) {
       dilates[0] = param.dilate[0] - 1;
       dilates[1] = param.dilate[1] - 1;
+    } else if (param.dilate.ndim() == 3) {
+      dilates[0] = param.dilate[0] - 1;
+      dilates[1] = param.dilate[1] - 1;
+      dilates[2] = param.dilate[2] - 1;
     } else {
       LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
-                 << param.dilate.ndim() << ", supporting only 1 or 2.";
+                 << param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
     }
     mkldnn::convolution_backward_data::desc 
desc(mkldnn::algorithm::convolution_direct, data_md,
                                                  weight_md, out_md, strides, 
dilates, padding,
@@ -250,9 +280,19 @@ static 
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
     strides[1] = param.stride[1];
     padding[0] = param.pad[0];
     padding[1] = param.pad[1];
+  } else if (param.kernel.ndim() == 3) {
+    CHECK_GE(param.stride.ndim(), 3);
+    CHECK_GE(param.pad.ndim(), 3);
+    CHECK_GE(param.dilate.ndim(), 3);
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
+    strides[2] = param.stride[2];
+    padding[0] = param.pad[0];
+    padding[1] = param.pad[1];
+    padding[2] = param.pad[2];
   } else {
     LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
-               << ", supporting only 1 or 2.";
+               << ", supporting only 1 or 2 or 3.";
   }
 
   auto GetConvBwdWeightsPd = [&data, &weight, &output,
@@ -291,9 +331,13 @@ static 
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> Get
     } else if (param.dilate.ndim() == 2) {
       dilates[0] = param.dilate[0] - 1;
       dilates[1] = param.dilate[1] - 1;
+    } else if (param.dilate.ndim() == 3) {
+      dilates[0] = param.dilate[0] - 1;
+      dilates[1] = param.dilate[1] - 1;
+      dilates[2] = param.dilate[2] - 1;
     } else {
       LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
-                 << param.dilate.ndim() << ", supporting only 1 or 2.";
+                 << param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
     }
     if (bias == nullptr) {
       mkldnn::convolution_backward_weights::desc 
desc(mkldnn::algorithm::convolution_direct,
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h 
b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
index 22e9abd..9858ad2 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h
@@ -38,17 +38,15 @@ class MKLDNNPoolingFwd {
  public:
   MKLDNNPoolingFwd(const mxnet::NDArray &input,
                    const mxnet::NDArray &output,
-                   const int kernel_h, const int kernel_w,
-                   const int stride_h, const int stride_w,
-                   const int padding_t, const int padding_b,
-                   const int padding_l, const int padding_r,
+                   const mkldnn::memory::dims &kernel,
+                   const mkldnn::memory::dims &strides,
+                   const mkldnn::memory::dims &pad_l,
+                   const mkldnn::memory::dims &pad_r,
                    const mkldnn::algorithm alg_kind,
                    const bool with_workspace, const bool is_train):
                    with_workspace_(with_workspace),
                    fwd_(nullptr) {
-    Init(input, output,
-         kernel_h, kernel_w, stride_h, stride_w,
-         padding_t, padding_b, padding_l, padding_r,
+    Init(input, output, kernel, strides, pad_l, pad_r,
          is_train, alg_kind);
   }
 
@@ -67,10 +65,10 @@ class MKLDNNPoolingFwd {
  private:
   void Init(const mxnet::NDArray &input,
             const mxnet::NDArray &output,
-            const int kernel_h, const int kernel_w,
-            const int stride_h, const int stride_w,
-            const int padding_t, const int padding_b,
-            const int padding_l, const int padding_r,
+            const mkldnn::memory::dims &kernel,
+            const mkldnn::memory::dims &strides,
+            const mkldnn::memory::dims &pad_l,
+            const mkldnn::memory::dims &pad_r,
             const bool is_train, const mkldnn::algorithm alg_kind);
 };
 
@@ -89,23 +87,56 @@ class MKLDNNPoolingBwd {
   const mkldnn::pooling_backward::primitive_desc &GetPd();
 };
 
+inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
+  if ((x + padl + padr - k) % s != 0) {
+    return (padr + s - ((x + padl + padr - k) % s));
+  } else {
+    return padr;
+  }
+}
+
 inline bool SupportMKLDNNPooling(const PoolingParam &param) {
-  return param.kernel.ndim() == 2 &&
+  return (param.kernel.ndim() == 1 || param.kernel.ndim() == 2 ||
+          param.kernel.ndim() == 3) &&
          (param.pool_type == pool_enum::kMaxPooling ||
           param.pool_type == pool_enum::kAvgPooling) &&
-         (!param.layout.has_value() || param.layout.value() == mshadow::kNCHW);
+         (!param.layout.has_value() ||
+         (param.layout.value() == mshadow::kNCW || param.layout.value() == 
mshadow::kNCHW ||
+          param.layout.value() == mshadow::kNCDHW));
 }
 
 inline bool SupportMKLDNNPooling(const PoolingParam &param,
-                                 const mxnet::TShape &dshape) {
-  bool ret = SupportMKLDNNPooling(param);
-  if (!ret)
+                                 const NDArray &input) {
+  const auto dshape = input.shape();
+  const auto ndim = dshape.ndim();
+  const auto dtype = input.dtype();
+
+  if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 
|| ndim == 5) &&
+       (dtype == mshadow::kFloat32)))
+    return false;
+
+  if (!SupportMKLDNNPooling(param))
     return false;
 
   if (param.pooling_convention == pool_enum::kValid) {
     return true;
   } else {
-    // currently, only max-pooling is supported for full convention
+    if (param.pool_type == pool_enum::kAvgPooling) {
+      // mkldnn works differently when padding is asymmetric, so let's skip 
this case.
+      bool is_symmetric = true;
+      switch (ndim) {
+        case 5:
+          is_symmetric = is_symmetric && (param.pad[2] == 
GetPaddingSizeFull(dshape[4],
+                                param.pad[2], param.pad[2], param.kernel[2], 
param.stride[2]));
+        case 4:
+          is_symmetric = is_symmetric && (param.pad[1] == 
GetPaddingSizeFull(dshape[3],
+                                param.pad[1], param.pad[1], param.kernel[1], 
param.stride[1]));
+        case 3:
+          is_symmetric = is_symmetric && (param.pad[0] == 
GetPaddingSizeFull(dshape[2],
+                                param.pad[0], param.pad[0], param.kernel[0], 
param.stride[0]));
+      }
+      return is_symmetric;
+    }
     return param.pool_type == pool_enum::kMaxPooling;
   }
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc 
b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 6eda2aa..bb1a75e 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -31,19 +31,13 @@ namespace mxnet {
 namespace op {
 
 void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray 
&output,
-                            const int kernel_h,  const int kernel_w,
-                            const int stride_h,  const int stride_w,
-                            const int padding_t, const int padding_b,
-                            const int padding_l, const int padding_r,
+                            const mkldnn::memory::dims &kernel,
+                            const mkldnn::memory::dims &strides,
+                            const mkldnn::memory::dims &pad_l,
+                            const mkldnn::memory::dims &pad_r,
                             const bool is_train, const mkldnn::algorithm 
alg_kind) {
-  auto src_md = input.GetMKLDNNData()->get_desc();
-  mkldnn::memory::dims dims = {src_md.data.dims[0],
-                               src_md.data.dims[1],
-                               static_cast<int>(output.shape()[2]),
-                               static_cast<int>(output.shape()[3])};
-  auto dst_md = mkldnn::memory::desc({dims},
-                                     
static_cast<mkldnn::memory::data_type>(src_md.data.data_type),
-                                     mkldnn::memory::format_tag::any);
+  const auto src_md = input.GetMKLDNNData()->get_desc();
+  const auto dst_md = GetMemDesc(output);
   const mkldnn::engine engine = CpuEngine::Get()->get_engine();
   if (alg_kind != mkldnn::algorithm::pooling_max &&
       alg_kind != mkldnn::algorithm::pooling_avg &&
@@ -60,11 +54,6 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, 
const mxnet::NDArray &o
     LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring";
   }
 
-  const mkldnn::memory::dims strides = {stride_h,  stride_w  };
-  const mkldnn::memory::dims pad_l   = {padding_t, padding_l };
-  const mkldnn::memory::dims pad_r   = {padding_b, padding_r };
-  const mkldnn::memory::dims kernel  = {kernel_h,  kernel_w  };
-  // mkldnn::pooling_forward::desc
   const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, 
dst_md,
                                                       strides, kernel, pad_l, 
pad_r);
   this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, 
engine));
@@ -127,52 +116,129 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam 
&param) {
   }
 }
 
-static inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int 
s) {
-  if ((x + padl + padr - k) % s != 0) {
-    return (padr + s - ((x + padl + padr - k) % s));
-  } else {
-    return padr;
-  }
-}
+void InitPoolingPrimitiveParams(const PoolingParam &param,
+                                const mkldnn::memory::desc &data_md,
+                                const mkldnn::memory::dims &new_kernel,
+                                const mkldnn::memory::dims &new_strides,
+                                const mkldnn::memory::dims &new_pad_l,
+                                const mkldnn::memory::dims &new_pad_r) {
+  const int kernel_ndims = param.kernel.ndim();
+  mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
+  mkldnn::memory::dims& strides = 
const_cast<mkldnn::memory::dims&>(new_strides);
+  mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
+  mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
+  if (kernel_ndims == 1) {
+    CHECK_GE(param.pad.ndim(), 1);
+    CHECK_GE(param.stride.ndim(), 1);
+    kernel[0] = param.kernel[0];
+    pad_l[0] = param.pad[0];
+    pad_r[0] = param.pad[0];
+    strides[0] = param.stride[0];
 
-mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
-    const PoolingParam &param, const bool is_train, const mkldnn::memory::desc 
&data_md,
-    const mkldnn::memory::desc &out_md) {
-  CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
-  int kernel_h_, kernel_w_;
-  if (param.global_pool) {
-    kernel_h_ = data_md.data.dims[2];
-    kernel_w_ = data_md.data.dims[3];
-  } else {
-    kernel_h_ = param.kernel[0];
-    kernel_w_ = param.kernel[1];
-  }
+    if (param.pooling_convention == pool_enum::kFull) {
+      pad_r[0] =
+        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], 
kernel[0], strides[0]);
+    }
 
-  CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-  CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
+    if (param.global_pool) {
+      kernel[0] = data_md.data.dims[2];
+      strides[0] = 1;
+      pad_l[0] = pad_r[0] = 0;
+    }
 
-  int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
-  int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
-  int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
+    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
+  } else if (kernel_ndims == 2) {
+    CHECK_GE(param.pad.ndim(), 2);
+    CHECK_GE(param.stride.ndim(), 2);
+    kernel[0] = param.kernel[0];
+    kernel[1] = param.kernel[1];
+    pad_l[0] = param.pad[0];
+    pad_l[1] = param.pad[1];
+    pad_r[0] = param.pad[0];
+    pad_r[1] = param.pad[1];
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
 
-  if (param.pooling_convention == pool_enum::kFull) {
-    pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, 
kernel_h_, stride_h_);
-    pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, 
kernel_w_, stride_w_);
-  }
+    if (param.pooling_convention == pool_enum::kFull) {
+      pad_r[0] =
+        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], 
kernel[0], strides[0]);
+      pad_r[1] =
+        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], 
kernel[1], strides[1]);
+    }
 
-  const mkldnn::engine engine = CpuEngine::Get()->get_engine();
-  if (param.global_pool) {
-    pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
-    stride_h_ = stride_w_ = 1;
+    if (param.global_pool) {
+      kernel[0] = data_md.data.dims[2];
+      kernel[1] = data_md.data.dims[3];
+      strides[0] = strides[1] = 1;
+      pad_l[0] = pad_l[1] = pad_r[0] = pad_r[1] = 0;
+    }
+
+    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
+    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
+  } else {
+    CHECK_GE(param.pad.ndim(), 3);
+    CHECK_GE(param.stride.ndim(), 3);
+    kernel[0] = param.kernel[0];
+    kernel[1] = param.kernel[1];
+    kernel[2] = param.kernel[2];
+    pad_l[0] = param.pad[0];
+    pad_l[1] = param.pad[1];
+    pad_l[2] = param.pad[2];
+    pad_r[0] = param.pad[0];
+    pad_r[1] = param.pad[1];
+    pad_r[2] = param.pad[2];
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
+    strides[2] = param.stride[2];
+
+    if (param.pooling_convention == pool_enum::kFull) {
+      pad_r[0] =
+        GetPaddingSizeFull(data_md.data.dims[2], pad_l[0], pad_r[0], 
kernel[0], strides[0]);
+      pad_r[1] =
+        GetPaddingSizeFull(data_md.data.dims[3], pad_l[1], pad_r[1], 
kernel[1], strides[1]);
+      pad_r[2] =
+        GetPaddingSizeFull(data_md.data.dims[4], pad_l[2], pad_r[2], 
kernel[2], strides[2]);
+    }
+
+    if (param.global_pool) {
+      kernel[0] = data_md.data.dims[2];
+      kernel[1] = data_md.data.dims[3];
+      kernel[2] = data_md.data.dims[4];
+      strides[0] = strides[1] = strides[2] = 1;
+      pad_l[0] = pad_l[1] = pad_l[2] = pad_r[0] = pad_r[1] = pad_r[2] = 0;
+    }
+
+    CHECK_GT(kernel[0], 0) << "Filter dimensions cannot be zero.";
+    CHECK_GT(kernel[1], 0) << "Filter dimensions cannot be zero.";
+    CHECK_GT(kernel[2], 0) << "Filter dimensions cannot be zero.";
   }
 
-  if (pad_t_ != 0 || pad_l_ != 0) {
+  if (pad_l[0] != 0 || (kernel_ndims == 2 && pad_l[1] != 0) ||
+     (kernel_ndims == 3 && pad_l[2] != 0)) {
     CHECK(param.pool_type == pool_enum::kAvgPooling ||
           param.pool_type == pool_enum::kMaxPooling)
         << "Padding implemented only for average and max pooling.";
-    CHECK_LT(pad_l_, kernel_w_);
-    CHECK_LT(pad_t_, kernel_h_);
+    CHECK_LT(pad_l[0], kernel[0]);
+    if (kernel_ndims > 1)
+      CHECK_LT(pad_l[1], kernel[1]);
+    if (kernel_ndims > 2)
+      CHECK_LT(pad_l[2], kernel[2]);
   }
+}
+
+mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
+    const PoolingParam &param, const bool is_train, const mkldnn::memory::desc 
&data_md,
+    const mkldnn::memory::desc &out_md) {
+  CHECK(param.kernel.ndim() == 1 || param.kernel.ndim() == 2 || 
param.kernel.ndim() == 3)
+        << "Not Implemented";
+
+  const int kernel_ndims = param.kernel.ndim();
+  mkldnn::memory::dims kernel(kernel_ndims);
+  mkldnn::memory::dims strides(kernel_ndims);
+  mkldnn::memory::dims pad_l(kernel_ndims);
+  mkldnn::memory::dims pad_r(kernel_ndims);
+
+  InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);
 
   const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
   mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
@@ -180,15 +246,9 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
     kind = mkldnn::prop_kind::forward_training;
   }
 
-  const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, 
out_md,
-                                              {static_cast<int>(stride_h_),
-                                               static_cast<int>(stride_w_)},
-                                              {kernel_h_, kernel_w_},
-                                              {static_cast<int>(pad_t_),
-                                               static_cast<int>(pad_l_)},
-                                              {static_cast<int>(pad_b_),
-                                               static_cast<int>(pad_r_)});
-  return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine);
+  const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, 
out_md, strides,
+                                                      kernel, pad_l, pad_r);
+  return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, 
CpuEngine::Get()->get_engine());
 }
 
 MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
@@ -214,45 +274,20 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
 
   auto it = pooling_fwds.find(key);
   if (it == pooling_fwds.end()) {
-    CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented";
+    CHECK(param.kernel.ndim() == 1 || param.kernel.ndim() == 2 || 
param.kernel.ndim() == 3)
+          << "Not Implemented";
     auto data_md = data.GetMKLDNNData()->get_desc();
-    int kernel_h_, kernel_w_;
-    if (param.global_pool) {
-      kernel_h_ = data_md.data.dims[2];
-      kernel_w_ = data_md.data.dims[3];
-    } else {
-      kernel_h_ = param.kernel[0];
-      kernel_w_ = param.kernel[1];
-    }
 
-    CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
-
-    int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
-    int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
-    int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
-
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, 
kernel_h_, stride_h_);
-      pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, 
kernel_w_, stride_w_);
-    }
-
-    if (param.global_pool) {
-      pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
-      stride_h_ = stride_w_ = 1;
-    }
-
-    if (pad_t_ != 0 || pad_l_ != 0) {
-      CHECK(param.pool_type == pool_enum::kAvgPooling ||
-            param.pool_type == pool_enum::kMaxPooling)
-            << "Padding implemented only for average and max pooling.";
-      CHECK_LT(pad_l_, kernel_w_);
-      CHECK_LT(pad_t_, kernel_h_);
-    }
+    const auto kernel_ndims = param.kernel.ndim();
+    mkldnn::memory::dims kernel(kernel_ndims);
+    mkldnn::memory::dims strides(kernel_ndims);
+    mkldnn::memory::dims pad_l(kernel_ndims);
+    mkldnn::memory::dims pad_r(kernel_ndims);
+    InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);
 
     const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
-    MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, 
stride_w_,
-                         pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, 
is_train);
+    MKLDNNPoolingFwd fwd(data, output, kernel, strides,
+                         pad_l, pad_r, alg, with_workspace, is_train);
     it = AddToCache(&pooling_fwds, key, fwd);
   }
   return it->second;
@@ -304,50 +339,24 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
     auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
     auto input_mem = in_data.GetMKLDNNData();
     const mkldnn::memory::desc data_md = input_mem->get_desc();
-    const mkldnn::memory::dims dims = {data_md.data.dims[0], 
data_md.data.dims[1],
-                               static_cast<int>(out_grad.shape()[2]),
-                               static_cast<int>(out_grad.shape()[3])};
-    const mkldnn::memory::desc out_md(
-        {dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type),
-        mkldnn::memory::format_tag::any);
+    const mkldnn::memory::desc out_md = GetMemDesc(out_grad);
     auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
-    const mkldnn::memory::desc diff_md =
-        diff_dst_mem->get_desc();
-    const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], 
diff_md.data.dims[1],
-                                static_cast<int>(in_grad.shape()[2]),
-                                static_cast<int>(in_grad.shape()[3])};
-    const mkldnn::memory::desc diff_in_md(
-        {dims1}, 
static_cast<mkldnn::memory::data_type>(diff_md.data.data_type),
-        mkldnn::memory::format_tag::any);
-    const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();;
-    const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
+    const mkldnn::memory::desc diff_md = diff_dst_mem->get_desc();
 
-    int kernel_h_, kernel_w_;
-    if (param.global_pool) {
-      kernel_h_ = data_md.data.dims[2];
-      kernel_w_ = data_md.data.dims[3];
-    } else {
-      kernel_h_ = param.kernel[0];
-      kernel_w_ = param.kernel[1];
-    }
-
-    int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
-    int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
-    int stride_h_ = param.stride[0], stride_w_ = param.stride[1];
+    const mkldnn::memory::desc diff_in_md = GetMemDesc(in_grad);
+    const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();
+    const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
 
-    if (param.pooling_convention == pool_enum::kFull) {
-      pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, 
kernel_h_, stride_h_);
-      pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, 
kernel_w_, stride_w_);
-    }
+    const int kernel_ndims = param.kernel.ndim();
+    mkldnn::memory::dims kernel(kernel_ndims);
+    mkldnn::memory::dims strides(kernel_ndims);
+    mkldnn::memory::dims pad_l(kernel_ndims);
+    mkldnn::memory::dims pad_r(kernel_ndims);
 
-    if (param.global_pool) {
-      pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
-      stride_h_ = stride_w_ = 1;
-    }
+    InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);
 
     const mkldnn::pooling_backward::desc desc(
-        alg, diff_in_md, diff_md, {stride_h_, stride_w_},
-        {kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_});
+                alg, diff_in_md, diff_md, strides, kernel, pad_l, pad_r);
     const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, 
cpu_engine, fwd_pd);
     MKLDNNPoolingBwd bwd(pdesc, with_workspace);
     it = AddToCache(&pooling_bwds, key, bwd);
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 485fc13..943cac0 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -278,9 +278,7 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
     return;
   }
 
-
-  if (SupportMKLDNN(inputs[0]) &&
-      SupportMKLDNNPooling(param, inputs[0].shape())) {
+  if (SupportMKLDNNPooling(param, inputs[0])) {
     if (MKLDNNRequireWorkspace(param)) {
       CHECK_GT(outputs.size(), 1U);
       workspace = &outputs[1];
@@ -306,8 +304,7 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, 
const OpContext &ctx,
   }
 
 
-  if (SupportMKLDNN(inputs[0])
-      && SupportMKLDNNPooling(param, inputs[0].shape())) {
+  if (SupportMKLDNNPooling(param, inputs[0])) {
     const NDArray &out_grad = inputs[0];
     const NDArray *workspace = nullptr;
     const NDArray *in_data = nullptr;
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc 
b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
index 190dfed..740c5f9 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_pooling.cc
@@ -35,8 +35,8 @@ static void MKLDNNQuantizedPoolingForward(const 
nnvm::NodeAttrs& attrs, const Op
                                           const std::vector<OpReqType> &req,
                                           const std::vector<NDArray> 
&out_data) {
   CHECK(in_data[0].dtype() == mshadow::kUint8
-    || in_data[0].dtype() == mshadow::kInt8)
-    << "mkldnn_quantized_pooling op only supports uint8 and int8 as input 
type";
+        || in_data[0].dtype() == mshadow::kInt8)
+        << "mkldnn_quantized_pooling op only supports uint8 and int8 as input 
type";
   const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
   MKLDNNPoolingCompute(ctx, param, in_data[0], req[0], out_data[0], nullptr);
   out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
diff --git a/src/operator/quantization/quantized_conv.cc 
b/src/operator/quantization/quantized_conv.cc
index 9d774dd..412e315 100644
--- a/src/operator/quantization/quantized_conv.cc
+++ b/src/operator/quantization/quantized_conv.cc
@@ -40,27 +40,88 @@ bool QuantizedConvShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_shape->size(), param.no_bias? 6U : 9U);
   CHECK_EQ(out_shape->size(), 3U);
   if (param.layout.has_value()) {
+#if MXNET_USE_MKLDNN == 1
+    CHECK(param.layout.value() == mshadow::kNCHW || param.layout.value() == 
mshadow::kNCDHW)
+          << "mkldnn quantized_conv now supports NCHW or NCDHW for now";
+#else
     CHECK_EQ(param.layout.value(), mshadow::kNCHW) << "quantized_conv only 
supports NCHW for now";
+#endif
   }
-  CHECK_EQ(param.kernel.ndim(), 2U) << "quantized_conv only supports 2D 
convolution for now";
-  CHECK(param.dilate.ndim() == 0U || param.dilate.Size() == 1U)
-    << "quantized_conv only supports dilation=1 for all dimensions";
+
   const mxnet::TShape& dshape =  in_shape->at(0);
-  CHECK_EQ(dshape.ndim(), 4U);
-  if (dshape.ndim() == 0U) return false;
+  const int data_ndims = dshape.ndim();
+  const int kernel_ndims = param.kernel.ndim();
+  if (data_ndims == 0U) return false;
 
-  const int N = 0, H = 2, W = 3, C = 1;
-  CHECK_EQ(dshape[C] % 4,  0U)
+#if MXNET_USE_MKLDNN == 1
+  CHECK(kernel_ndims == 2U || kernel_ndims == 3U)
+        << "mkldnn quantized_conv only supports 2d or 3d kernel for now";
+  CHECK(data_ndims == 4U || data_ndims == 5U)
+        << "mkldnn quantized_conv only supports 4d or 5d layout for now";
+#else
+  CHECK_EQ(kernel_ndims, 2U) << "quantized_conv only supports 2D convolution 
for now";
+  CHECK(param.dilate.ndim() == 0U || param.dilate.Size() == 1U)
+        << "quantized_conv only supports dilation=1 for all dimensions";
+  CHECK_EQ(data_ndims, 4U);
+  CHECK_EQ(dshape[1] % 4,  0U)
     << "for 8bit cudnn conv, the number of channel must be multiple of 4";
   CHECK_EQ(param.num_filter % 4, 0U)
     << "for 8bit cudnn conv, the number of channel must be multiple of 4";
+#endif
+
+  auto AddPad = [](index_t dsize, index_t pad) { return dsize + 2 * pad; };
+  const int D = (data_ndims == 5) ? 2 : 1;
+  const int N = 0, H = D + 1, W = D + 2, C = 1;
+
+if (data_ndims == 4) {
+    // conv 2d
+    mxnet::TShape wshape(data_ndims, 0);
+    wshape[N] = param.num_filter;
+    wshape[H] = param.kernel[0];
+    wshape[W] = param.kernel[1];
+    wshape[C] = dshape[C];
+    SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
+
+    mxnet::TShape oshape{1, 1, 1, 1};
+    oshape[N] = dshape[N];
+    oshape[C] = wshape[N];
+
+    const index_t dilated_ksize_y = param.DilatedKernelSize(0);
+    const index_t dilated_ksize_x = param.DilatedKernelSize(1);
+    oshape[H] = (AddPad(dshape[H], param.pad[0]) - dilated_ksize_y) / 
param.stride[0] + 1;
+    oshape[W] = (AddPad(dshape[W], param.pad[1]) - dilated_ksize_x) / 
param.stride[1] + 1;
+
+    SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape);
+    SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
+    SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));
+#if MXNET_USE_MKLDNN == 1
+  } else {
+    // conv 3d
+    mxnet::TShape wshape(data_ndims, 0);
+    wshape[N] = param.num_filter;
+    wshape[D] = param.kernel[0];
+    wshape[H] = param.kernel[1];
+    wshape[W] = param.kernel[2];
+    wshape[C] = dshape[C];
+    SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
+
+    mxnet::TShape oshape{1, 1, 1, 1, 1};
+    oshape[N] = dshape[N];
+    oshape[C] = wshape[N];
+
+    const index_t dilated_ksize_d = param.DilatedKernelSize(0);
+    const index_t dilated_ksize_y = param.DilatedKernelSize(1);
+    const index_t dilated_ksize_x = param.DilatedKernelSize(2);
+    oshape[D] = (AddPad(dshape[D], param.pad[0]) - dilated_ksize_d) / 
param.stride[0] + 1;
+    oshape[H] = (AddPad(dshape[H], param.pad[1]) - dilated_ksize_y) / 
param.stride[1] + 1;
+    oshape[W] = (AddPad(dshape[W], param.pad[2]) - dilated_ksize_x) / 
param.stride[2] + 1;
+
+    SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape);
+    SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
+    SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));
+#endif
+  }
 
-  mxnet::TShape wshape{0, 0, 0, 0};
-  wshape[N] = param.num_filter;
-  wshape[H] = param.kernel[0];
-  wshape[W] = param.kernel[1];
-  wshape[C] = dshape[C];
-  SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
   const int start = param.no_bias? 2 : 3;
   const int end = param.no_bias? 6 : 9;
   for (int i = start; i < end; ++i) {
@@ -70,16 +131,6 @@ bool QuantizedConvShape(const nnvm::NodeAttrs& attrs,
     SHAPE_ASSIGN_CHECK(*in_shape, 2, Shape1(param.num_filter));
   }
 
-  auto AddPad = [](index_t dsize, index_t pad) { return dsize + 2 * pad; };
-  mxnet::TShape oshape{1, 1, 1, 1};
-  oshape[N] = dshape[N];
-  oshape[C] = wshape[N];
-  oshape[H] = (AddPad(dshape[H], param.pad[0]) - wshape[H]) / param.stride[0] 
+ 1;
-  oshape[W] = (AddPad(dshape[W], param.pad[1]) - wshape[W]) / param.stride[1] 
+ 1;
-
-  SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape);
-  SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
-  SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));
   return true;
 }
 
diff --git a/src/operator/quantization/quantized_pooling.cc 
b/src/operator/quantization/quantized_pooling.cc
index eeb2ac4..ce4a48c 100644
--- a/src/operator/quantization/quantized_pooling.cc
+++ b/src/operator/quantization/quantized_pooling.cc
@@ -37,47 +37,89 @@ bool QuantizedPoolingShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_shape->size(), 3U);
   if (!shape_is_known(in_shape->at(0))) return false;
   const mxnet::TShape &dshape = (*in_shape)[0];
-  CHECK_EQ(dshape.ndim(), 4U)
-      << "quantized_pooling: Input data should be 4D in "
-      << "(batch, channel, y, x)";
-  int layout = param.GetLayout(dshape.ndim());
+
+  const int data_ndims = dshape.ndim();
+  const int kernel_ndims = param.kernel.ndim();
+  const int layout = param.GetLayout(data_ndims);
+
+#if MXNET_USE_MKLDNN == 1
+  CHECK(data_ndims == 4U || data_ndims == 5U)
+        << "MKL-DNN QuantizedPoolingOp only supports 4D/5D layout yet, input 
should be 4D in"
+        << "(batch, channel, y, x) or 5D in (batch, channel, d, y, x)";
+  CHECK(layout == mshadow::kNCHW || layout == mshadow::kNCDHW)
+        << "MKL-DNN QuantizedPoolingOp only supports NCHW/NCDHW layout for 
now, saw " << layout;
+  CHECK(kernel_ndims == 2U || kernel_ndims == 3U)
+        << "MKL-DNN QuantizedPoolingOp only supports 2D/3D pooling for now, 
saw" << kernel_ndims;
+#else
+  CHECK_EQ(data_ndims, 4U)
+           << "quantized_pooling: Input data should be 4D in "
+           << "(batch, channel, y, x)";
   CHECK_EQ(layout, mshadow::kNCHW)
-      << "QuantizedPoolingOp only supports NCHW layout for now, saw " << 
layout;
-  // NCHW layout
-  const int N = 0, H = 2, W = 3, C = 1;
-  mxnet::TShape oshape(4, -1);
-  CHECK_EQ(param.kernel.ndim(), 2) << "QuantizedPoolingOp only supports 2D 
pooling for now";
-  CHECK(param.kernel[0] <= dshape[H] + 2 * param.pad[0])
-      << "kernel size (" << param.kernel[0]
+           << "QuantizedPoolingOp only supports NCHW layout for now, saw " << 
layout;
+  CHECK_EQ(kernel_ndims, 2U)
+           << "QuantizedPoolingOp only supports 2D pooling for now";
+#endif
+
+  const int D = (data_ndims == 5) ? 2 : 1;
+  const int N = 0, H = D + 1, W = D + 2, C = 1;
+  mxnet::TShape oshape(data_ndims, -1);
+
+  int idx = 0;
+  if (kernel_ndims == 3) {
+    CHECK(param.kernel[idx] <= dshape[D] + 2 * param.pad[idx])
+          << "kernel size (" << param.kernel[0]
+          << ") exceeds input (" << dshape[D]
+          << " padded to " << (dshape[D] + 2 * param.pad[idx]) << ")";
+    ++idx;
+  }
+  CHECK(param.kernel[idx] <= dshape[H] + 2 * param.pad[idx])
+      << "kernel size (" << param.kernel[idx]
       << ") exceeds input (" << dshape[H]
-      << " padded to " << (dshape[H] + 2*param.pad[0]) << ")";
-  CHECK(param.kernel[1] <= dshape[W] + 2 * param.pad[1])
-      << "kernel size (" << param.kernel[1]
+      << " padded to " << (dshape[H] + 2 * param.pad[idx]) << ")";
+  ++idx;
+  CHECK(param.kernel[idx] <= dshape[W] + 2 * param.pad[idx])
+      << "kernel size (" << param.kernel[idx]
       << ") exceeds input (" << dshape[W]
-      << " padded to " << (dshape[W] + 2*param.pad[1]) << ")";
+      << " padded to " << (dshape[W] + 2 * param.pad[idx]) << ")";
+
+#define OUTPUT_SHAPE_VALID_ASSIGN(spatial_dim, idx)                            
                \
+{                                                                              
                \
+  oshape[spatial_dim] = 1 + (dshape[spatial_dim] + 2 * param.pad[idx] - 
param.kernel[idx]) /   \
+                            param.stride[idx];                                 
                \
+}
+#define OUTPUT_SHAPE_FULL_ASSIGN(spatial_dim, idx)                             
                \
+{                                                                              
                \
+  oshape[spatial_dim] = 1 + static_cast<int>(std::ceil(                        
                \
+                              static_cast<float>(dshape[spatial_dim] + 2 * 
param.pad[idx] -    \
+                            param.kernel[idx]) / param.stride[idx]));          
                \
+}
 
   oshape[N] = dshape[N];
   oshape[C] = dshape[C];
   if (param.global_pool) {
+    if (data_ndims == 5)
+      oshape[D] = 1;
     oshape[H] = 1;
     oshape[W] = 1;
   } else {
     if (param.pooling_convention == pool_enum::kValid) {
-      oshape[H] = 1 +
-                  (dshape[H] + 2 * param.pad[0] - param.kernel[0]) /
-                      param.stride[0];
-      oshape[W] = 1 +
-                  (dshape[W] + 2 * param.pad[1] - param.kernel[1]) /
-                      param.stride[1];
+      int idx = 0;
+      if (data_ndims == 5) {
+        OUTPUT_SHAPE_VALID_ASSIGN(D, idx);
+        ++idx;
+      }
+      OUTPUT_SHAPE_VALID_ASSIGN(H, idx);
+      ++idx;
+      OUTPUT_SHAPE_VALID_ASSIGN(W, idx);
     } else {
-      oshape[H] = 1 + static_cast<int>(std::ceil(
-                          static_cast<float>(dshape[H] + 2 * param.pad[0] -
-                                             param.kernel[0]) /
-                          param.stride[0]));
-      oshape[W] = 1 + static_cast<int>(std::ceil(
-                          static_cast<float>(dshape[W] + 2 * param.pad[1] -
-                                             param.kernel[1]) /
-                          param.stride[1]));
+      int idx = 0;
+      if (data_ndims == 5) {
+        OUTPUT_SHAPE_FULL_ASSIGN(D, idx);
+        ++idx;
+      }
+      OUTPUT_SHAPE_FULL_ASSIGN(H, idx);
+      ++idx;
+      OUTPUT_SHAPE_FULL_ASSIGN(W, idx);
     }
   }
 
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc 
b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index e1f9174..b35337f 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -41,7 +41,6 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray 
*bias, bool no_bias,
                                  const NDArray &gamma, const NDArray &beta,
                                  const NDArray &mean, const NDArray &variance,
                                  const BatchNormParam *param) {
-  // TODO(Zhennan): Handle the case weight is not in dims 4.
   NDArray update_weight = NDArray(weight->storage_type(), weight->shape(),
                                   weight->ctx(), true, weight->dtype());
   NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(),
@@ -55,7 +54,8 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray 
*bias, bool no_bias,
   DType *update_weight_ptr = update_weight.data().dptr<DType>();
   DType *update_bias_ptr = update_bias.data().dptr<DType>();
   size_t channel = gamma.shape()[0];
-  size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3];
+  const auto wshape = weight->shape();
+  size_t offset = wshape.ProdShape(1, wshape.ndim());
 #pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
   for (int c = 0; c < static_cast<int>(channel); ++c) {
     const DType *p1 = weight_ptr + c * offset;
@@ -732,8 +732,9 @@ nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& 
attrs) {
   auto const &param = nnvm::get<MKLDNNConvFusionParam>(attrs.parsed);
   nnvm::NodePtr node = nnvm::Node::Create();
   node->attrs.op = Op::Get("_sg_mkldnn_conv");
-  CHECK_EQ(param.full_conv_param.conv_param.kernel.ndim(), 2U)
-      << "Quantized Convolution of MKL-DNN only supports 2D kernel currently."
+  const int k_ndims = param.full_conv_param.conv_param.kernel.ndim();
+  CHECK(k_ndims == 2U || k_ndims == 3U)
+      << "Quantized Convolution of MKL-DNN supports 2D/3D kernel currently."
       <<  "Please exclude this layer from the quantized model.";
   node->attrs.name = "quantized_" + attrs.name;
   node->attrs.dict = attrs.dict;
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h 
b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h
index ff6589e..a5bceb9 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h
@@ -65,7 +65,8 @@ class SgMKLDNNConvSelector : public SubgraphSelector {
   bool Select(const nnvm::Node& n, const std::shared_ptr<NodeAttr>& node_attr) 
override {
     if (n.op() && n.op()->name == "Convolution") {
       const auto &param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
-      if (param.kernel.ndim() == 2 && SupportMKLDNNAttr(node_attr)) {
+      if ((param.kernel.ndim() == 2 || param.kernel.ndim() == 3) &&
+           SupportMKLDNNAttr(node_attr)) {
         status_ = disable_all_ ? kSuccess : kStart;
         matched_list_.clear();
         matched_list_.push_back(&n);
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h 
b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
index 4c8a7ab..05a407b 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
@@ -29,7 +29,8 @@ static inline bool SupportMKLDNNAttr(const 
std::shared_ptr<NodeAttr>& node_attr)
   if (node_attr) {
     int ndim = node_attr->ishape[0].ndim();
     return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
-           (node_attr->itype[0] == mshadow::kFloat32) && (ndim == 1 || ndim == 
2 || ndim == 4);
+           (node_attr->itype[0] == mshadow::kFloat32) &&
+           (ndim == 1 || ndim == 2 || ndim == 4 || ndim == 5);
   } else {
     return true;
   }
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc 
b/tests/cpp/operator/mkldnn_operator_test.cc
index 8ae1db6..d3b5cf1 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -161,7 +161,7 @@ OpAttrs GetPoolingOp(int kernel, int dim, int stride, int 
pad) {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("Pooling");
   attrs.num_inputs = 1;
-  attrs.num_outputs = dim == 2 ? 2 : 1;
+  attrs.num_outputs = (dim == 2 || dim == 3) ? 2 : 1;
   attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)});
   attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)});
   attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)});
@@ -173,7 +173,7 @@ OpAttrs GetPoolingOp(int kernel, int dim, int stride, int 
pad) {
 OpAttrs GetPoolingBackwardsOp(int kernel, int dim, int stride, int pad) {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("_backward_Pooling");
-  attrs.num_inputs = dim == 2 ? 5 : 3;
+  attrs.num_inputs = (dim == 2 || dim == 3) ? 5 : 3;
   attrs.num_outputs = 1;
   attrs.attrs.dict.insert({"kernel", CreateShapeString(kernel, dim)});
   attrs.attrs.dict.insert({"stride", CreateShapeString(stride, dim)});
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index e43daf1..8f71499 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -315,15 +315,17 @@ def test_softmax():
 @with_seed()
 def test_pooling():
     def check_pooling_training(stype):
-        for shape in [(3, 3, 10), (3, 3, 20, 20)]:
+        for shape in [(3, 3, 10), (3, 3, 20, 20), (3, 3, 10, 20, 20)]:
             data_tmp = np.random.normal(-0.1, 0.1, size=shape)
             data = mx.symbol.Variable('data', stype=stype)
             in_location = [mx.nd.array(data_tmp).tostype(stype)]
 
             if np.array(shape).shape[0] == 3:
-                test = mx.symbol.Pooling(data=data, kernel=(3,), stride=(2), 
pool_type='avg')
+                test = mx.symbol.Pooling(data=data, kernel=(3), stride=(2), 
pool_type='avg')
             elif np.array(shape).shape[0] == 4:
                 test = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 
2), pool_type='avg')
+            elif np.array(shape).shape[0] == 5:
+                test = mx.symbol.Pooling(data=data, kernel=(3, 3, 3), 
stride=(2, 2, 2), pool_type='avg')
             else:
                 return 0
             check_numeric_gradient(test, in_location, numeric_eps=1e-2, 
rtol=0.16, atol=1e-4)
@@ -357,7 +359,7 @@ def test_activation():
 @with_seed()
 def test_convolution():
     def check_convolution_training(stype):
-        for shape in [(3, 3, 10), (3, 3, 10, 10)]:
+        for shape in [(3, 3, 10), (3, 3, 10, 10), (3, 3, 10, 10, 10)]:
             data_tmp = np.random.normal(-0.1, 1, size=shape)
             data = mx.symbol.Variable('data', stype=stype)
 
@@ -367,6 +369,9 @@ def test_convolution():
             elif np.array(shape).shape[0] == 4:
                 test = mx.symbol.Convolution(data=data, kernel=(3, 3), 
stride=(2, 2), num_filter=4)
                 weight_tmp = np.random.normal(-0.1, 0.1, size=(4, 3, 3, 3))
+            elif np.array(shape).shape[0] == 5:
+                test = mx.symbol.Convolution(data=data, kernel=(3, 3, 3), 
stride=(2, 2, 2), num_filter=4)
+                weight_tmp = np.random.normal(-0.1, 0.1, size=(4, 3, 3, 3, 3))
             else:
                 return 0
             bias_tmp = np.random.normal(0.1, 0.1, size=(4,))
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index bbe3008..7804f6d 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -196,7 +196,7 @@ def test_requantize_int32_to_int8():
 
 @with_seed()
 def test_quantized_conv():
-    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
no_bias, qdtype):
+    def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, 
dilate, no_bias, qdtype):
         if is_test_for_native_cpu():
             print('skipped testing quantized_conv for native cpu since it is 
not supported yet')
             return
@@ -206,14 +206,17 @@ def test_quantized_conv():
         elif qdtype == 'uint8' and is_test_for_gpu():
             print('skipped testing quantized_conv for gpu uint8 since it is 
not supported yet')
             return
+        elif is_test_for_gpu() and len(data_shape) != 4:
+            print('skipped testing quantized_conv for gpu 5d layout since it 
is not supported yet')
+            return
 
         # run fp32 conv
         data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
-        conv2d = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
-                                    no_bias=no_bias, cudnn_off=False, 
name='conv2d')
-        arg_shapes, _, _ = conv2d.infer_shape(data=data_shape)
-        arg_names = conv2d.list_arguments()
-        conv_exe_fp32 = conv2d.simple_bind(ctx=mx.current_context(), 
grad_req='null')
+        conv = mx.sym.Convolution(data=data, kernel=kernel, 
num_filter=num_filter, pad=pad, stride=stride,
+                                  dilate=dilate, no_bias=no_bias, 
cudnn_off=False, name='conv')
+        arg_shapes, _, _ = conv.infer_shape(data=data_shape)
+        arg_names = conv.list_arguments()
+        conv_exe_fp32 = conv.simple_bind(ctx=mx.current_context(), 
grad_req='null')
         if qdtype == 'uint8':
             data_low = 0.0
             data_high = 127.0
@@ -221,12 +224,12 @@ def test_quantized_conv():
             data_low = -127.0
             data_high = 127.0
         conv_exe_fp32.arg_dict[arg_names[0]][:] = 
mx.nd.random.uniform(low=data_low, high=data_high,
-                                                                        
shape=data_shape).astype('int32')
+                                                                       
shape=data_shape).astype('int32')
         conv_exe_fp32.arg_dict[arg_names[1]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                        
shape=arg_shapes[1]).astype('int32')
+                                                                       
shape=arg_shapes[1]).astype('int32')
         if not no_bias:
             conv_exe_fp32.arg_dict[arg_names[2]][:] = 
mx.nd.random.uniform(low=-127.0, high=127.0,
-                                                                            
shape=arg_shapes[2]).astype('int32')
+                                                                           
shape=arg_shapes[2]).astype('int32')
         output = conv_exe_fp32.forward()[0]
 
         # run quantized conv
@@ -236,16 +239,16 @@ def test_quantized_conv():
         max_data = mx.sym.Variable(name='max_data')
         min_weight = mx.sym.Variable(name='min_weight')
         max_weight = mx.sym.Variable(name='max_weight')
-        quantized_conv2d = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
-                                                            max_data=max_data, 
min_weight=min_weight,
-                                                            
max_weight=max_weight, kernel=kernel,
-                                                            
num_filter=num_filter, pad=pad, stride=stride,
-                                                            no_bias=no_bias)
-        qarg_names = quantized_conv2d.list_arguments()
+        quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, 
weight=qweight, min_data=min_data,
+                                                       max_data=max_data, 
min_weight=min_weight,
+                                                       max_weight=max_weight, 
kernel=kernel,
+                                                       num_filter=num_filter, 
pad=pad, stride=stride,
+                                                       dilate=dilate, 
no_bias=no_bias)
+        qarg_names = quantized_conv.list_arguments()
         type_dict = None
         if not no_bias:
             type_dict = {qarg_names[2]: 'int8'}
-        conv_exe_int8 = quantized_conv2d.simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
+        conv_exe_int8 = quantized_conv.simple_bind(ctx=mx.current_context(), 
type_dict=type_dict, grad_req='null')
         conv_exe_int8.arg_dict[qarg_names[0]][:] = 
conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype)
         conv_exe_int8.arg_dict[qarg_names[1]][:] = 
conv_exe_fp32.arg_dict[arg_names[1]].astype('int8')
         quantized_range = 127.0
@@ -273,8 +276,12 @@ def test_quantized_conv():
             assert cond == 0
 
     for qdtype in ['int8', 'uint8']:
-        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), 
True, qdtype)
-        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), 
False, qdtype)
+        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 
1), True, qdtype)
+        check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 
1), False, qdtype)
+        check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 
1, 1), (1, 1, 1), False, qdtype)
+        check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 
1, 1), (1, 1, 1), True, qdtype)
+        check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 
1, 1), (2, 2, 2), False, qdtype)
+        check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 
1, 1), (2, 2, 2), True, qdtype)
 
 
 @with_seed()
@@ -350,6 +357,9 @@ def test_quantized_pooling():
         elif qdtype == 'uint8' and is_test_for_gpu():
             print('skipped testing quantized_pooling for gpu uint8 since it is 
not supported yet')
             return
+        elif is_test_for_gpu() and len(data_shape) != 4:
+            print('skipped testing quantized_pooling for gpu 5d layout since 
it is not supported yet')
+            return
 
         data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
         pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, 
stride=stride,
@@ -396,11 +406,19 @@ def test_quantized_pooling():
         check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), 
True, qdtype)
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), 
False, qdtype)
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), 
True, qdtype)
+        check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 
0), (1, 2, 2), False, qdtype)
+        check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 
0), (1, 2, 2), True, qdtype)
+        check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 
0), (1, 2, 2), False, qdtype)
+        check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 
0), (1, 2, 2), True, qdtype)
 
         check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), 
False, qdtype, 'full')
         check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), 
True, qdtype, 'full')
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), 
False, qdtype, 'full')
         check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), 
True, qdtype, 'full')
+        check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 
0), (1, 2, 2), False, qdtype, 'full')
+        check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 
0), (1, 2, 2), True, qdtype, 'full')
+        check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 
0), (1, 2, 2), False, qdtype, 'full')
+        check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 
0), (1, 2, 2), True, qdtype, 'full')
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 9ae35f1..37f7376 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2600,6 +2600,10 @@ def test_convolution_dilated_impulse_response():
     for dil in [ (1,1), (2,2), (3,3) ]:
         for ks in [ (3,3), (4,4), (2,3), (3,2), (1,1) ]:
             test_run_convolution_dilated_impulse_response(dil=dil, 
kernel_shape=ks)
+    # 3D
+    for dil in [ (1,1,1), (2,2,2), (3,3,3) ]:
+        for ks in [ (3,3,3), (4,4,4), (2,3,4), (3,2,4), (1,1,1) ]:
+            test_run_convolution_dilated_impulse_response(dil=dil, 
kernel_shape=ks)
 
 
 @with_seed()

Reply via email to