This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cf6ae59 adding unittest for MKLDNN Softmax operator (#12884)
cf6ae59 is described below
commit cf6ae59058435fd2bfa0cbe5ceb1f9ce82e92b70
Author: Manu Seth <[email protected]>
AuthorDate: Mon Nov 19 22:26:36 2018 -0800
adding unittest for MKLDNN Softmax operator (#12884)
* adding unit test for MKLDNN FullyConnected operator
* adding unit test for MKLDNN Softmax operator
* merging TestSoftmaxOp with TestOpEx
* adding unit test for MKLDNN Softmax operator
* adding accepted dimension support
* changing set to unordered set
* adding check for non-empty accept_dims
---
tests/cpp/include/test_mkldnn.h | 1 +
tests/cpp/operator/mkldnn_operator_test.cc | 132 +++++++++++++++++++++--------
2 files changed, 100 insertions(+), 33 deletions(-)
diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index c421849..c705a60 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -231,6 +231,7 @@ struct OpAttrs {
nnvm::NodeAttrs attrs;
std::vector<DispatchMode> dispatches;
std::set<OpReqType> requests;
+ std::unordered_set<int> accept_dims;
int num_inputs;
int num_outputs;
int input_types;
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc
b/tests/cpp/operator/mkldnn_operator_test.cc
index 61fa1b0..a500d4c 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -190,7 +190,7 @@ OpAttrs GetLRNOp() {
attrs.num_outputs = 2;
attrs.attrs.dict.insert({"nsize" , "3"});
attrs.attrs.op->attr_parser(&attrs.attrs);
- attrs.dispatches.resize(2);
+ attrs.accept_dims.insert(4);
attrs.requests.insert(OpReqType::kWriteTo);
attrs.input_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN |
@@ -215,6 +215,26 @@ OpAttrs GetLRNBackwardsOp() {
return attrs;
}
+OpAttrs GetSoftmaxOp() {
+ OpAttrs attrs;
+ attrs.attrs.op = Op::Get("softmax");
+ attrs.num_inputs = 1;
+ attrs.num_outputs = 1;
+ attrs.attrs.op->attr_parser(&attrs.attrs);
+ attrs.accept_dims.insert({1, 2, 3, 4, 5});
+ attrs.requests.insert(OpReqType::kWriteTo);
+ attrs.requests.insert(OpReqType::kWriteInplace);
+ attrs.input_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN |
+ ArrayTypes::NormalReshaped |
+ ArrayTypes::MKLDNNReshaped;
+ attrs.output_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN |
+ ArrayTypes::NormalReshaped |
+ ArrayTypes::MKLDNNReshaped;
+ return attrs;
+}
+
OpAttrs GetFullyConnectedOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("FullyConnected");
@@ -586,19 +606,51 @@ void TestConcatOp(const OpAttrs &attrs, VerifyFunc
verify_fn,
}
}
-// compares output of fcompute with fcomputex
-void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
- std::vector<NDArray*> inputs(forward_attrs.num_inputs);
- std::vector<NDArray*> outputs(forward_attrs.num_outputs);
- std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
-
+void TestOpExBackward(const OpAttrs &forward_attrs,
+ const OpAttrs &backwards_attrs,
+ const OpReqType &req,
+ const std::vector<NDArray*> &inputs,
+ const std::vector<NDArray*> &outputs,
+ const NDArrayAttrs &in_arr,
+ const NDArrayAttrs &out_arr) {
std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+ std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+ if (req == kWriteTo) {
+ // backwards test performed same time since output needed
+ backwards_input[0] = outputs[0]; // output grad
+ backwards_input[1] = inputs[0]; // input
+ backwards_input[2] = outputs[1]; // out norm
+
+ auto tmp_output = in_arr.arr;
+ backwards_outputs[0] = &tmp_output;
+ backwards_ex_outputs[0] = &tmp_output;
+
+ for (int i = 0; i < backwards_attrs.num_outputs; i++)
+ back_req[i] = kWriteTo;
+
+ std::cout << "Backwards: ";
+ PrintVerifyMsg(out_arr, in_arr);
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+ back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input,
backwards_ex_outputs,
+ back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(backwards_outputs, backwards_ex_outputs);
+ }
+}
+// compares output of fcompute with fcomputex
+void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+ std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+ std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+ std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
std::vector<OpReqType> req(forward_attrs.num_outputs);
- std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
@@ -611,8 +663,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs
&backwards_attrs) {
for (int i1 = 0; i1 < in_arrs.size(); i1++) {
auto in_arr = in_arrs[i1];
- // TODO(alex): (MXNET-845) Remove when MKLDNN supports other dims
- if (in_arr.arr.shape().ndim() != 4)
+ CHECK_NE(forward_attrs.accept_dims.size(), 0);
+ if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
+ forward_attrs.accept_dims.end())
continue;
for (int i = 0; i < forward_attrs.num_outputs; i++) {
@@ -626,9 +679,6 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs
&backwards_attrs) {
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
- if (out_arrs[0][output_i].arr.IsMKLDNNData())
- continue;
-
for (int i = 0; i < forward_attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
@@ -646,31 +696,41 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs
&backwards_attrs) {
Engine::Get()->WaitForAll();
AssertEqual(outputs, ex_outputs);
- // backwards test performed same time since output needed
- backwards_input[0] = outputs[0]; // output grad
- backwards_input[1] = inputs[0]; // input
- backwards_input[2] = outputs[1]; // out norm
+ if (!backwards_attrs.requests.empty()) {
+ TestOpExBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
+ inputs, outputs, in_arr, out_arrs[0][output_i]);
+ }
+ }
+ }
+ }
- auto tmp_output = GetTestInputArrays(forward_attrs.input_types,
true)[i1];
- backwards_outputs[0] = &tmp_output.arr;
+ if (forward_attrs.requests.find(OpReqType::kWriteInplace) !=
forward_attrs.requests.end()) {
+ for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+ auto in_arr = in_arrs[i1];
- auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types,
true)[i1];
- backwards_ex_outputs[0] = &tmp_output2.arr;
+ // If the array is a view, we shouldn't write data to it.
+ if (in_arr.arr.IsView())
+ continue;
- for (int i = 0; i < backwards_attrs.num_outputs; i++)
- back_req[i] = kWriteTo;
+ NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
+ for (int i = 0; i < forward_attrs.num_inputs; i++)
+ inputs[i] = &in_arr.arr;
- std::cout << "Backwards: ";
- PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
- Imperative::Get()->InvokeOp(
- Context(), backwards_attrs.attrs, backwards_input,
backwards_outputs,
- back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
- Imperative::Get()->InvokeOp(
- Context(), backwards_attrs.attrs, backwards_input,
backwards_ex_outputs,
- back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
- Engine::Get()->WaitForAll();
- AssertEqual(backwards_outputs, backwards_ex_outputs);
+ for (int i = 0; i < forward_attrs.num_outputs; i++) {
+ req[i] = kWriteInplace;
+ outputs[i] = &in_arr.arr;
+ ex_outputs[i] = &in_arr.arr;
}
+ Imperative::Get()->set_is_training(true);
+ PrintVerifyMsg(orig, in_arr);
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs, outputs, req,
+ DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+ DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(outputs, ex_outputs);
}
}
}
@@ -1082,6 +1142,12 @@ TEST(IMPERATIVE, LRNOp) {
TestOpEx(forward_attrs, backwards_attrs);
}
+TEST(IMPERATIVE, SoftmaxOp) {
+ OpAttrs forward_attrs = GetSoftmaxOp();
+ OpAttrs backwards_attrs;
+ TestOpEx(forward_attrs, backwards_attrs);
+}
+
TEST(IMPERATIVE, FullyConnectedOp) {
OpAttrs forward_attrs = GetFullyConnectedOp();
OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();