eric-haibin-lin commented on a change in pull request #10183: [MXNET-120] 
Float16 support for distributed training
URL: https://github.com/apache/incubator-mxnet/pull/10183#discussion_r178666240
 
 

 ##########
 File path: src/kvstore/kvstore_dist_server.h
 ##########
 @@ -220,175 +313,229 @@ class KVStoreDistServer {
     }
   }
 
-  void DataHandleRowSparse(const ps::KVMeta& req_meta,
-                       const ps::KVPairs<real_t>& req_data,
-                       ps::KVServer<real_t>* server) {
+  void AccumulateRowSparseGrads(const DataHandleType type,
+                                const NDArray& recved,
+                                UpdateBuf* updateBuf) {
+    NDArray out(kRowSparseStorage, updateBuf->merged.shape(), Context(), true,
+                has_multi_precision_copy(type) ? mshadow::kFloat32 : 
type.dtype);
+    if (has_multi_precision_copy(type)) CopyFromTo(recved, 
updateBuf->temp_array);
+    const NDArray& to_merge = has_multi_precision_copy(type) ? 
updateBuf->temp_array : recved;
+    // accumulate row_sparse gradients
+    // TODO(haibin) override + operator for row_sparse NDArray
+    // instead of calling BinaryComputeRspRsp directly
+    using namespace mshadow;
+    Engine::Get()->PushAsync(
+    [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete 
on_complete) {
+      op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
+      {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out});
+      on_complete();
+    }, to_merge.ctx(), {to_merge.var(), updateBuf->merged.var()}, {out.var()},
+    FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+    CopyFromTo(out, &(updateBuf->merged), 0);
+    updateBuf->merged.WaitToRead();
+  }
+
+  void RowSparsePullResponse(const DataHandleType type,
+                             const int master_key,
+                             const size_t num_rows,
+                             const ps::KVMeta& req_meta,
+                             const ps::KVPairs<char>& req_data,
+                             ps::KVServer<char>* server) {
+    if (log_verbose_) LOG(INFO) << "pull: " << master_key;
+    ps::KVPairs<char> response;
+    if (num_rows == 0) {
+      std::vector<int> lens(req_data.keys.size(), 0);
+      response.keys = req_data.keys;
+      response.lens.CopyFrom(lens.begin(), lens.end());
+      server->Response(req_meta, response);
+      return;
+    }
+    const NDArray& stored = store_[master_key];
+    if (has_multi_precision_copy(type)) stored.WaitToRead();
 
 Review comment:
   why need to wait on `stored`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to