This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1267c6a  fix wrong dist-kvstore push/pull/rsp_pull (#7762)
1267c6a is described below

commit 1267c6a87661b69a115d4d5dde795c17a548afbc
Author: Haibin Lin <[email protected]>
AuthorDate: Thu Sep 7 10:30:26 2017 -0700

    fix wrong dist-kvstore push/pull/rsp_pull (#7762)
---
 include/mxnet/kvstore.h           |   4 +-
 src/kvstore/kvstore_dist.h        | 137 +++++++++++++++++++-------------------
 src/kvstore/kvstore_dist_server.h |   1 +
 src/kvstore/kvstore_local.h       |  40 +++++------
 4 files changed, 91 insertions(+), 91 deletions(-)

diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index bca88a5..ddaa207 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -184,7 +184,7 @@ class KVStore {
    */
   virtual void PullRowSparse(const std::vector<int>& str_keys,
                              const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
-                             const int priority = 0) = 0;
+                             int priority = 0) = 0;
 
   /*!
    * \brief pull a list of key-value pairs from the store, where each key is a 
string.
@@ -196,7 +196,7 @@ class KVStore {
    */
   virtual void PullRowSparse(const std::vector<std::string>& str_keys,
                              const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
-                             const int priority = 0) = 0;
+                             int priority = 0) = 0;
 
   /**
    * \brief the prototype of user-defined updater
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 0cc83a0..6ce6b5a 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -80,8 +80,63 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
-  void Init(const std::vector<int>& keys,
-            const std::vector<NDArray>& values) override {
+  void set_updater(const Updater& updater) override {
+    CHECK(updater) << "invalid updater";
+    if (IsServerNode()) {
+      CHECK_NOTNULL(server_)->set_updater(updater);
+    } else {
+      updater_ = updater;
+    }
+  }
+
+  void Barrier() override {
+    ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
+  }
+
+  void SendCommandToServers(int cmd_id,
+                            const std::string& cmd_body) override {
+    CHECK_NOTNULL(ps_worker_);
+    ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
+  }
+
+  int get_group_size() const override { return ps::NumWorkers(); }
+
+  int get_rank() const override { return ps::MyRank(); }
+
+  int get_num_dead_node(int node_id, int timeout) const override {
+    int number = 0;
+    auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
+    const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
+    std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
+    for (int r : dead_nodes) {
+      if (watch_set.find(r) != watch_set.end()) number++;
+    }
+    return number;
+  }
+
+  void RunServer(const Controller& controller) override {
+    CHECK(!IsWorkerNode());
+    if (IsServerNode()) {
+      server_ = new KVStoreDistServer();
+      server_->set_controller(controller);
+    }
+
+    ps::StartAsync("mxnet_server\0");
+    if (!ps::Postoffice::Get()->is_recovery()) {
+      ps::Postoffice::Get()->Barrier(
+        ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
+    }
+    if (server_) server_->Run();
+    ps::Finalize();
+    if (server_) {
+      delete server_;
+    }
+    server_ = nullptr;
+  }
+
+ private:
+  void InitImpl(const std::vector<int>& keys,
+                const std::vector<NDArray>& values) override {
     CheckUnique(keys);
     for (size_t i = 0; i < keys.size(); ++i) {
       comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), 
values[i].dtype());
@@ -100,15 +155,15 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
-  void Push(const std::vector<int>& keys,
-            const std::vector<NDArray>& values,
-            int priority) override {
+  void PushImpl(const std::vector<int>& keys,
+                const std::vector<NDArray>& values,
+                int priority) override {
     Push_(keys, values, priority, true);
   }
 
-  void Pull(const std::vector<int>& keys,
-            const std::vector<NDArray*>& values,
-            int priority) override {
+  void PullImpl(const std::vector<int>& keys,
+                const std::vector<NDArray*>& values,
+                int priority) override {
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray*> > grouped_vals;
     GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -155,9 +210,9 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
-  void PullRowSparse(const std::vector<int>& keys,
-                     const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
-                     const int priority = 0) {
+  void PullRowSparseImpl(const std::vector<int>& keys,
+                         const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
+                         int priority = 0) override {
     std::vector<int> uniq_keys;
     std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
     GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
@@ -198,66 +253,10 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
-  void set_updater(const Updater& updater) override {
-    CHECK(updater) << "invalid updater";
-    if (IsServerNode()) {
-      CHECK_NOTNULL(server_)->set_updater(updater);
-    } else {
-      updater_ = updater;
-    }
-  }
-
-  void Barrier() override {
-    ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
-  }
-
-
-  void SendCommandToServers(int cmd_id,
-                            const std::string& cmd_body) override {
-    CHECK_NOTNULL(ps_worker_);
-    ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup));
-  }
-
-  int get_group_size() const override { return ps::NumWorkers(); }
-
-  int get_rank() const override { return ps::MyRank(); }
-
-  int get_num_dead_node(int node_id, int timeout) const override {
-    int number = 0;
-    auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout);
-    const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id);
-    std::unordered_set<int> watch_set(watch_nodes.begin(), watch_nodes.end());
-    for (int r : dead_nodes) {
-      if (watch_set.find(r) != watch_set.end()) number++;
-    }
-    return number;
-  }
-
-  void RunServer(const Controller& controller) override {
-    CHECK(!IsWorkerNode());
-    if (IsServerNode()) {
-      server_ = new KVStoreDistServer();
-      server_->set_controller(controller);
-    }
-
-    ps::StartAsync("mxnet_server\0");
-    if (!ps::Postoffice::Get()->is_recovery()) {
-      ps::Postoffice::Get()->Barrier(
-        ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
-    }
-    if (server_) server_->Run();
-    ps::Finalize();
-    if (server_) {
-      delete server_;
-    }
-    server_ = nullptr;
-  }
-
- private:
   void Push_(const std::vector<int>& keys,
              const std::vector<NDArray>& values,
              int priority,
-             bool do_merge)  {
+             bool do_merge) {
     // first aggregate the values over keys
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray> > grouped_vals;
@@ -320,7 +319,7 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   // pull row sparse weight into `recv_buf` based on indices given by `indices`
-  void PullRowSparse_(int key, NDArray *recv_buf, const NDArray& indices, int 
priority) {
+  void PullRowSparse_(const int key, NDArray *recv_buf, const NDArray& 
indices, int priority) {
     using namespace rowsparse;
     auto pull_from_servers = [this, key, recv_buf, indices]
                              (RunContext rctx, Engine::CallbackOnComplete cb) {
diff --git a/src/kvstore/kvstore_dist_server.h 
b/src/kvstore/kvstore_dist_server.h
index 43a10b0..88bdcab 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -339,6 +339,7 @@ class KVStoreDistServer {
       auto len = unit_len * num_rows;
       // concat values
       response.vals.resize(len);
+      #pragma omp parallel for
       for (size_t i = 1; i <= num_rows; i++) {
         int key = DecodeKey(req_data.keys[i]);
         int64_t row_id = key - master_key;
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index db1d04a..ac2968b 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -67,7 +67,7 @@ class KVStoreLocal : public KVStore {
   void Init(const std::vector<int>& keys,
             const std::vector<NDArray>& values) override {
     SetKeyType(kIntKey);
-    Init_(keys, values);
+    InitImpl(keys, values);
   }
 
   void Init(const std::vector<std::string>& str_keys,
@@ -84,28 +84,28 @@ class KVStoreLocal : public KVStore {
       reverse_str_key_dict_[key] = str_key;
       keys[i] = key;
     }
-    Init_(keys, values);
+    InitImpl(keys, values);
   }
 
   void Push(const std::vector<int>& keys,
             const std::vector<NDArray>& values,
             int priority) override {
     SetKeyType(kIntKey);
-    Push_(keys, values, priority);
+    PushImpl(keys, values, priority);
   }
 
   void Pull(const std::vector<int>& keys,
             const std::vector<NDArray*>& values,
             int priority) override {
     SetKeyType(kIntKey);
-    Pull_(keys, values, priority);
+    PullImpl(keys, values, priority);
   }
 
   void PullRowSparse(const std::vector<int>& keys,
                      const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
                      int priority = 0) override {
     SetKeyType(kIntKey);
-    PullRowSparse_(keys, val_rowids, priority);
+    PullRowSparseImpl(keys, val_rowids, priority);
   }
 
   void Push(const std::vector<std::string>& str_keys,
@@ -114,7 +114,7 @@ class KVStoreLocal : public KVStore {
     SetKeyType(kStringKey);
     std::vector<int> keys(str_keys.size());
     LookupKeys(str_keys, &keys);
-    Push_(keys, values, priority);
+    PushImpl(keys, values, priority);
   }
 
   void Pull(const std::vector<std::string>& str_keys,
@@ -123,21 +123,21 @@ class KVStoreLocal : public KVStore {
     SetKeyType(kStringKey);
     std::vector<int> keys(str_keys.size());
     LookupKeys(str_keys, &keys);
-    Pull_(keys, values, priority);
+    PullImpl(keys, values, priority);
   }
 
   void PullRowSparse(const std::vector<std::string>& str_keys,
                      const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
-                     const int priority = 0) override {
+                     int priority = 0) override {
     SetKeyType(kStringKey);
     std::vector<int> keys(str_keys.size());
     LookupKeys(str_keys, &keys);
-    PullRowSparse_(keys, val_rowids, priority);
+    PullRowSparseImpl(keys, val_rowids, priority);
   }
 
  private:
-  void Init_(const std::vector<int>& keys,
-             const std::vector<NDArray>& values) {
+  virtual void InitImpl(const std::vector<int>& keys,
+                        const std::vector<NDArray>& values) {
     for (size_t i = 0; i < keys.size(); ++i) {
       CHECK(local_.find(keys[i]) == local_.end())
           << "duplicate init of key " << keys[i];
@@ -146,9 +146,9 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void Push_(const std::vector<int>& keys,
-             const std::vector<NDArray>& values,
-             int priority) {
+  virtual void PushImpl(const std::vector<int>& keys,
+                        const std::vector<NDArray>& values,
+                        int priority) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray> > grouped_vals;
     GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
@@ -185,9 +185,9 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void Pull_(const std::vector<int>& keys,
-             const std::vector<NDArray*>& values,
-             int priority) {
+  virtual void PullImpl(const std::vector<int>& keys,
+                        const std::vector<NDArray*>& values,
+                        int priority) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<NDArray*> > grouped_vals;
     GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);
@@ -200,9 +200,9 @@ class KVStoreLocal : public KVStore {
     }
   }
 
-  void PullRowSparse_(const std::vector<int>& keys,
-                      const std::vector<std::pair<NDArray*, NDArray>>& 
val_rowids,
-                      int priority = 0) {
+  virtual void PullRowSparseImpl(const std::vector<int>& keys,
+                                 const std::vector<std::pair<NDArray*, 
NDArray>>& val_rowids,
+                                 int priority = 0) {
     std::vector<int> uniq_keys;
     std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
     GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to