This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cf6ae59  adding unittest for MKLDNN Softmax operator (#12884)
cf6ae59 is described below

commit cf6ae59058435fd2bfa0cbe5ceb1f9ce82e92b70
Author: Manu Seth <[email protected]>
AuthorDate: Mon Nov 19 22:26:36 2018 -0800

    adding unittest for MKLDNN Softmax operator (#12884)
    
    * adding unit test for MKLDNN FullyConnected operator
    
    * adding unit test for MKLDNN Softmax operator
    
    * merging TestSoftmaxOp with TestOpEx
    
    * adding unit test for MKLDNN Softmax operator
    
    * adding accepted dimension support
    
    * changing set to unordered set
    
    * adding check for non-empty accept_dims
---
 tests/cpp/include/test_mkldnn.h            |   1 +
 tests/cpp/operator/mkldnn_operator_test.cc | 132 +++++++++++++++++++++--------
 2 files changed, 100 insertions(+), 33 deletions(-)

diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index c421849..c705a60 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -231,6 +231,7 @@ struct OpAttrs {
   nnvm::NodeAttrs attrs;
   std::vector<DispatchMode> dispatches;
   std::set<OpReqType> requests;
+  std::unordered_set<int> accept_dims;
   int num_inputs;
   int num_outputs;
   int input_types;
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc 
b/tests/cpp/operator/mkldnn_operator_test.cc
index 61fa1b0..a500d4c 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -190,7 +190,7 @@ OpAttrs GetLRNOp() {
   attrs.num_outputs = 2;
   attrs.attrs.dict.insert({"nsize" , "3"});
   attrs.attrs.op->attr_parser(&attrs.attrs);
-  attrs.dispatches.resize(2);
+  attrs.accept_dims.insert(4);
   attrs.requests.insert(OpReqType::kWriteTo);
   attrs.input_types = ArrayTypes::Normal |
       ArrayTypes::MKLDNN |
@@ -215,6 +215,26 @@ OpAttrs GetLRNBackwardsOp() {
   return attrs;
 }
 
+OpAttrs GetSoftmaxOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("softmax");
+  attrs.num_inputs = 1;
+  attrs.num_outputs = 1;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.accept_dims.insert({1, 2, 3, 4, 5});
+  attrs.requests.insert(OpReqType::kWriteTo);
+  attrs.requests.insert(OpReqType::kWriteInplace);
+  attrs.input_types = ArrayTypes::Normal |
+    ArrayTypes::MKLDNN |
+    ArrayTypes::NormalReshaped |
+    ArrayTypes::MKLDNNReshaped;
+  attrs.output_types = ArrayTypes::Normal |
+    ArrayTypes::MKLDNN |
+    ArrayTypes::NormalReshaped |
+    ArrayTypes::MKLDNNReshaped;
+  return attrs;
+}
+
 OpAttrs GetFullyConnectedOp() {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("FullyConnected");
@@ -586,19 +606,51 @@ void TestConcatOp(const OpAttrs &attrs, VerifyFunc 
verify_fn,
   }
 }
 
-// compares output of fcompute with fcomputex
-void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
-  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
-  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
-  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
-
+void TestOpExBackward(const OpAttrs &forward_attrs,
+                      const OpAttrs &backwards_attrs,
+                      const OpReqType &req,
+                      const std::vector<NDArray*> &inputs,
+                      const std::vector<NDArray*> &outputs,
+                      const NDArrayAttrs &in_arr,
+                      const NDArrayAttrs &out_arr) {
   std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
   std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
   std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+  if (req == kWriteTo) {
+    // backwards test performed same time since output needed
+    backwards_input[0] = outputs[0];  // output grad
+    backwards_input[1] = inputs[0];  // input
+    backwards_input[2] = outputs[1];  // out norm
+
+    auto tmp_output = in_arr.arr;
+    backwards_outputs[0] = &tmp_output;
+    backwards_ex_outputs[0] = &tmp_output;
+
+    for (int i = 0; i < backwards_attrs.num_outputs; i++)
+      back_req[i] = kWriteTo;
+
+    std::cout << "Backwards: ";
+    PrintVerifyMsg(out_arr, in_arr);
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+        back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, 
backwards_ex_outputs,
+        back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+    Engine::Get()->WaitForAll();
+    AssertEqual(backwards_outputs, backwards_ex_outputs);
+  }
+}
 
 
+// compares output of fcompute with fcomputex
+void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
   std::vector<OpReqType> req(forward_attrs.num_outputs);
-  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
 
   TestArrayShapes tas = GetTestArrayShapes();
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
@@ -611,8 +663,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
     for (int i1 = 0; i1 < in_arrs.size(); i1++) {
       auto in_arr = in_arrs[i1];
 
-      // TODO(alex): (MXNET-845) Remove when MKLDNN supports other dims
-      if (in_arr.arr.shape().ndim() != 4)
+      CHECK_NE(forward_attrs.accept_dims.size(), 0);
+      if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
+          forward_attrs.accept_dims.end())
         continue;
 
       for (int i = 0; i < forward_attrs.num_outputs; i++) {
@@ -626,9 +679,6 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
         inputs[i] = &in_arr.arr;
 
       for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
-        if (out_arrs[0][output_i].arr.IsMKLDNNData())
-          continue;
-
         for (int i = 0; i < forward_attrs.num_outputs; i++) {
           req[i] = kWriteTo;
           outputs[i] = &out_arrs[i][output_i].arr;
@@ -646,31 +696,41 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs 
&backwards_attrs) {
         Engine::Get()->WaitForAll();
         AssertEqual(outputs, ex_outputs);
 
-        // backwards test performed same time since output needed
-        backwards_input[0] = outputs[0];  // output grad
-        backwards_input[1] = inputs[0];  // input
-        backwards_input[2] = outputs[1];  // out norm
+        if (!backwards_attrs.requests.empty()) {
+          TestOpExBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
+                           inputs, outputs, in_arr, out_arrs[0][output_i]);
+        }
+      }
+    }
+  }
 
-        auto tmp_output = GetTestInputArrays(forward_attrs.input_types, 
true)[i1];
-        backwards_outputs[0] = &tmp_output.arr;
+  if (forward_attrs.requests.find(OpReqType::kWriteInplace) != 
forward_attrs.requests.end()) {
+    for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+      auto in_arr = in_arrs[i1];
 
-        auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, 
true)[i1];
-        backwards_ex_outputs[0] = &tmp_output2.arr;
+      // If the array is a view, we shouldn't write data to it.
+      if (in_arr.arr.IsView())
+          continue;
 
-        for (int i = 0; i < backwards_attrs.num_outputs; i++)
-          back_req[i] = kWriteTo;
+      NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
+      for (int i = 0; i < forward_attrs.num_inputs; i++)
+        inputs[i] = &in_arr.arr;
 
-        std::cout << "Backwards: ";
-        PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
-        Imperative::Get()->InvokeOp(
-            Context(), backwards_attrs.attrs, backwards_input, 
backwards_outputs,
-            back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
-        Imperative::Get()->InvokeOp(
-            Context(), backwards_attrs.attrs, backwards_input, 
backwards_ex_outputs,
-            back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
-        Engine::Get()->WaitForAll();
-        AssertEqual(backwards_outputs, backwards_ex_outputs);
+      for (int i = 0; i < forward_attrs.num_outputs; i++) {
+        req[i] = kWriteInplace;
+        outputs[i] = &in_arr.arr;
+        ex_outputs[i] = &in_arr.arr;
       }
+      Imperative::Get()->set_is_training(true);
+      PrintVerifyMsg(orig, in_arr);
+      Imperative::Get()->InvokeOp(
+          Context(), forward_attrs.attrs, inputs, outputs, req,
+          DispatchMode::kFCompute, mxnet::OpStatePtr());
+      Imperative::Get()->InvokeOp(
+          Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+          DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+      Engine::Get()->WaitForAll();
+      AssertEqual(outputs, ex_outputs);
     }
   }
 }
@@ -1082,6 +1142,12 @@ TEST(IMPERATIVE, LRNOp) {
   TestOpEx(forward_attrs, backwards_attrs);
 }
 
+TEST(IMPERATIVE, SoftmaxOp) {
+  OpAttrs forward_attrs = GetSoftmaxOp();
+  OpAttrs backwards_attrs;
+  TestOpEx(forward_attrs, backwards_attrs);
+}
+
 TEST(IMPERATIVE, FullyConnectedOp) {
   OpAttrs forward_attrs = GetFullyConnectedOp();
   OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();

Reply via email to