eric-haibin-lin commented on a change in pull request #10183: [MXNET-120] Float16 support for distributed training URL: https://github.com/apache/incubator-mxnet/pull/10183#discussion_r178666240
########## File path: src/kvstore/kvstore_dist_server.h ########## @@ -220,175 +313,229 @@ class KVStoreDistServer { } } - void DataHandleRowSparse(const ps::KVMeta& req_meta, - const ps::KVPairs<real_t>& req_data, - ps::KVServer<real_t>* server) { + void AccumulateRowSparseGrads(const DataHandleType type, + const NDArray& recved, + UpdateBuf* updateBuf) { + NDArray out(kRowSparseStorage, updateBuf->merged.shape(), Context(), true, + has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); + if (has_multi_precision_copy(type)) CopyFromTo(recved, updateBuf->temp_array); + const NDArray& to_merge = has_multi_precision_copy(type) ? updateBuf->temp_array : recved; + // accumulate row_sparse gradients + // TODO(haibin) override + operator for row_sparse NDArray + // instead of calling BinaryComputeRspRsp directly + using namespace mshadow; + Engine::Get()->PushAsync( + [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete on_complete) { + op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>( + {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out}); + on_complete(); + }, to_merge.ctx(), {to_merge.var(), updateBuf->merged.var()}, {out.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + CopyFromTo(out, &(updateBuf->merged), 0); + updateBuf->merged.WaitToRead(); + } + + void RowSparsePullResponse(const DataHandleType type, + const int master_key, + const size_t num_rows, + const ps::KVMeta& req_meta, + const ps::KVPairs<char>& req_data, + ps::KVServer<char>* server) { + if (log_verbose_) LOG(INFO) << "pull: " << master_key; + ps::KVPairs<char> response; + if (num_rows == 0) { + std::vector<int> lens(req_data.keys.size(), 0); + response.keys = req_data.keys; + response.lens.CopyFrom(lens.begin(), lens.end()); + server->Response(req_meta, response); + return; + } + const NDArray& stored = store_[master_key]; + if (has_multi_precision_copy(type)) stored.WaitToRead(); Review comment: why need to wait on `stored`? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services