This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch v1.2.0 in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 62a47a7ecf203138b0b0b953bcc7c6fceb1ba0e1 Author: Da Zheng <zhengda1...@gmail.com> AuthorDate: Fri May 25 10:11:45 2018 -0700 Fix bugs in MKLDNN. (#10979) * Fix bugs in MKLDNN. * add more test cases. * Fix CopyFrom when it's the view of an NDArray. * add test. * check same shape correctly. * add unit test for CopyFrom. * Fix warning. * Add test sum. * fix sum. * Fix fallback. * Fix fallback of sum. * add tests. * Update mkldnn.cc --- src/ndarray/ndarray.cc | 111 +++++++++------- src/operator/nn/mkldnn/mkldnn_base.cc | 5 +- src/operator/nn/mkldnn/mkldnn_sum.cc | 22 +++- src/operator/tensor/elemwise_binary_op_basic.cc | 12 +- tests/cpp/operator/mkldnn.cc | 165 +++++++++++++++++++++--- 5 files changed, 235 insertions(+), 80 deletions(-) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 6a8bc9d..fc01c75 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -200,6 +200,7 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { ret.ptr_->delay_alloc = false; ret.ptr_->static_data = true; ret.byte_offset_ = byte_offset_; + ret.reuse_ = false; return ret; } } @@ -217,6 +218,7 @@ NDArray NDArray::Reshape(const TShape &shape) const { // Otherwise, reshape only works on the default layout. CHECK_EQ(storage_type(), kDefaultStorage); ret.shape_ = shape; + ret.reuse_ = false; return ret; } @@ -249,6 +251,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); }); + ret.reuse_ = false; ret.shape_[0] = end - begin; return ret; } @@ -554,6 +557,7 @@ NDArray NDArray::Reorder2Default() const { // reshape as needed ret.shape_ = shape_; ret.byte_offset_ = byte_offset_; + ret.reuse_ = false; return ret; } @@ -583,39 +587,39 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc) const mkldnn::memory *NDArray::GetMKLDNNData() const { CHECK(storage_type() == kDefaultStorage); + bool is_view = IsView(); if (IsMKLDNNData()) { // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. - CHECK(!IsView()); + CHECK(!is_view); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); // If this array uses MKLDNN format, we should return now. Otherwise, // SetMKLMem may mess up mkl_mem_. return ptr_->mkl_mem_->GetRaw(); - } - ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); - if (IsView()) { - mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(); - // Sliced array must use the default layout. - CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format); - void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle()) - + byte_offset_; - + } else if (is_view) { + // If this is a view, we can't create a MKLDNN memory for the chunk + // because we don't have the complete data type and shape information for + // the chunk. + void *off_addr = static_cast<char *>(ptr_->shandle.dptr) + byte_offset_; // Create the primitive desc for the new mkldnn memory. mkldnn::memory::dims dims(shape().ndim()); for (size_t i = 0; i < dims.size(); i++) dims[i] = shape()[i]; mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>( GetDefaultFormat(shape().ndim())); - mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>( - pd.desc().data.data_type); + mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_); mkldnn::memory::desc data_md(dims, cpp_type, cpp_format); - mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine()); + mkldnn::memory::primitive_desc new_pd(data_md, + CpuEngine::Get()->get_engine()); std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr)); MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } else { + // If this isn't a view, we can create a MKLDNN memory and store it in the + // chunk. + ptr_->SetMKLMem(shape_, dtype_); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return ptr_->mkl_mem_->GetRaw(); } } @@ -630,20 +634,23 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { MKLDNNStream *stream = MKLDNNStream::Get(); // If this array uses MKLDNN layout, we have to make sure it's not a view. // Otherwise, we'll have to change the layout inside the array. - if (IsMKLDNNData()) - CHECK(!IsView()); - ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, - dtype_); - stream->RegisterMem(ptr_->mkl_mem_->GetMem()); - mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc(); - mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); + + if (IsMKLDNNData() && IsView()) + ptr_->Reorder2Default(); + + const mkldnn::memory *this_mem = GetMKLDNNData(); + mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc(); + mkldnn::memory::desc from_desc = from_pd.desc(); + mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); + mkldnn::memory::desc this_desc = this_pd.desc(); mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc); + mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc); if (IsView()) { // Sliced array must use the default layout. CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format); } // It's possible that the memory and the NDArray don't have the same shape. - if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims) + if (!same_shape(this_desc, from_desc) // If the source memory uses the default layout, we can reshape directly. && from_def_format == from_desc.data.format) { // In this case, we can simply create a new MKLDNN memory for the required @@ -653,15 +660,14 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type); auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)); mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); + mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); - } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) { + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + } else if (!same_shape(this_desc, from_desc)) { // In this case, the source memory stores data in a customized layout. We // need to reorganize the data in memory before we can reshape. - mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), - from_def_format); + mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format); mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); // Now we can reshape it @@ -670,45 +676,40 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type); auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)); mkldnn::memory::desc data_md(dims, this_dtype, this_format); - mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); + mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); - } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) { + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); + } else if (from_pd == this_pd) { // If the layout is the same, we can just copy data. - stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw())); + stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); } else { - mkldnn_memory_format_t src_def = GetDefaultFormat(mem.get_primitive_desc().desc()); - mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat(); // If both are not using the default layouts. There isn't much we can do, // other than reorder data layout directly. - if (dst_def != ptr_->mkl_mem_->GetFormat() - && src_def != mem.get_primitive_desc().desc().data.format) { - stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw())); - } else if (dst_def == ptr_->mkl_mem_->GetFormat()) { + if (this_def_format != this_desc.data.format + && from_def_format != from_desc.data.format) { + stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); + } else if (this_def_format == this_desc.data.format) { // If the dest mem uses the default memory layout, we can simply use // the default format of the source memory to improve perf of reorder. - mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(src_def); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->mkl_mem_->GetDataHandle())); + mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd, + from_def_format); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle())); stream->RegisterMem(tmp_mem); stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); } else { // If the src mem uses the default memory layout, we can use // the default format of the source memory to improve perf. - mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def); + mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd, + this_def_format); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); } } } -mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, - mkldnn_memory_format_t format); mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) { - // This array shouldn't be a view. - CHECK(!IsView()); - if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) { LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; @@ -719,10 +720,26 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc()); // If the required format is a default format, we don't need to worry about the shape. // If the shape isn't the same, it actually implicitly reshapes data. - if (required_format == def_format) { + if (required_format == def_format && !IsView()) { ptr_->SetMKLMem(shape_, dtype_); MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); + } else if (required_format == def_format) { + ptr_->CheckAndAlloc(); + CHECK(ptr_->shandle.dptr); + // When this is a view and a user wants the default layout, we can simply + // create a new mkldnn memory that points to the right memory. + std::shared_ptr<mkldnn::memory> mem(new mkldnn::memory( + desc, static_cast<char *>(ptr_->shandle.dptr) + byte_offset_)); + MKLDNNStream::Get()->RegisterMem(mem); + return mem.get(); + } else if (IsView()) { + // If this is a view and a user wants to write data to it with special + // a MKLDNN format, we should reorder the data in the array and return NULL. + // In this way, the user will create a new NDArray for the special format + // and copy data back. + ptr_->Reorder2Default(); + return nullptr; } if (ptr_->mkl_mem_) diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 9083216..9fa93a1 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -290,10 +290,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, } else { if (in_bufs.empty()) in_bufs.reserve(inputs.size()); - in_bufs.emplace_back(inputs[i].shape(), inputs[i].ctx(), - false, inputs[i].dtype()); - const mkldnn::memory *mem = inputs[i].GetMKLDNNData(); - in_bufs.back().CopyFrom(*mem); + in_bufs.push_back(inputs[i].Reorder2Default()); in_blobs[i] = in_bufs.back().data(); } } diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc index ccad068..e8fec50 100644 --- a/src/operator/nn/mkldnn/mkldnn_sum.cc +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -59,8 +59,15 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, std::vector<float> scales(inputs.size(), 1); in_prims.reserve(inputs.size()); bool pd_same = true; + std::vector<NDArray> in_bufs(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { - auto in_mem = inputs[i].GetMKLDNNData(); + const mkldnn::memory *in_mem; + if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) { + in_bufs[i] = inputs[i].Reorder2Default(); + in_mem = in_bufs[i].GetMKLDNNData(); + } else { + in_mem = inputs[i].GetMKLDNNData(); + } in_prims.push_back(*in_mem); in_pds[i] = in_mem->get_primitive_desc(); } @@ -68,9 +75,16 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, mkldnn::sum::primitive_desc pdesc(scales, in_pds); pd_same = pd_same && (pdesc.dst_primitive_desc() == in_pds[0]); auto out_mem = const_cast<NDArray&>(out_data).CreateMKLDNNData(pdesc.dst_primitive_desc()); - bool addr_same = out_mem->get_data_handle() == inputs[0].GetMKLDNNData()->get_data_handle(); - if ((req == kWriteTo) || - (req == kWriteInplace && pd_same && addr_same)) { + bool addr_same = false; + const void *first_data_handle; + if (in_bufs[0].is_none()) + first_data_handle = inputs[0].GetMKLDNNData()->get_data_handle(); + else + first_data_handle = in_bufs[0].GetMKLDNNData()->get_data_handle(); + if (out_mem) + addr_same = out_mem->get_data_handle() == first_data_handle; + if (((req == kWriteTo) || (req == kWriteInplace && pd_same && addr_same)) + && out_mem) { // do sum computation directly on output NDArray MKLDNNStream *stream = MKLDNNStream::Get(); stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *out_mem)); diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index d73edc7..00469b0 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -43,16 +43,8 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs, return; } else if (inputs[0].storage_type() == kDefaultStorage && inputs[1].storage_type() == kDefaultStorage) { - // This happens if inputs are supposed to be in MKLDNN format - // but MKLDNN doesn't support the data type or the shape. We're - // forced to convert it to the default format. - std::vector<TBlob> in_blobs(2); - std::vector<TBlob> out_blobs(1); - in_blobs[0] = inputs[0].data(); - in_blobs[1] = inputs[1].data(); - out_blobs[0] = outputs[0].data(); - ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>(attrs, ctx, in_blobs, - req, out_blobs); + FallBackCompute(ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>, + attrs, ctx, inputs, req, outputs); return; } #endif diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc index a6dd8ad..5db4256 100644 --- a/tests/cpp/operator/mkldnn.cc +++ b/tests/cpp/operator/mkldnn.cc @@ -82,12 +82,17 @@ TEST(MKLDNN_UTIL_FUNC, AlignMem) { } // Init arrays with the default layout. -static void InitArray(NDArray *arr) { +static void InitArray(NDArray *arr, bool is_rand = false) { const TBlob &blob = arr->data(); mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>(); size_t size = blob.Size(); - for (size_t i = 0; i < size; i++) - data[i] = i; + if (is_rand) { + for (size_t i = 0; i < size; i++) + data[i] = std::rand(); + } else { + for (size_t i = 0; i < size; i++) + data[i] = i; + } } // Init arrays with the specified layout. @@ -354,6 +359,15 @@ OpAttrs GetLeakyReluOp() { return attrs; } +OpAttrs GetSumOp() { + OpAttrs attrs; + attrs.attrs.op = Op::Get("elemwise_add"); + attrs.dispatches.resize(2); + attrs.dispatches[0] = DispatchMode::kFCompute; + attrs.dispatches[1] = DispatchMode::kFComputeEx; + return attrs; +} + /* * We want to get a few types of NDArrays for testing: * 1. Normal NDArray @@ -411,34 +425,65 @@ std::vector<NDArray> GetTestInputArrays() { * pass them to all operators. * In the inference mode, the MKLDNN memory in the weight array will be * reordered to 5 dimensions. - * 4. Reused NDArray (this is created by the MXNet executor). This type of + * 4. Reshaped/sliced NDArray + * 5. Reused NDArray (this is created by the MXNet executor). This type of * NDArrays can only be used as output arrays. + * 6. Reused NDArray converted from an array with a different data type. + * 7. Reused reshaped/sliced NDArray. + * 8. Reused NDArray with MKLDNN layout. + * 9. Reused NDArray with MKLDNN layout of different dimensions. */ std::vector<NDArray> GetTestOutputArrays(const TShape &shape, const std::vector<mkldnn::memory::primitive_desc> &pds) { std::vector<NDArray> in_arrs; + // Type 1. in_arrs.emplace_back(shape, Context()); - InitArray(&in_arrs.back()); + InitArray(&in_arrs.back(), true); + + // Type 4. + TShape tmp_shape = shape; + tmp_shape[0] = shape[0] * 2; + NDArray arr0(tmp_shape, Context()); + InitArray(&arr0, true); + in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1)); + // Type 5. // Get a reused version. nnvm::TShape s(1); s[0] = shape.Size(); - NDArray arr(s, Context()); - arr = arr.AsArray(shape, arr.dtype()); - InitArray(&arr); - in_arrs.emplace_back(arr); + NDArray arr1(s, Context()); + arr1 = arr1.AsArray(shape, arr1.dtype()); + InitArray(&arr1, true); + in_arrs.emplace_back(arr1); + + // Type 6. + s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag); + NDArray arr2(s, Context(), true, mshadow::kUint8); + arr2 = arr2.AsArray(shape, mshadow::default_type_flag); + InitArray(&arr2, true); + in_arrs.emplace_back(arr2); + + // Type 7 + s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag) * 2; + NDArray arr3(s, Context(), true, mshadow::kUint8); + tmp_shape[0] = shape[0] * 2; + arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag); + InitArray(&arr3, true); + in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1)); for (auto pd : pds) { if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t)) continue; + // Type 2, 3. in_arrs.emplace_back(shape, Context()); InitMKLDNNArray(&in_arrs.back(), pd, true); + // Type 8, 9. // Get a reused version. nnvm::TShape s(1); s[0] = shape.Size(); - arr = NDArray(s, Context()); + NDArray arr = NDArray(s, Context()); arr = arr.AsArray(shape, arr.dtype()); InitMKLDNNArray(&arr, pd, true); in_arrs.emplace_back(arr); @@ -446,10 +491,10 @@ std::vector<NDArray> GetTestOutputArrays(const TShape &shape, return in_arrs; } -using VerifyFunc = std::function<void (const NDArray &in_arr, const NDArray &arr)>; +using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs, const NDArray &arr)>; -void VerifyCopyResult(const NDArray &in_arr, const NDArray &arr) { - NDArray tmp1 = in_arr.Reorder2Default(); +void VerifyCopyResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) { + NDArray tmp1 = in_arrs[0]->Reorder2Default(); NDArray tmp2 = arr.Reorder2Default(); EXPECT_EQ(tmp1.shape().Size(), tmp2.shape().Size()); TBlob d1 = tmp1.data(); @@ -458,6 +503,40 @@ void VerifyCopyResult(const NDArray &in_arr, const NDArray &arr) { tmp1.shape().Size() * sizeof(mshadow::default_real_t)), 0); } +void VerifySumResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) { + NDArray in1 = in_arrs[0]->Reorder2Default(); + NDArray in2 = in_arrs[1]->Reorder2Default(); + NDArray out = arr.Reorder2Default(); + EXPECT_EQ(in1.shape().Size(), in2.shape().Size()); + EXPECT_EQ(in1.shape().Size(), out.shape().Size()); + + mshadow::default_real_t *d1 = in1.data().dptr<mshadow::default_real_t>(); + mshadow::default_real_t *d2 = in2.data().dptr<mshadow::default_real_t>(); + mshadow::default_real_t *o = out.data().dptr<mshadow::default_real_t>(); + for (size_t i = 0; i < in1.shape().Size(); i++) + EXPECT_EQ(d1[i] + d2[i], o[i]); +} + +TEST(MKLDNN_NDArray, CopyFrom) { + TestArrayShapes tas = GetTestArrayShapes(); + std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; + + std::vector<NDArray> in_arrs = GetTestInputArrays(); + for (auto in_arr : in_arrs) { + std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr.shape(), pds); + for (auto out_arr : out_arrs) { + if (in_arr.IsMKLDNNData() && in_arr.IsView()) + in_arr = in_arr.Reorder2Default(); + const mkldnn::memory *mem = in_arr.GetMKLDNNData(); + out_arr.CopyFrom(*mem); + MKLDNNStream::Get()->Submit(); + std::vector<NDArray *> inputs(1); + inputs[0] = &in_arr; + VerifyCopyResult(inputs, out_arr); + } + } +} + void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) { std::vector<NDArray*> inputs(1); std::vector<NDArray*> outputs(1); @@ -478,7 +557,53 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) { Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, dispatch, mxnet::OpStatePtr()); out_arr.WaitToRead(); - verify_fn(in_arr, out_arr); + verify_fn(inputs, out_arr); + } + } + } + + for (auto dispatch : dispatches) { + in_arrs = GetTestInputArrays(); + for (auto arr : in_arrs) { + // If the array is a view, we shouldn't write data to it. + if (arr.IsView()) + continue; + + NDArray orig = arr.Copy(arr.ctx()); + req[0] = kWriteInplace; + inputs[0] = &arr; + outputs[0] = &arr; + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, + dispatch, mxnet::OpStatePtr()); + arr.WaitToRead(); + inputs[0] = &orig; + verify_fn(inputs, arr); + } + } +} + +void TestBinaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) { + std::vector<NDArray*> inputs(2); + std::vector<NDArray*> outputs(1); + std::vector<OpReqType> req(1); + std::vector<DispatchMode> dispatches = attrs.dispatches; + + TestArrayShapes tas = GetTestArrayShapes(); + std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; + + std::vector<NDArray> in_arrs = GetTestInputArrays(); + for (auto in_arr1 : in_arrs) { + for (auto dispatch : dispatches) { + std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr1.shape(), pds); + for (auto out_arr : out_arrs) { + req[0] = kWriteTo; + inputs[0] = &in_arr1; + inputs[1] = &in_arr1; + outputs[0] = &out_arr; + Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, + outputs, req, dispatch, mxnet::OpStatePtr()); + out_arr.WaitToRead(); + verify_fn(inputs, out_arr); } } } @@ -493,11 +618,15 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) { NDArray orig = arr.Copy(arr.ctx()); req[0] = kWriteInplace; inputs[0] = &arr; + inputs[1] = &arr; outputs[0] = &arr; Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, dispatch, mxnet::OpStatePtr()); arr.WaitToRead(); - verify_fn(orig, arr); + std::vector<NDArray *> orig_inputs(2); + orig_inputs[0] = &orig; + orig_inputs[1] = &orig; + verify_fn(orig_inputs, arr); } } } @@ -507,4 +636,10 @@ TEST(IMPERATIVE, UnaryOp) { TestUnaryOp(attrs, VerifyCopyResult); } + +TEST(IMPERATIVE, BinaryOp) { + OpAttrs attrs = GetSumOp(); + TestBinaryOp(attrs, VerifySumResult); +} + #endif -- To stop receiving notification emails like this one, please contact anirudh2...@apache.org.