zheng-da closed pull request #10275: [WIP] Fix the default layout in the MKLDNN 
integration
URL: https://github.com/apache/incubator-mxnet/pull/10275
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index b8b7f20fcf3..62f7e2cd942 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -886,9 +886,11 @@ class NDArray {
     void CheckAndAllocData(const TShape &shape, int dtype);
 
 #if MXNET_USE_MKLDNN == 1
-    // Have MKL memory reference to the data in the default storage
-    // or create memory for MKLDNN.
+    // Have MKL memory reference to the data in the default storage.
     void SetMKLMem(const TShape &shape, int dtype);
+    // Have MKL memory reference to the data in the default storage
+    // The layout in the memory descriptor must be the default layout.
+    void SetMKLMem(const mkldnn::memory::primitive_desc &desc);
     // If the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
     // save the result in shandle.
     void Reorder2Default();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 52b96fad692..ef6953230fe 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -416,6 +416,35 @@ void NDArray::Chunk::MKLDNNDataReorder(const 
mkldnn::memory::primitive_desc &pd)
   mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
 }
 
+void NDArray::Chunk::SetMKLMem(const mkldnn::memory::primitive_desc &desc) {
+  // The shape of the array and the one of the MKL memory may mismatch.
+  // For example, if the array stores parameters, the MKL memory may store data
+  // in 5 dimensions while the NDArray stores data in 4 dimensions.
+  if (mkl_mem_ && mkl_mem_->get_primitive_desc() == desc) {
+    return;
+  }
+
+  if (shandle.dptr == nullptr) {
+    CHECK(delay_alloc);
+    CheckAndAlloc();
+  }
+  mkldnn::memory::primitive_desc pd = desc;
+  // When we set mkldnn memory, we can only use its default format.
+  if (pd.desc().data.format != GetDefaultFormat(pd.desc())) {
+    auto _desc = pd.desc();
+    mkldnn::memory::dims dims(_desc.data.dims, _desc.data.dims + 
_desc.data.ndims);
+    mkldnn::memory::format layout
+        = static_cast<mkldnn::memory::format>(GetDefaultFormat(_desc));
+    auto dtype = static_cast<mkldnn::memory::data_type>(_desc.data.data_type);
+    mkldnn::memory::desc data_md(dims, dtype, layout);
+    auto cpu_engine = CpuEngine::Get()->get_engine();
+    pd = mkldnn::memory::primitive_desc(data_md, cpu_engine);
+  }
+
+  CHECK(shandle.size >= pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(desc, shandle.dptr));
+}
+
 void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
   // The shape of the array and the one of the MKL memory may mismatch.
   // For example, if the array stores parameters, the MKL memory may store data
@@ -486,15 +515,13 @@ const mkldnn::memory *NDArray::GetMKLDNNData(
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN 
memory desc";
     return nullptr;
   }
+  // If mkl_mem hasn't been set up, we set it up now.
+  // TODO(zhengda) the input layout must be the default layout.
+  if (ptr_->mkl_mem_ == nullptr)
+    ptr_->SetMKLMem(desc);
   auto mem = GetMKLDNNData();
-  mkldnn::memory::primitive_desc _desc = desc;
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
-  // The MKL memory has the same format and shape as required,
-  // or both use the default format, we can return the MKL memory.
-  if (mem->get_primitive_desc() == desc
-      || (desc1.data.format == GetDefaultFormat(desc1)
-        && desc2.data.format == GetDefaultFormat(desc2))) {
+  // The two formats are compatible and shape as required.
+  if (mem->get_primitive_desc() == desc) {
     return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
   } else {
     return nullptr;
@@ -509,6 +536,9 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   }
   CHECK(storage_type() == kDefaultStorage);
 
+  // If mkl_mem hasn't been set up, we set it up now.
+  if (ptr_->mkl_mem_ == nullptr)
+    ptr_->SetMKLMem(desc);
   auto mem = GetMKLDNNData();
   // If the memory descriptor matches, it's easy.
   MKLDNNStream *stream = MKLDNNStream::Get();
@@ -516,16 +546,10 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
     return GetMKLDNNExact(mem, desc);
   }
 
-  mkldnn::memory::primitive_desc _desc = desc;
   // Now we need to determine if we should reorder the memory.
-  // If both use the default formats, we think we don't need to reorder.
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
-  if (desc1.data.format == GetDefaultFormat(desc1) &&
-      desc2.data.format == GetDefaultFormat(desc2)) {
-    mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle()));
-    stream->RegisterMem(ret);
-    return ret.get();
+  // If the two formats are compatible, we don't need to reorder.
+  if (mem->get_primitive_desc() == desc) {
+    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
   } else {
     auto ret = TmpMemMgr::Get()->Alloc(desc);
     stream->RegisterPrim(mkldnn::reorder(*mem, *ret));
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 820cca1402c..dea543dd712 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -171,11 +171,8 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
   }
   if (mem == nullptr)
     mem = arr.GetMKLDNNDataReorder(target_pd);
-  if (mem->get_primitive_desc() == target_pd) return mem;
-
-  auto ret = TmpMemMgr::Get()->Alloc(target_pd);
-  MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*mem, *ret));
-  return ret;
+  CHECK(mem);
+  return mem;
 }
 
 mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
@@ -201,14 +198,19 @@ mkldnn_memory_format_t 
GetDefaultFormat(mkldnn::memory::desc desc) {
   } else if (desc.data.ndims == 4) {
     switch (desc.data.format) {
       case mkldnn_nchw:
-      case mkldnn_nhwc:
-      case mkldnn_chwn:
       case mkldnn_nChw8c:
       case mkldnn_nChw16c:
         return mkldnn_nchw;
+      case mkldnn_nhwc:
+        return mkldnn_nhwc;
+      case mkldnn_chwn:
+        return mkldnn_chwn;
       case mkldnn_oihw:
+        return mkldnn_oihw;
       case mkldnn_ihwo:
+        return mkldnn_ihwo;
       case mkldnn_hwio:
+        return mkldnn_hwio;
       case mkldnn_OIhw8i8o:
       case mkldnn_OIhw16i16o:
       case mkldnn_OIhw8i16o2i:
@@ -218,9 +220,11 @@ mkldnn_memory_format_t 
GetDefaultFormat(mkldnn::memory::desc desc) {
       case mkldnn_IOhw16o16i:
       case mkldnn_Oihw8o:
       case mkldnn_Oihw16o:
+      case mkldnn_OhIw16o4i:
+        return mkldnn_oihw;
       case mkldnn_Ohwi8o:
       case mkldnn_Ohwi16o:
-      case mkldnn_OhIw16o4i:
+        // TODO(zhengda) what is the right default format for these two?
         return mkldnn_oihw;
       default:
         LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << 
desc.data.format;


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to