zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r161939043
 
 

 ##########
 File path: include/mxnet/ndarray.h
 ##########
 @@ -611,6 +565,64 @@ class NDArray {
              << "CheckAndAllocAuxData is not intended for kDefaultStorage";
     ptr_->CheckAndAllocAuxData(i, aux_shape);
   }
+
+#if MXNET_USE_MKLDNN == 1
+  bool IsMKLDNN() const {
+    return ptr_->IsMKLDNN();
+  }
+  bool IsDefault() const {
+    return ptr_->IsDefault();
+  }
+  /*
+   * All functions below return a raw pointer to mkldnn memory. Actually there
+   * is a shared pointer that hold the memory either in NDArray or in MKLDNN
+   * stream. As long as we call these functions inside an operator, the return
+   * memory is always valid.
+   */
+
+  /*
+   * This function returns mkldnn::memory with the default primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNData() const;
+  /*
+   * This function returns mkldnn::memory with the given primitive_desc
+   * as long as the array size meets the required size in the given 
primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNData(
+      const mkldnn::memory::primitive_desc &desc) const;
+  /*
+   * This function returns mkldnn::memory with the given primitive_desc.
+   * The returned mkldnn::memory will have the same physical layout as
+   * the given primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNDataReorder(
+      const mkldnn::memory::primitive_desc &desc) const;
+
+  void CopyFrom(const mkldnn::memory &mem);
+  mkldnn::memory *CreateMKLDNNData(
+      const mkldnn::memory::primitive_desc &desc);
+
+  /*
+   * Reorder the memory to the specified layout.
+   */
+  void Reorder(const mkldnn::memory::primitive_desc &desc);
+  void Reorder2Default() {
+    CHECK_EQ(storage_type(), kDefaultStorage);
+    ptr_->Reorder2Default();
+  }
+
+  void InvalidateData() {
+    // When we invalidate data, we don't need to care about the MKLDNN format.
+    ptr_->Mkl_mem_ = nullptr;
+  }
+
+  /*
+   * This function is used inside operators to reshape an array.
+   * It's used by FullyConnected right now.
+   */
+  NDArray ReshapeMKLDNN(const TShape &shape) const;
 
 Review comment:
   If the array stores data in a special layout, Reshape will cause the data in 
the array to be converted to the default layout, which allocates memory from 
malloc directly.
   ReshapeMKLDNN won't change the layout of the original array and always uses 
the temporary memory buffer to store the reordered data.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to