azai91 closed pull request #13603: [WIP] Test/mkldnn batch norm op 2
URL: https://github.com/apache/incubator-mxnet/pull/13603
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 6254a1e1866..e7108241257 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -380,12 +380,20 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
 }
 
 #if MXNET_USE_MKLDNN == 1
-static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam 
&param) {
-  TShape shape = input.shape();
-  return SupportMKLDNN(input) && shape.ndim() == 4
+static inline bool SupportMKLDNNBN(const std::vector<NDArray> &inputs,
+    const BatchNormParam &param) {
+  TShape shape = inputs[0].shape();
+  bool params_valid = shape.ndim() == 4
       && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
       && shape[param.axis] % 8 == 0
       && !mxnet::op::batchnorm::disable_mkl;
+  bool inputs_valid = SupportMKLDNN(inputs[0]);
+  for (size_t i = 1; i < inputs.size(); i++) {
+    if (inputs[i].IsMKLDNNData()) {
+      inputs_valid = false;
+    }
+  }
+  return  params_valid && inputs_valid;
 }
 
 void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
@@ -396,7 +404,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
   CHECK_EQ(inputs.size(), 5U);
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
   // MKLDNN batchnorm only works well on the special MKLDNN layout.
-  if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {
+  if (SupportMKLDNNBN(inputs, param) && inputs[0].IsMKLDNNData()) {
     std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + 
batchnorm::kInMovingMean);
     std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, 
inputs.end());
 
@@ -420,7 +428,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
 
   TShape shape = inputs[0].shape();
   // MKLDNN batchnorm only works well on the special MKLDNN layout.
-  if (SupportMKLDNNBN(inputs[0], param)
+  if (SupportMKLDNNBN(inputs, param)
       && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
     std::vector<NDArray> out_grad(1);
     std::vector<NDArray> out_data(3);
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h 
b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 403baaa94ab..7638e8bcf52 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -216,14 +216,24 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
   auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
   const NDArray &out  = out_data[batchnorm::kOut];
 
+  auto gamma_buffer = in_data[batchnorm::kGamma];
+  if (gamma_buffer.IsMKLDNNData()) {
+    gamma_buffer = gamma_buffer.Reorder2Default();
+  }
+
+  auto beta_buffer = in_data[batchnorm::kBeta];
+  if (beta_buffer.IsMKLDNNData()) {
+    beta_buffer = beta_buffer.Reorder2Default();
+  }
+
   // for output memory
   auto out_mem = const_cast<NDArray 
&>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());
 
   // mxnet will always use scale shift.
   // But if fix_gamma is true, then all scale elements will be set to 1.0f
   if (flags & use_scale_shift) {
-    const NDArray &gamma    = in_data[batchnorm::kGamma];
-    const NDArray &beta     = in_data[batchnorm::kBeta];
+    const NDArray &gamma    = gamma_buffer;
+    const NDArray &beta     = beta_buffer;
     CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
     CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
 
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc 
b/tests/cpp/operator/mkldnn_operator_test.cc
index a500d4c2df6..92b78a7367e 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, 
int dim, int stride, in
   return attrs;
 }
 
+OpAttrs GetBNOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("BatchNorm");
+  attrs.num_inputs = 5;
+  attrs.num_outputs = 3;
+  attrs.accept_dims.insert(4);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.input_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN;
+  attrs.output_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN;
+  return attrs;
+}
+
+OpAttrs GetBNBackwardOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_BatchNorm");
+  attrs.num_inputs = 8;
+  attrs.num_outputs = 3;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  return attrs;
+}
+
 void AssertEqual(const std::vector<NDArray *> &in_arrs,
                  const std::vector<NDArray *> &out_arrs,
                  float rtol = 1e-5, float atol = 1e-8) {
@@ -614,19 +639,42 @@ void TestOpExBackward(const OpAttrs &forward_attrs,
                       const NDArrayAttrs &in_arr,
                       const NDArrayAttrs &out_arr) {
   std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+
+  std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
+  std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);
+
   std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
   std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
   std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
 
   if (req == kWriteTo) {
     // backwards test performed same time since output needed
-    backwards_input[0] = outputs[0];  // output grad
-    backwards_input[1] = inputs[0];  // input
-    backwards_input[2] = outputs[1];  // out norm
 
-    auto tmp_output = in_arr.arr;
-    backwards_outputs[0] = &tmp_output;
-    backwards_ex_outputs[0] = &tmp_output;
+    if (forward_attrs.attrs.op->name.compare("BatchNorm") == 0) {
+      backwards_input[0] = outputs[0];  // output grad
+      backwards_input[1] = outputs[1];  // mean
+      backwards_input[2] = outputs[2];  // var
+      backwards_input[3] = inputs[0];  // data
+      backwards_input[4] = inputs[1];  // gamma
+      backwards_input[5] = inputs[2];  // beta
+      backwards_input[6] = inputs[3];  // moving mean
+      backwards_input[7] = inputs[4];  // moving var
+    } else {
+      backwards_input[0] = outputs[0];  // output grad
+      backwards_input[1] = inputs[0];  // input
+      backwards_input[2] = outputs[1];  // out norm
+    }
+
+    for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
+      auto tmp_output = in_arr.arr;
+      backwards_buffer.emplace_back(tmp_output.Copy(Context()));
+      backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
+      backwards_buffer.back().CopyFrom(*tmp_output.GetMKLDNNData());
+      backwards_buffer2.back().CopyFrom(*tmp_output.GetMKLDNNData());
+      backwards_outputs[i] = &backwards_buffer.back();
+      backwards_ex_outputs[i] = &backwards_buffer.back();
+    }
+
 
     for (int i = 0; i < backwards_attrs.num_outputs; i++)
       back_req[i] = kWriteTo;
@@ -648,6 +696,11 @@ void TestOpExBackward(const OpAttrs &forward_attrs,
 // compares output of fcompute with fcomputex
 void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
   std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
+  std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
+  std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
+  std::vector<const mkldnn::memory*> inputs_mem(forward_attrs.num_inputs);
+  std::vector<const mkldnn::memory*> inputs2_mem(forward_attrs.num_inputs);
   std::vector<NDArray*> outputs(forward_attrs.num_outputs);
   std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
   std::vector<OpReqType> req(forward_attrs.num_outputs);
@@ -655,7 +708,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
   TestArrayShapes tas = GetTestArrayShapes();
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
 
-  std::vector<NDArrayAttrs> in_arrs = 
GetTestInputArrays(forward_attrs.input_types, true);
+  std::vector<NDArrayAttrs> in_arrs = 
GetTestInputArrays(forward_attrs.input_types, false);
   std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
   std::vector<std::vector<NDArrayAttrs>> 
ex_out_arrs(forward_attrs.num_outputs);
 
@@ -670,15 +723,33 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
 
       for (int i = 0; i < forward_attrs.num_outputs; i++) {
         out_arrs[i] =
-            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, 
forward_attrs.output_types);
+            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, 
forward_attrs.output_types);
         ex_out_arrs[i] =
-            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, 
forward_attrs.output_types);
+            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, 
forward_attrs.output_types);
       }
 
-      for (int i = 0; i < forward_attrs.num_inputs; i++)
-        inputs[i] = &in_arr.arr;
-
       for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+        inputs_buffer.clear();
+        inputs2_buffer.clear();
+        inputs_mem.clear();
+        inputs2_mem.clear();
+
+        for (int i = 0; i < forward_attrs.num_inputs; i++) {
+          inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
+          inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
+
+          if (in_arr.arr.IsMKLDNNData()) {
+            inputs_mem.emplace_back(in_arr.arr.GetMKLDNNData());
+            inputs2_mem.emplace_back(in_arr.arr.GetMKLDNNData());
+            inputs_buffer.back().CopyFrom(*inputs_mem.back());
+            inputs2_buffer.back().CopyFrom(*inputs2_mem.back());
+          }
+          Engine::Get()->WaitForAll();
+          inputs[i] = &inputs_buffer.back();
+          inputs2[i] = &inputs2_buffer.back();
+        }
+
+
         for (int i = 0; i < forward_attrs.num_outputs; i++) {
           req[i] = kWriteTo;
           outputs[i] = &out_arrs[i][output_i].arr;
@@ -691,7 +762,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
             Context(), forward_attrs.attrs, inputs, outputs, req,
             DispatchMode::kFCompute, mxnet::OpStatePtr());
         Imperative::Get()->InvokeOp(
-            Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+            Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
             DispatchMode::kFComputeEx, mxnet::OpStatePtr());
         Engine::Get()->WaitForAll();
         AssertEqual(outputs, ex_outputs);
@@ -1204,4 +1275,10 @@ TEST(IMPERATIVE, DeconvOp) {
   }
 }
 
+TEST(IMPERATIVE, BNOp) {
+  OpAttrs forward_attrs = GetBNOp();
+  OpAttrs backwards_attrs = GetBNBackwardOp();
+  TestOpEx(forward_attrs, backwards_attrs);
+}
+
 #endif


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to