zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r161921308
 
 

 ##########
 File path: src/common/exec_utils.h
 ##########
 @@ -43,19 +43,61 @@ namespace common {
           indices are not recorded
  * \return true if any source NDArray need to cast storage
  */
-inline bool SetupDefaultBlobs(const std::vector<NDArray>& src,
-                              std::vector<TBlob> *blobs,
-                              std::vector<NDArray> *temp_src,
-                              std::vector<NDArray> *temp_dst,
-                              std::unordered_map<uint32_t, uint32_t> *idx_map 
= nullptr) {
+inline bool SetupDefaultBlobsIn(const std::vector<NDArray>& src,
+                                const std::vector<NDArray> *bufs,
+                                std::vector<TBlob> *blobs,
+                                std::vector<NDArray> *temp_src,
+                                std::vector<NDArray> *temp_dst,
+                                std::unordered_map<uint32_t, uint32_t> 
*idx_map) {
   bool require_cast = false;
   for (size_t i = 0; i < src.size(); i++) {
     auto& nd = src[i];
-    if (nd.storage_type() != kDefaultStorage) {
-      if (idx_map != nullptr) {
-        (*idx_map)[i] = temp_dst->size();
-      }
-      NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype());
+    bool is_default = nd.storage_type() == kDefaultStorage;
+#if MXNET_USE_MKLDNN == 1
+    // We have to make sure it's default storage and default layout.
+    is_default = nd.IsDefault();
+#endif
+    if (!is_default) {
+      (*idx_map)[i] = temp_dst->size();
+      NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), 
nd.ctx(),
+                                                             true, nd.dtype());
+#if MXNET_USE_MKLDNN == 1
+      CHECK(temp.IsDefault());
+#endif
+      temp_src->emplace_back(nd);
+      temp_dst->emplace_back(temp);
+      blobs->emplace_back(temp.data());
+      require_cast = true;
+    } else {
+      blobs->push_back(nd.data());
+    }
+  }
+  return require_cast;
+}
+
+inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<NDArray> *bufs,
+                                 std::vector<TBlob> *blobs,
+                                 std::vector<NDArray> *temp_src,
+                                 std::vector<NDArray> *temp_dst) {
+  bool require_cast = false;
+  for (size_t i = 0; i < src.size(); i++) {
+    auto& nd = src[i];
+    bool is_default = nd.storage_type() == kDefaultStorage;
+#if MXNET_USE_MKLDNN == 1
+    // If it's writeTo, we don't need to worry whether it contains valid data.
+    if (req[i] == kWriteTo && is_default)
 
 Review comment:
   The goal is to remove Mkl_ptr_ in the NDArray when an NDArray is reused.
   When we use WriteTo, the data in the output array shouldn't have any valid 
data. We should notify the NDArray of this. Otherwise, NDArray always thinks 
the data should be stored in a special layout.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to