Laurawly commented on a change in pull request #8915: NVLink communication 
pattern updated 
URL: https://github.com/apache/incubator-mxnet/pull/8915#discussion_r163997689
 
 

 ##########
 File path: src/kvstore/comm.h
 ##########
 @@ -526,101 +541,238 @@ class CommDevice : public Comm {
     }
 
     InitBuffersAndComm(src);
+    auto& stage = stage_buf_[key];
     auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-    CopyFromTo(src[0], &(buf.merged), priority);
-    reduce[0] = buf.merged;
-
-    if (buf.copy_buf.empty()) {
-      // TODO(mli) this results in large device memory usage for huge ndarray,
-      // such as the largest fullc in VGG. consider to do segment reduce with
-      // NDArray.Slice or gpu direct memory access. for the latter, we need to
-      // remove some ctx check, and also it reduces 20% perf
-      buf.copy_buf.resize(src.size()-1);
-      for (size_t i = 0; i < src.size()-1; ++i) {
-        buf.copy_buf[i] = NDArray(
-          buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+
+    if (buf.merged.is_none() && stage.copy_buf.empty()) {
+      stage.copy_buf.resize(src.size() - 1);
+      for (size_t i = 0; i < src.size() - 1; ++i)
+        stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                    false, stage.merged.dtype());
+    } else if (!buf.merged.is_none()) {
+      if (buf.copy_buf.empty()) {
+        buf.copy_buf.resize(g1.size());
+        for (size_t i = 0; i < g1.size(); ++i)
+          buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), 
false,
+                                    buf.merged.dtype());
+      }
+      if (stage.copy_buf.empty()) {
+        stage.copy_buf.resize(g2.size() - 1);
+        for (size_t i = 0; i < g2.size() - 1; ++i)
+          stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                      false, stage.merged.dtype());
       }
     }
-    for (size_t i = 0; i < src.size()-1; ++i) {
-      CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-      reduce[i+1] = buf.copy_buf[i];
+    std::vector<NDArray> reduce_s(stage.copy_buf.size() + 1);
+    for (size_t i = 0, j = 0; i < src.size(); ++i) {
+      int id = src[i].ctx().dev_id;
+      if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) ||
+          (buf.merged.is_none() && i == 0)) {
+        CopyFromTo(src[i], &(stage.merged), priority);
+        reduce_s[0] = stage.merged;
+      } else if (id >= NVLINK_SUPPORT || buf.merged.is_none()) {
+        CopyFromTo(src[i], &(stage.copy_buf[j]), priority);
+        reduce_s[j + 1] = stage.copy_buf[j];
+        j++;
+      }
     }
+    ElementwiseSum(reduce_s, &stage.merged);
+    // Main reduce result on gpu 0 including the partial result from gpu
+    // NVLINK_SUPPORT
+    if (!buf.merged.is_none()) {
+      std::vector<NDArray> reduce(buf.copy_buf.size() + 1);
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if (id == buf.merged.ctx().dev_id) {
+          CopyFromTo(src[i], &(buf.merged), priority);
+          reduce[0] = buf.merged;
+        } else if (id < NVLINK_SUPPORT) {
+          CopyFromTo(src[i], &(buf.copy_buf[j]), priority);
+          reduce[j + 1] = buf.copy_buf[j];
+          j++;
+        }
+      }
 
-    ElementwiseSum(reduce, &buf.merged);
+      CopyFromTo(stage.merged, &(buf.copy_buf[buf.copy_buf.size() - 1]),
+                 priority);
+      reduce[reduce.size() - 1] = buf.copy_buf[buf.copy_buf.size() - 1];
+      ElementwiseSum(reduce, &buf.merged);
+    } else {
+      return stage.merged;
+    }
     return buf.merged;
   }
 
   const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
                                   int priority) {
     InitBuffersAndComm(src);
     auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-    if (buf.copy_buf.empty()) {
+    auto& stage = stage_buf_[key];
+    if (buf.merged.is_none() && stage.copy_buf.empty()) {
       // one buf for each context
-      buf.copy_buf.resize(src.size());
-      buf.compressed_recv_buf.resize(src.size());
-      buf.compressed_send_buf.resize(src.size());
-      buf.residual.resize(src.size());
+      stage.copy_buf.resize(src.size());
+      stage.compressed_recv_buf.resize(src.size());
+      stage.compressed_send_buf.resize(src.size());
+      stage.residual.resize(src.size());
 
       for (size_t i = 0; i < src.size(); ++i) {
-        buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = 0;
+        stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                    false, stage.merged.dtype());
+        stage.residual[i] = NDArray(stage.merged.shape(), src[i].ctx(), false,
+                                    stage.merged.dtype());
+        stage.residual[i] = 0;
+        int64_t small_size =
+            gc_->GetCompressedSize(stage.merged.shape().Size());
+        stage.compressed_recv_buf[i] =
+            NDArray(TShape{small_size}, stage.merged.ctx(), false,
+                    stage.merged.dtype());
+        stage.compressed_send_buf[i] = NDArray(TShape{small_size}, 
src[i].ctx(),
+                                               false, stage.merged.dtype());
+      }
+    } else if (!buf.merged.is_none()) {
+      if (buf.copy_buf.empty() && stage.copy_buf.empty()) {
+        buf.copy_buf.resize(g1.size() + 1);
+        buf.compressed_recv_buf.resize(g1.size() + 1);
+        buf.compressed_send_buf.resize(g1.size() + 1);
+        buf.residual.resize(g1.size() + 1);
 
 Review comment:
   Tested tests/nightly/test_kvstore.py and it passes.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to