This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 1267c6a fix wrong dist-kvstore push/pull/rsp_pull (#7762)
1267c6a is described below
commit 1267c6a87661b69a115d4d5dde795c17a548afbc
Author: Haibin Lin <[email protected]>
AuthorDate: Thu Sep 7 10:30:26 2017 -0700
fix wrong dist-kvstore push/pull/rsp_pull (#7762)
---
include/mxnet/kvstore.h | 4 +-
src/kvstore/kvstore_dist.h | 137 +++++++++++++++++++-------------------
src/kvstore/kvstore_dist_server.h | 1 +
src/kvstore/kvstore_local.h | 40 +++++------
4 files changed, 91 insertions(+), 91 deletions(-)
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index bca88a5..ddaa207 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -184,7 +184,7 @@ class KVStore {
*/
virtual void PullRowSparse(const std::vector<int>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
- const int priority = 0) = 0;
+ int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store, where each key is a
string.
@@ -196,7 +196,7 @@ class KVStore {
*/
virtual void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
- const int priority = 0) = 0;
+ int priority = 0) = 0;
/**
* \brief the prototype of user-defined updater
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 0cc83a0..6ce6b5a 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -80,8 +80,63 @@ class KVStoreDist : public KVStoreLocal {
}
}
- void Init(const std::vector<int>& keys,
- const std::vector<NDArray>& values) override {
+ void set_updater(const Updater& updater) override {
+ CHECK(updater) << "invalid updater";
+ if (IsServerNode()) {
+ CHECK_NOTNULL(server_)->set_updater(updater);
+ } else {
+ updater_ = updater;
+ }
+ }
+
+ void Barrier() override {
+ ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
+ }
+
+ void SendCommandToServers(int cmd_id,
+ const std::string& cmd_body) override {
+ CHECK_NOTNULL(ps_worker_);
+ ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
+ }
+
+ int get_group_size() const override { return ps::NumWorkers(); }
+
+ int get_rank() const override { return ps::MyRank(); }
+
+ int get_num_dead_node(int node_id, int timeout) const override {
+ int number = 0;
+ auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
+ const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
+ std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
+ for (int r : dead_nodes) {
+ if (watch_set.find(r) != watch_set.end()) number++;
+ }
+ return number;
+ }
+
+ void RunServer(const Controller& controller) override {
+ CHECK(!IsWorkerNode());
+ if (IsServerNode()) {
+ server_ = new KVStoreDistServer();
+ server_->set_controller(controller);
+ }
+
+ ps::StartAsync("mxnet_server\0");
+ if (!ps::Postoffice::Get()->is_recovery()) {
+ ps::Postoffice::Get()->Barrier(
+ ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
+ }
+ if (server_) server_->Run();
+ ps::Finalize();
+ if (server_) {
+ delete server_;
+ }
+ server_ = nullptr;
+ }
+
+ private:
+ void InitImpl(const std::vector<int>& keys,
+ const std::vector<NDArray>& values) override {
CheckUnique(keys);
for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(),
values[i].dtype());
@@ -100,15 +155,15 @@ class KVStoreDist : public KVStoreLocal {
}
}
- void Push(const std::vector<int>& keys,
- const std::vector<NDArray>& values,
- int priority) override {
+ void PushImpl(const std::vector<int>& keys,
+ const std::vector<NDArray>& values,
+ int priority) override {
Push_(keys, values, priority, true);
}
- void Pull(const std::vector<int>& keys,
- const std::vector<NDArray*>& values,
- int priority) override {
+ void PullImpl(const std::vector<int>& keys,
+ const std::vector<NDArray*>& values,
+ int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -155,9 +210,9 @@ class KVStoreDist : public KVStoreLocal {
}
}
- void PullRowSparse(const std::vector<int>& keys,
- const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
- const int priority = 0) {
+ void PullRowSparseImpl(const std::vector<int>& keys,
+ const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
+ int priority = 0) override {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
@@ -198,66 +253,10 @@ class KVStoreDist : public KVStoreLocal {
}
}
- void set_updater(const Updater& updater) override {
- CHECK(updater) << "invalid updater";
- if (IsServerNode()) {
- CHECK_NOTNULL(server_)->set_updater(updater);
- } else {
- updater_ = updater;
- }
- }
-
- void Barrier() override {
- ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
- }
-
-
- void SendCommandToServers(int cmd_id,
- const std::string& cmd_body) override {
- CHECK_NOTNULL(ps_worker_);
- ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
- }
-
- int get_group_size() const override { return ps::NumWorkers(); }
-
- int get_rank() const override { return ps::MyRank(); }
-
- int get_num_dead_node(int node_id, int timeout) const override {
- int number = 0;
- auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
- const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
- std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
- for (int r : dead_nodes) {
- if (watch_set.find(r) != watch_set.end()) number++;
- }
- return number;
- }
-
- void RunServer(const Controller& controller) override {
- CHECK(!IsWorkerNode());
- if (IsServerNode()) {
- server_ = new KVStoreDistServer();
- server_->set_controller(controller);
- }
-
- ps::StartAsync("mxnet_server\0");
- if (!ps::Postoffice::Get()->is_recovery()) {
- ps::Postoffice::Get()->Barrier(
- ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
- }
- if (server_) server_->Run();
- ps::Finalize();
- if (server_) {
- delete server_;
- }
- server_ = nullptr;
- }
-
- private:
void Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority,
- bool do_merge) {
+ bool do_merge) {
// first aggregate the values over keys
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
@@ -320,7 +319,7 @@ class KVStoreDist : public KVStoreLocal {
}
// pull row sparse weight into `recv_buf` based on indices given by `indices`
- void PullRowSparse_(int key, NDArray *recv_buf, const NDArray& indices, int
priority) {
+ void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray&
indices, int priority) {
using namespace rowsparse;
auto pull_from_servers = [this, key, recv_buf, indices]
(RunContext rctx, Engine::CallbackOnComplete cb) {
diff --git a/src/kvstore/kvstore_dist_server.h
b/src/kvstore/kvstore_dist_server.h
index 43a10b0..88bdcab 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -339,6 +339,7 @@ class KVStoreDistServer {
auto len = unit_len * num_rows;
// concat values
response.vals.resize(len);
+ #pragma omp parallel for
for (size_t i = 1; i <= num_rows; i++) {
int key = DecodeKey(req_data.keys[i]);
int64_t row_id = key - master_key;
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index db1d04a..ac2968b 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -67,7 +67,7 @@ class KVStoreLocal : public KVStore {
void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
SetKeyType(kIntKey);
- Init_(keys, values);
+ InitImpl(keys, values);
}
void Init(const std::vector<std::string>& str_keys,
@@ -84,28 +84,28 @@ class KVStoreLocal : public KVStore {
reverse_str_key_dict_[key] = str_key;
keys[i] = key;
}
- Init_(keys, values);
+ InitImpl(keys, values);
}
void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority) override {
SetKeyType(kIntKey);
- Push_(keys, values, priority);
+ PushImpl(keys, values, priority);
}
void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority) override {
SetKeyType(kIntKey);
- Pull_(keys, values, priority);
+ PullImpl(keys, values, priority);
}
void PullRowSparse(const std::vector<int>& keys,
const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
int priority = 0) override {
SetKeyType(kIntKey);
- PullRowSparse_(keys, val_rowids, priority);
+ PullRowSparseImpl(keys, val_rowids, priority);
}
void Push(const std::vector<std::string>& str_keys,
@@ -114,7 +114,7 @@ class KVStoreLocal : public KVStore {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
- Push_(keys, values, priority);
+ PushImpl(keys, values, priority);
}
void Pull(const std::vector<std::string>& str_keys,
@@ -123,21 +123,21 @@ class KVStoreLocal : public KVStore {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
- Pull_(keys, values, priority);
+ PullImpl(keys, values, priority);
}
void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
- const int priority = 0) override {
+ int priority = 0) override {
SetKeyType(kStringKey);
std::vector<int> keys(str_keys.size());
LookupKeys(str_keys, &keys);
- PullRowSparse_(keys, val_rowids, priority);
+ PullRowSparseImpl(keys, val_rowids, priority);
}
private:
- void Init_(const std::vector<int>& keys,
- const std::vector<NDArray>& values) {
+ virtual void InitImpl(const std::vector<int>& keys,
+ const std::vector<NDArray>& values) {
for (size_t i = 0; i < keys.size(); ++i) {
CHECK(local_.find(keys[i]) == local_.end())
<< "duplicate init of key " << keys[i];
@@ -146,9 +146,9 @@ class KVStoreLocal : public KVStore {
}
}
- void Push_(const std::vector<int>& keys,
- const std::vector<NDArray>& values,
- int priority) {
+ virtual void PushImpl(const std::vector<int>& keys,
+ const std::vector<NDArray>& values,
+ int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
@@ -185,9 +185,9 @@ class KVStoreLocal : public KVStore {
}
}
- void Pull_(const std::vector<int>& keys,
- const std::vector<NDArray*>& values,
- int priority) {
+ virtual void PullImpl(const std::vector<int>& keys,
+ const std::vector<NDArray*>& values,
+ int priority) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -200,9 +200,9 @@ class KVStoreLocal : public KVStore {
}
}
- void PullRowSparse_(const std::vector<int>& keys,
- const std::vector<std::pair<NDArray*, NDArray>>&
val_rowids,
- int priority = 0) {
+ virtual void PullRowSparseImpl(const std::vector<int>& keys,
+ const std::vector<std::pair<NDArray*,
NDArray>>& val_rowids,
+ int priority = 0) {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].