anirudh2290 closed pull request #11778: [MXNET-483] C++ tests for mkldnn 
convolution/deconvolution operator
URL: https://github.com/apache/incubator-mxnet/pull/11778
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc 
b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 9cf1b71880a..985a9655b10 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -353,9 +353,18 @@ void MKLDNNConvolutionForwardFullFeature(const 
MKLDNNConvFullParam &param,
                                          const std::vector<OpReqType> &req,
                                          const std::vector<NDArray> &out_data) 
{
   TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
-  NDArray weight = in_data[conv::kWeight];
+
+  auto data = in_data[conv::kData];
+  if (data.IsView() && data.IsMKLDNNData())
+    data = data.Reorder2Default();
+
+  auto weight = in_data[conv::kWeight];
+  if (weight.IsView() && weight.IsMKLDNNData())
+    weight = weight.Reorder2Default();
+
   bool no_bias = param.conv_param.no_bias && !param.mkldnn_param.with_bn;
-  auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(
+
+  auto data_mem = data.GetMKLDNNDataReorder(
       fwd->fwd_pd.src_primitive_desc());
   const mkldnn::memory *weight_mem;
   if (ctx.is_train) {
@@ -577,19 +586,32 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& 
attrs, const OpContext &ct
   MKLDNNConvFullParam full_param;
   full_param.conv_param = nnvm::get<ConvolutionParam>(attrs.parsed);
   full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
+
+  auto data = inputs[conv::kData + 1];
+  if (data.IsView() && data.IsMKLDNNData())
+    data = data.Reorder2Default();
+
+  auto weight = inputs[conv::kWeight + 1];
+  if (weight.IsView() && weight.IsMKLDNNData())
+    weight = weight.Reorder2Default();
+
+  const NDArray* bias = full_param.conv_param.no_bias ? nullptr : 
&inputs[conv::kBias + 1];
+
+  auto out_grad = inputs[conv::kOut];
+  if (out_grad.IsView() && out_grad.IsMKLDNNData())
+    out_grad = out_grad.Reorder2Default();
+
   mkldnn::convolution_forward::primitive_desc fwd_pd = GetConvFwdImpl(
-      full_param, ctx.is_train, inputs[conv::kData + 1], inputs[conv::kWeight 
+ 1],
-      full_param.conv_param.no_bias ? nullptr : &inputs[conv::kBias + 1],
-      inputs[conv::kOut]);
+      full_param, ctx.is_train, data, weight, bias, out_grad);
   const ConvolutionParam &param = full_param.conv_param;
 
   CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
-  MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
-             inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
-  auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
+  MKLDNNConvBackward &convBwd = GetConvBwd(attrs, data,
+      weight, bias, out_grad, fwd_pd);
+  auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
       convBwd.bwdData_pd.diff_dst_primitive_desc());
   if (req[conv::kData]) {
-    auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
+    auto weight_mem = GetWeights(weight,
         convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
     auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
         convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
@@ -598,14 +620,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& 
attrs, const OpContext &ct
     CommitOutput(in_grad[conv::kData], in_grad_mem);
   }
   if (req[conv::kWeight]) {
-    MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 
1],
-             inputs[conv::kWeight + 1], param.no_bias ? nullptr : 
&inputs[conv::kBias + 1],
-             inputs[conv::kOut], fwd_pd);
+    MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, data,
+        weight, bias, out_grad, fwd_pd);
     if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
         convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
-      out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
+      out_grad_mem = out_grad.GetMKLDNNDataReorder(
           convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
-    auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
+    auto data_mem = data.GetMKLDNNDataReorder(
         convBwdWeight.bwdWeights_pd.src_primitive_desc());
     auto in_grad_weight = CreateMKLDNNWeightGrad(
         in_grad[conv::kWeight],
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc 
b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 93032f7c92d..577fae0d716 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -212,7 +212,8 @@ class MKLDNNDeconvForward {
                       const NDArray &output);
   void SetDataHandle(const DeconvolutionParam& param,
                      const OpContext &ctx,
-                     const std::vector<NDArray> &in_data,
+                     const NDArray &in_data,
+                     const NDArray &weight,
                      const std::vector<OpReqType> &req,
                      const std::vector<NDArray> &out_data);
 
@@ -243,32 +244,30 @@ MKLDNNDeconvForward::MKLDNNDeconvForward(const 
DeconvolutionParam& param,
 
 void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
                                         const OpContext &ctx,
-                                        const std::vector<NDArray> &in_data,
+                                        const NDArray &in_data,
+                                        const NDArray &weight,
                                         const std::vector<OpReqType> &req,
                                         const std::vector<NDArray> &out_data) {
-  auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
+  auto data_mem = in_data.GetMKLDNNDataReorder(
       fwd_pd.diff_dst_primitive_desc());
-  NDArray weight = in_data[deconv::kWeight];
   const mkldnn::memory *weight_mem;
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
     // to the default format for now.
     if (weight.IsMKLDNNData())
       // This asks the engine to reorder data after the weight array is used.
-      weight.Reorder2DefaultAsync();
+      const_cast<NDArray&>(weight).Reorder2DefaultAsync();
     weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), 
param.num_group);
   } else {
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
     if (weight.IsDefaultData()) {
-      weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), 
param.num_group);
       // We also need to modify the layout on the original weight array. The
       // data conversion happens after the weight array is used.
-      weight.MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
-    } else {
-      weight_mem = weight.GetMKLDNNData();
-      CHECK(weight_mem->get_primitive_desc() == 
fwd_pd.weights_primitive_desc());
+      
const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
     }
+    weight_mem = weight.GetMKLDNNData();
+    CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
   }
   auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
       fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
@@ -287,19 +286,19 @@ void MKLDNNDeconvForward::Execute(const 
std::vector<NDArray> &out_data) {
 
 static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
                                            const OpContext &ctx,
-                                           const std::vector<NDArray> &in_data,
+                                           const NDArray &bias,
                                            const std::vector<NDArray> 
&out_data) {
   // add bias, broadcast bias to dim 1: channel
   if (!param.no_bias) {
     // MKLDNN only supports float right now.
     typedef float DType;
     Stream<cpu> *s = ctx.get_stream<cpu>();
-    Tensor<cpu, 1, DType> bias = in_data[deconv::kBias].data().get<cpu, 1, 
DType>(s);
+    Tensor<cpu, 1, DType> b = bias.data().get<cpu, 1, DType>(s);
     // If the output data is stored in a special MKLDNN format, data()
     // automatically converts its format to the default format.
     // Unfortunately, MKLDNN doesn't support broadcast.
     Tensor<cpu, 4, DType> out_cpu = out_data[deconv::kOut].data().get<cpu, 4, 
DType>(s);
-    out_cpu += mshadow::expr::broadcast<1>(bias, out_cpu.shape_);
+    out_cpu += mshadow::expr::broadcast<1>(b, out_cpu.shape_);
   }
 }
 
@@ -344,15 +343,24 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& 
attrs, const OpContext &c
   TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
   const DeconvolutionParam& param = 
nnvm::get<DeconvolutionParam>(attrs.parsed);
 
+  auto data = in_data[deconv::kData];
+  if (data.IsView() && data.IsMKLDNNData())
+    data = data.Reorder2Default();
+
+  auto weight = in_data[deconv::kWeight];
+  if (weight.IsView() && weight.IsMKLDNNData())
+    weight = weight.Reorder2Default();
+
+  const NDArray* bias = param.no_bias ? nullptr : &in_data[deconv::kBias];
+
   MKLDNNDeconvForward &deconvFwd = GetDeconvFwd(
-      attrs, in_data[deconv::kData], in_data[deconv::kWeight],
-      param.no_bias ? nullptr : &in_data[deconv::kBias], 
out_data[deconv::kOut]);
+      attrs, data, weight, bias, out_data[deconv::kOut]);
 
-  deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data);
+  deconvFwd.SetDataHandle(param, ctx, data, weight, req, out_data);
 
   deconvFwd.Execute(out_data);
 
-  MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
+  MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data);
 }
 
 class MKLDNNDeconvBackwardData {
@@ -506,17 +514,24 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs 
&attrs,
   TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
   const std::vector<NDArray> &in_grad = outputs;
   const DeconvolutionParam &param = 
nnvm::get<DeconvolutionParam>(attrs.parsed);
+
+  auto data = inputs[deconv::kData + 1];
+  if (data.IsView() && data.IsMKLDNNData())
+    data = data.Reorder2Default();
+
+  auto weight = inputs[deconv::kWeight + 1];
+  if (weight.IsView() && weight.IsMKLDNNData())
+    weight = weight.Reorder2Default();
+
   CHECK_NE(req[deconv::kWeight], kWriteInplace)
       << "cannot write weight inplace";
   MKLDNNDeconvBackwardData &bwd_data =
-      GetDeconvBwdData(param, inputs[deconv::kData + 1],
-                       inputs[deconv::kWeight + 1], inputs[deconv::kOut]);
+      GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]);
   auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
       bwd_data.pd.src_primitive_desc());
   if (req[deconv::kData]) {
     auto weight_mem =
-        GetWeights(inputs[deconv::kWeight + 1],
-                   bwd_data.pd.weights_primitive_desc(), param.num_group);
+        GetWeights(weight, bwd_data.pd.weights_primitive_desc(), 
param.num_group);
     auto in_grad_mem =
         CreateMKLDNNMem(in_grad[deconv::kData],
                         bwd_data.pd.dst_primitive_desc(), req[deconv::kData]);
@@ -526,12 +541,12 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs 
&attrs,
   }
   if (req[deconv::kWeight]) {
     MKLDNNDeconvBackwardWeights &bwd_weights = GetDeconvBwdWeights(
-        param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1],
+        param, data, weight,
         inputs[deconv::kOut], bwd_data.pd);
     if (bwd_data.pd.src_primitive_desc() != 
bwd_weights.pd.src_primitive_desc())
       out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
           bwd_weights.pd.src_primitive_desc());
-    auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
+    auto data_mem = data.GetMKLDNNDataReorder(
         bwd_weights.pd.diff_dst_primitive_desc());
     auto in_grad_weight = CreateMKLDNNWeightGrad(
         in_grad[deconv::kWeight], bwd_weights.pd.diff_weights_primitive_desc(),
diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index 31fe5c7d7bb..c4218490068 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -159,7 +159,7 @@ inline static std::vector<mkldnn::memory::format> 
GetMKLDNNFormat(size_t num_dim
   }
 }
 
-inline static TestArrayShapes GetTestArrayShapes() {
+inline static TestArrayShapes GetTestArrayShapes(bool spatial_data_format = 
false) {
   int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
   std::vector<TShape> shapes;
   std::vector<mkldnn::memory::primitive_desc> pds;
@@ -198,7 +198,9 @@ inline static TestArrayShapes GetTestArrayShapes() {
     pds.push_back(GetMemPD(s2, dtype, mkldnn::memory::format::oihw));
 
     std::vector<mkldnn::memory::format> formats = GetMKLDNNFormat(4, dtype);
-    pds.push_back(GetMemPD(s1, dtype, formats[0]));
+    if (!spatial_data_format) {
+      pds.push_back(GetMemPD(s1, dtype, formats[0]));
+    }
   }
   {
     // 5D
@@ -208,7 +210,9 @@ inline static TestArrayShapes GetTestArrayShapes() {
     pds.push_back(GetMemPD(s, dtype, mkldnn::memory::format::goihw));
 
     std::vector<mkldnn::memory::format> formats = GetMKLDNNFormat(5, dtype);
-    pds.push_back(GetMemPD(s, dtype, formats[0]));
+    if (!spatial_data_format) {
+      pds.push_back(GetMemPD(s, dtype, formats[0]));
+    }
   }
 
   TestArrayShapes ret;
@@ -250,6 +254,38 @@ enum ArrayTypes {
   All = 8191,
 };
 
+
+inline NDArray CreateKernelNDArray(TShape kernel, int num_filters, TShape 
input,
+    bool is_deconv = false) {
+  CHECK_EQ(kernel.ndim(), 2) << "mkldnn only supports 2d filters on 4d inputs";
+  TShape target_shape(4);
+  target_shape[0] = is_deconv ? input[1] : num_filters;
+  target_shape[1] = is_deconv ? num_filters : input[1];
+  target_shape[2] = kernel[0];
+  target_shape[3] = kernel[1];
+  int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
+  NDArray arr(target_shape, Context());
+  auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::nchw);
+  InitMKLDNNArray(&arr, pd);
+  return arr;
+}
+
+inline NDArray CreateBiasNDArray(TShape target_shape) {
+  int dtype = mshadow::DataType<mshadow::default_real_t>::kFlag;
+  NDArray arr(target_shape, Context());
+  auto pd = GetMemPD(target_shape, dtype, mkldnn::memory::format::x);
+  InitMKLDNNArray(&arr, pd);
+  return arr;
+}
+
+inline int CalculateWidthConvOutput(int width, int kernel, int padding, int 
stride) {
+  return (width - kernel + 2 * padding) / stride  + 1;
+}
+
+inline int CalculateWidthDeconvOutput(int width, int kernel, int padding, int 
stride) {
+  return stride * (width - 1) + kernel - 2 * padding;
+}
+
 inline std::string CreateShapeString(int value, int dim) {
   std::stringstream ss;
   ss << "(";
@@ -293,21 +329,21 @@ inline void PrintVerifyMsg(const NDArrayAttrs &arr1, 
const NDArrayAttrs &arr2) {
  */
 inline std::vector<NDArrayAttrs> GetTestInputArrays(
     int types = ArrayTypes::All, bool rand = false,
-    int num_inputs = 1, int dim = 0) {
-  TestArrayShapes tas = GetTestArrayShapes();
+    std::vector<float> scale = {1}, bool spatial_data_format = false) {
+  TestArrayShapes tas = GetTestArrayShapes(spatial_data_format);
   std::vector<nnvm::TShape> shapes = tas.shapes;
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
 
   std::vector<NDArrayAttrs> in_arrs;
   std::string desc;
 
-  int slice_amount = 1;
-  if (dim == 0)
-    slice_amount = num_inputs;
+  int slice_amount = scale[0];
   for (auto shape : shapes) {
-    if (dim >= shape.ndim())
+    if (scale.size() > shape.ndim())
       continue;
-    shape[dim] = shape[dim] * num_inputs;
+
+    for (size_t dim = 0; dim < scale.size(); ++dim)
+      shape[dim] = static_cast<int>(round(shape[dim] * scale[dim]));
 
     // Type 1.
     NDArray arr(shape, Context());
@@ -326,12 +362,12 @@ inline std::vector<NDArrayAttrs> GetTestInputArrays(
 
 
     for (auto pd : pds) {
-      if (num_inputs > 1) {
+      for (size_t dim = 0; dim < scale.size(); ++dim) {
         // preserve if matching layout else just expand on 0 dim
         if (shape.ndim() == pd.desc().data.ndims)
-          pd = GetExpandedMemPD(pd, num_inputs, dim);
+          pd = GetExpandedMemPD(pd, scale[dim], dim);
         else
-          pd = GetExpandedMemPD(pd, num_inputs);
+          pd = GetExpandedMemPD(pd, scale[dim]);
       }
 
       if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc 
b/tests/cpp/operator/mkldnn_operator_test.cc
index 9e30cd8fa62..61fa1b0a743 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -35,6 +35,8 @@
 #include "../../src/operator/nn/mkldnn/mkldnn_ops-inl.h"
 #include "../../src/operator/nn/mkldnn/mkldnn_pooling-inl.h"
 #include "../../src/operator/nn/pooling-inl.h"
+#include "../../src/operator/nn/convolution-inl.h"
+#include "../../src/operator/nn/deconvolution-inl.h"
 #include "../include/test_mkldnn.h"
 
 using namespace mxnet;
@@ -243,6 +245,88 @@ OpAttrs GetFullyConnectedBackwardsOp() {
   return attrs;
 }
 
+OpAttrs GetConvOp(int kernel, int num_filters, int dim, int stride, int pad) {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("Convolution");
+  attrs.num_inputs = 3;
+  attrs.num_outputs = 1;
+  attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)});
+  attrs.attrs.dict.insert({"num_filter" , std::to_string(num_filters)});
+  attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)});
+  attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.input_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped |
+      ArrayTypes::NormalReused |
+      ArrayTypes::MKLDNNReused |
+      ArrayTypes::NormalReshapedReused;
+  attrs.output_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped |
+      ArrayTypes::NormalReused |
+      ArrayTypes::MKLDNNReused |
+      ArrayTypes::NormalReshapedReused |
+      ArrayTypes::NormalReusedDiffDtype;
+  return attrs;
+}
+
+OpAttrs GetConvBackwardOp(int kernel, int num_filters, int dim, int stride, 
int pad) {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_Convolution");
+  attrs.num_inputs = 4;
+  attrs.num_outputs = 3;
+  attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)});
+  attrs.attrs.dict.insert({"num_filter" , std::to_string(num_filters)});
+  attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)});
+  attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  return attrs;
+}
+
+OpAttrs GetDeconvOp(int kernel, int num_filters, int dim, int stride, int pad) 
{
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("Deconvolution");
+  attrs.num_inputs = 2;
+  attrs.num_outputs = 1;
+  attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)});
+  attrs.attrs.dict.insert({"num_filter" , std::to_string(num_filters)});
+  attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)});
+  attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.input_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped |
+      ArrayTypes::NormalReused |
+      ArrayTypes::MKLDNNReused |
+      ArrayTypes::NormalReshapedReused;
+  attrs.output_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped |
+      ArrayTypes::NormalReused |
+      ArrayTypes::MKLDNNReused |
+      ArrayTypes::NormalReshapedReused |
+      ArrayTypes::NormalReusedDiffDtype;
+  return attrs;
+}
+
+OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, 
int pad) {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_Deconvolution");
+  attrs.num_inputs = 3;
+  attrs.num_outputs = 2;
+  attrs.attrs.dict.insert({"kernel" , CreateShapeString(kernel, dim)});
+  attrs.attrs.dict.insert({"num_filter" , std::to_string(num_filters)});
+  attrs.attrs.dict.insert({"stride" , CreateShapeString(stride, dim)});
+  attrs.attrs.dict.insert({"pad" , CreateShapeString(pad, dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  return attrs;
+}
+
 void AssertEqual(const std::vector<NDArray *> &in_arrs,
                  const std::vector<NDArray *> &out_arrs,
                  float rtol = 1e-5, float atol = 1e-8) {
@@ -459,7 +543,11 @@ void TestConcatOp(const OpAttrs &attrs, VerifyFunc 
verify_fn,
   if (backwards) {
     std::string str_dim = const_cast<OpAttrs&>(attrs).attrs.dict["dim"];
     int dim = std::stoi(str_dim);
-    in_arrs = GetTestInputArrays(ArrayTypes::All, false, attrs.num_outputs, 
dim);
+    std::vector<float> scale_vector(dim+1);
+    for (size_t i = 0; i < dim+1; ++i)
+      scale_vector[i] = 1;
+    scale_vector[dim] = attrs.num_outputs;
+    in_arrs = GetTestInputArrays(ArrayTypes::All, false, scale_vector);
   }
 
   for (auto &in_arr : in_arrs) {
@@ -706,6 +794,134 @@ void TestFullyConnectedOp(const OpAttrs &forward_attrs, 
const OpAttrs &backwards
   }
 }
 
+template<typename P>
+void TestConvOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs,
+                bool is_deconv = false) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+
+  std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+  std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+  std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+
+
+  std::vector<OpReqType> req(forward_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+  std::vector<DispatchMode> dispatches = forward_attrs.dispatches;
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  P param;
+  param.Init(forward_attrs.attrs.dict);
+  TShape kernel = param.kernel;
+  TShape padding = param.pad;
+  TShape stride = param.stride;
+  int num_filter = param.num_filter;
+
+  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(
+      forward_attrs.input_types, true, {1}, true);
+  std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+  std::vector<std::vector<NDArrayAttrs>> 
ex_out_arrs(forward_attrs.num_outputs);
+
+  for (size_t i1 = 0; i1 < in_arrs.size(); ++i1) {
+    auto in_arr = in_arrs[i1];
+
+    // can only conv only 4D inputs
+    TShape input_shape = in_arr.arr.shape();
+    if (input_shape.ndim() != kernel.ndim() + 2)
+      continue;
+
+    float scale = CalculateWidthConvOutput(input_shape[2], kernel[0], 
padding[0], stride[0])
+        / static_cast<float>(input_shape[2]);
+
+    if (is_deconv) {
+      scale = CalculateWidthDeconvOutput(input_shape[2], kernel[0], 
padding[0], stride[0])
+        / static_cast<float>(input_shape[2]);
+    }
+    std::vector<float> scale_vector(in_arr.arr.shape().ndim());
+    scale_vector[0] = 1;
+    scale_vector[1] = static_cast<float>(num_filter) / input_shape[1];
+    scale_vector[2] = scale;
+    scale_vector[3] = scale;
+
+    for (size_t i = 0; i < forward_attrs.num_outputs; ++i) {
+      out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds,
+                                        scale_vector, true, 
forward_attrs.output_types);
+      ex_out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds,
+                                           scale_vector, true, 
forward_attrs.output_types);
+    }
+    NDArray ndkernel = CreateKernelNDArray(kernel, num_filter, 
in_arr.arr.shape(), is_deconv);
+    TShape bias_shape = {num_filter};
+    NDArray ndbias = CreateBiasNDArray(bias_shape);
+    inputs[0] = &in_arr.arr;
+    inputs[1] = &ndkernel;
+
+    if (!param.no_bias) {
+      inputs[2] = &ndbias;
+    }
+
+    for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+      for (size_t i = 0; i < forward_attrs.num_outputs; ++i) {
+        req[i] = kWriteTo;
+        outputs[i] = &out_arrs[i][output_i].arr;
+        ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
+      }
+      Imperative::Get()->set_is_training(true);
+
+      PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+      Imperative::Get()->InvokeOp(Context(), forward_attrs.attrs, inputs,
+                                  outputs, req, DispatchMode::kFCompute, 
mxnet::OpStatePtr());
+      Imperative::Get()->InvokeOp(Context(), forward_attrs.attrs, inputs,
+                                  ex_outputs, req, DispatchMode::kFComputeEx, 
mxnet::OpStatePtr());
+      Engine::Get()->WaitForAll();
+      VerifyCopyResult(outputs, ex_outputs);
+
+      // backwards test performed same time since output needed
+      backwards_input[0] = outputs[0];  // output grad
+      backwards_input[1] = inputs[0];  // input
+      backwards_input[2] = inputs[1];  // kernel
+
+      if (!param.no_bias) {
+        backwards_input[3] = inputs[2];  // bias
+      }
+
+      auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true, 
{1}, true)[i1];
+      NDArray tmp_kernel = CreateKernelNDArray(kernel, num_filter, 
in_arr.arr.shape(), is_deconv);
+      NDArray tmp_bias = CreateBiasNDArray(bias_shape);
+      backwards_outputs[0] = &tmp_output.arr;
+      backwards_outputs[1] = &tmp_kernel;
+      if (!param.no_bias) {
+        backwards_outputs[2] = &tmp_bias;
+      }
+
+      auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true, 
{1}, true)[i1];
+      NDArray tmp_kernel2 = CreateKernelNDArray(kernel, num_filter, 
in_arr.arr.shape(), is_deconv);
+      NDArray tmp_bias2 = CreateBiasNDArray(bias_shape);
+      backwards_ex_outputs[0] = &tmp_output2.arr;
+      backwards_ex_outputs[1] = &tmp_kernel2;
+      if (!param.no_bias) {
+        backwards_ex_outputs[2] = &tmp_bias2;
+      }
+
+      for (size_t i = 0; i < backwards_attrs.num_outputs; ++i)
+        back_req[i] = kWriteTo;
+
+      std::cout << "Backwards: ";
+      PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
+      Imperative::Get()->InvokeOp(
+          Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+          back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+      Imperative::Get()->InvokeOp(
+          Context(), backwards_attrs.attrs, backwards_input, 
backwards_ex_outputs,
+          back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+      Engine::Get()->WaitForAll();
+      VerifyCopyResult(backwards_outputs, backwards_ex_outputs);
+    }
+  }
+}
+
 void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
   std::vector<NDArray*> inputs(forward_attrs.num_inputs);
   std::vector<NDArray*> outputs(forward_attrs.num_outputs);
@@ -888,4 +1104,38 @@ TEST(IMPERATIVE, PoolingOp) {
   }
 }
 
+TEST(IMPERATIVE, ConvOp) {
+  int dim = 2;  // MKLDNN conv only supports 2d kernels
+  for (size_t num_filters = 2; num_filters < 3; ++num_filters) {
+    for (size_t kernel = 1; kernel < 4; ++kernel) {
+      for (size_t stride = 1; stride < 3; ++stride) {
+        for (size_t pad = 0; pad < 2; ++pad) {
+          if (kernel / 2. < pad)
+            continue;
+          OpAttrs forward_attrs = GetConvOp(kernel, num_filters, dim, stride, 
pad);
+          OpAttrs backwards_attrs = GetConvBackwardOp(kernel, num_filters, 
dim, stride, pad);
+          TestConvOp<mxnet::op::ConvolutionParam>(forward_attrs, 
backwards_attrs);
+        }
+      }
+    }
+  }
+}
+
+TEST(IMPERATIVE, DeconvOp) {
+  int dim = 2;  // MKLDNN deconv only supports 2d kernels
+  for (size_t num_filters = 2; num_filters < 3; ++num_filters) {
+    for (size_t kernel = 1; kernel < 3; ++kernel) {
+      for (size_t stride = 1; stride < 3; ++stride) {
+        for (size_t pad = 0; pad < 2; ++pad) {
+          if (kernel / 2. < pad)
+            continue;
+          OpAttrs forward_attrs = GetDeconvOp(kernel, num_filters, dim, 
stride, pad);
+          OpAttrs backwards_attrs = GetDeconvBackwardOp(kernel, num_filters, 
dim, stride, pad);
+          TestConvOp<mxnet::op::DeconvolutionParam>(forward_attrs, 
backwards_attrs, true);
+        }
+      }
+    }
+  }
+}
+
 #endif
diff --git a/tests/cpp/operator/mkldnn_test.cc 
b/tests/cpp/operator/mkldnn_test.cc
index 5e7e8d8b205..31e762f2172 100644
--- a/tests/cpp/operator/mkldnn_test.cc
+++ b/tests/cpp/operator/mkldnn_test.cc
@@ -351,8 +351,12 @@ TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
   auto in_arrs = GetTestInputArrays();
   for (int dim = 0; dim < 5; dim++) {
     for (int num_inputs = 2; num_inputs < 5; num_inputs++) {
+      std::vector<float> scale_vector(dim + 1);
+      for (size_t i = 0; i < dim + 1; ++i)
+        scale_vector[i] = 1;
+      scale_vector[dim] = num_inputs;
       std::vector<NDArrayAttrs> expanded_arrs = GetTestInputArrays(
-          ArrayTypes::All, false, num_inputs, dim);
+          ArrayTypes::All, false, scale_vector);
       int i = 0;
       for (auto &arr : in_arrs) {
         if (dim >= arr.arr.shape().ndim())


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to