piiswrong commented on a change in pull request #8915: NVLink communication 
pattern updated 
URL: https://github.com/apache/incubator-mxnet/pull/8915#discussion_r170377029
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -523,95 +537,236 @@ class CommDevice : public Comm {
     }
 
     InitBuffersAndComm(src);
-    auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-
-    const NDArrayStorageType stype = buf.merged.storage_type();
+    // merge buffer holds the first group of gpus
+    BufferEntry& buf = merge_buf_[key];
+    // stage buffer holds the data of the second group  or the first when merge
+    // buffer is empty
+    BufferEntry& stage = stage_buf_[key];
+    std::vector<NDArray> reduce_s;
+
+    const NDArrayStorageType stype = stage.merged.storage_type();
     if (stype == kDefaultStorage) {
-      CopyFromTo(src[0], &(buf.merged), priority);
-      reduce[0] = buf.merged;
-
-      if (buf.copy_buf.empty()) {
-        // TODO(mli) this results in large device memory usage for huge 
ndarray,
-        // such as the largest fullc in VGG. consider to do segment reduce with
-        // NDArray.Slice or gpu direct memory access. for the latter, we need 
to
-        // remove some ctx check, and also it reduces 20% perf
-        buf.copy_buf.resize(src.size()-1);
-        for (size_t i = 0; i < src.size()-1; ++i) {
-          buf.copy_buf[i] = NDArray(
-            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
-        }
+      if (buf.merged.is_none() && stage.copy_buf.empty()) {
+        stage.copy_buf.resize(src.size() - 1);
+        for (size_t i = 0; i < src.size() - 1; ++i)
+          stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                      false, stage.merged.dtype());
       }
-      for (size_t i = 0; i < src.size()-1; ++i) {
-        CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-        reduce[i+1] = buf.copy_buf[i];
+      reduce_s.resize(stage.copy_buf.size() + 1);
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) ||
+            (buf.merged.is_none() && i == 0)) {
+          CopyFromTo(src[i], &stage.merged, priority);
+          reduce_s[0] = stage.merged;
+        } else if (id >= 4 || buf.merged.is_none()) {
+          CopyFromTo(src[i], &(stage.copy_buf[j]), priority);
+          reduce_s[j + 1] = stage.copy_buf[j];
+          j++;
+        }
       }
     } else {
-      if (buf.copy_buf.empty()) {
-        buf.copy_buf.resize(src.size());
-        for (size_t j = 0; j < src.size(); ++j) {
-          buf.copy_buf[j] = NDArray(
-            buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
-            true, buf.merged.dtype());
+      if (buf.merged.is_none() && stage.copy_buf.empty()) {
+        stage.copy_buf.resize(src.size());
+        for (size_t j = 0; j < src.size(); ++j)
+          stage.copy_buf[j] =
+              NDArray(stage.merged.storage_type(), stage.merged.shape(),
+                      stage.merged.ctx(), true, stage.merged.dtype());
+      }
+      reduce_s.resize(stage.copy_buf.size());
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if (id >= 4 || buf.merged.is_none()) {
 
 Review comment:
   why 4? can we avoid magic numbers?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to