SINGA-21 Code review 5

review server.h, server.cc
 - format code
 - remove thread_id field
 - rename variables
   nUpdate_ -> n_update_
   nPendingSync_ -> n_pending_sync_
 - fix a bug in HandleUpdate that using a *msg in a unknown state

TODO:
 - give each socket an unique id from cluster
 - buffer the un-processed message, intead of sending it back to stub


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d3e1fca3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d3e1fca3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d3e1fca3

Branch: refs/heads/master
Commit: d3e1fca38b97e06ca113369d9a4f583750105a39
Parents: 3161175
Author: wang sheng <[email protected]>
Authored: Tue Sep 22 17:22:33 2015 +0800
Committer: wang sheng <[email protected]>
Committed: Tue Sep 22 17:28:41 2015 +0800

----------------------------------------------------------------------
 include/trainer/server.h  |  74 ++++++++++++++---------------
 include/trainer/trainer.h |   2 +-
 src/trainer/server.cc     | 104 +++++++++++++++++++++--------------------
 src/trainer/trainer.cc    |   6 +--
 4 files changed, 92 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index 3f1c12d..3f3539a 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -19,17 +19,20 @@
 *
 *************************************************************/
 
-#ifndef INCLUDE_TRAINER_SERVER_H_
-#define INCLUDE_TRAINER_SERVER_H_
+#ifndef SINGA_TRAINER_SERVER_H_
+#define SINGA_TRAINER_SERVER_H_
+
 #include <memory>
 #include <unordered_map>
-#include <utils/param.h>
-#include <utils/updater.h>
-#include "proto/job.pb.h"
+#include <vector>
 #include "communication/socket.h"
+#include "proto/job.pb.h"
+#include "utils/param.h"
+#include "utils/updater.h"
 
 namespace singa {
-/* Repsond to worker's get/put/udpate request, and periodically syncing with
+
+ /* Repsond to worker's get/put/udpate request, and periodically syncing with
   * other servers.
   *
   * Normally, the Server creates a response message for each request which
@@ -39,33 +42,26 @@ namespace singa {
   * it just sends it to the router. The router will decide to re-send the
   * request to the server or send it to the worker.
   */
-class Server{
+class Server {
  public:
-  Server(int thread_id, int group_id, int server_id);
+  Server(int group_id, int server_id);
   virtual ~Server();
-  void Setup(const UpdaterProto& proto,
-      const std::vector<int>& slice2group,
-      const std::vector<int>& slice2server);
+  void Setup(const UpdaterProto& proto, const std::vector<int>& slice2group,
+             const std::vector<int>& slice2server);
   void Run();
-  const int grp_id() const {
-    return grp_id_;
-  }
-  const int id() const {
-    return id_;
-  }
+  inline int grp_id() const { return grp_id_; }
+  inline int id() const { return id_; }
 
  protected:
-
-       /**
-        * Process GET request.
+  /**
+   * Process GET request.
    *
    * @return the orignal message or a response message which contains the 
values
    * of the Param with the request version.
    */
-       virtual Msg* HandleGet(Msg** msg);
-
-       /**
-        * Process Update request.
+  virtual Msg* HandleGet(Msg** msg);
+  /**
+   * Process Update request.
    *
    * It waits until received the gradients from all workers from the same 
worker
    * group. After updating, it responses to each sender with the new Param
@@ -86,16 +82,14 @@ class Server{
    * @return the orignal message or response message
    */
   const std::vector<Msg*> HandleUpdate(Msg **msg);
-
-       /**
-        * Process PUT request.
+  /**
+   * Process PUT request.
    *
    * @return the original message or response message. If we don't want to
    * acknowledge the put request, then return nullptr.
-        */
-       virtual Msg* HandlePut(Msg **msg);
-
-       /**
+   */
+  virtual Msg* HandlePut(Msg **msg);
+  /**
    * Handle sync request from other server groups.
    *
    * It adds updates of Param (slice) from other server groups directly to
@@ -106,8 +100,7 @@ class Server{
    * @param msg request msg containing the parameter updates
    * @return response msg that contains the fresh parameter values.
    */
-       virtual Msg* HandleSyncRequest(Msg** msg);
-
+  virtual Msg* HandleSyncRequest(Msg** msg);
   /**
    * Handle sync response.
    *
@@ -121,17 +114,20 @@ class Server{
   void HandleSyncResponse(Msg** msg);
 
  protected:
-  int thread_id_,grp_id_, id_;
-  Updater* updater_;
+  int grp_id_ = -1;
+  int id_ = -1;
+  Updater* updater_ = nullptr;
   //!< map from slice ID to slice and deleted in the destructor
   std::unordered_map<int, ParamEntry*> shard_;
   std::vector<int> slice2group_, slice2server_;
   //!< num of updates from last sync with master server group for a param/slice
-  std::vector<int> nUpdates_;
+  std::vector<int> n_updates_;
   //!< num of sync requests that have not been responded
-  std::vector<int> nPendingSync_;
+  std::vector<int> n_pending_sync_;
   std::vector<Blob<float>> last_sync_;
   std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
 };
-} /* Server */
-#endif //INCLUDE_TRAINER_SERVER_H_
+
+}  // namespace singa
+
+#endif  // SINGA_TRAINER_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index 6630e51..0b03dea 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -73,7 +73,7 @@ class Trainer{
    * @param jobConf
    * @return server instances
    */
-  vector<Server*> CreateServers(int nthread, const JobProto& jobConf);
+  vector<Server*> CreateServers(const JobProto& jobConf);
   /**
    * Create workers instances.
    * @param nthread total num of threads in current procs which is used to

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 18fe7d2..29f6a68 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -22,30 +22,30 @@
 #include <thread>
 #include <chrono>
 #include "mshadow/tensor.h"
+#include "proto/common.pb.h"
 #include "trainer/server.h"
 #include "utils/param.h"
 #include "utils/singleton.h"
 #include "utils/factory.h"
 #include "utils/cluster.h"
-#include "proto/common.pb.h"
 
 namespace singa {
 
 using namespace mshadow;
 using std::vector;
 
-Server::Server(int thread_id,int group_id, int server_id):
-  thread_id_(thread_id),grp_id_(group_id), id_(server_id){
+Server::Server(int group_id, int server_id) {
+  grp_id_ = group_id;
+  id_ = server_id;
 }
 
-void Server::Setup(const UpdaterProto& proto,
-    const vector<int>& slice2group,
-    const vector<int>& slice2server) {
+void Server::Setup(const UpdaterProto& proto, const vector<int>& slice2group,
+                   const vector<int>& slice2server) {
   updater_ = Updater::Create(proto);
   slice2group_ = slice2group;
   slice2server_ = slice2server;
-  nUpdates_.resize(slice2group_.size(), 0);
-  nPendingSync_.resize(slice2group_.size(), 0);
+  n_updates_.resize(slice2group_.size(), 0);
+  n_pending_sync_.resize(slice2group_.size(), 0);
   last_sync_.resize(slice2group_.size());
 }
 
@@ -57,14 +57,14 @@ Server::~Server() {
       delete param;
 }
 
-void Stop(void * running) {
+void Stop(void* running) {
   *static_cast<bool *>(running) = false;
 }
 
 void Server::Run() {
   LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_ << ") start";
-
-  auto dealer = new Dealer(2*thread_id_);
+  // TODO(wangsh): give each dealer a unique id
+  auto dealer = new Dealer(0);
   CHECK(dealer->Connect(kInprocRouterEndpoint));
   Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub));
   ping->set_type(kConnect);
@@ -77,7 +77,7 @@ void Server::Run() {
   // start recv loop and process requests
   while (running) {
     // must use poller here; otherwise Receive() gets stuck after workers stop.
-    auto *sock = poll.Wait(cluster->poll_time());
+    auto* sock = poll.Wait(cluster->poll_time());
     if (poll.Terminated()) {
       LOG(ERROR) << "Connection broken!";
       exit(0);
@@ -85,35 +85,35 @@ void Server::Run() {
       continue;
     }
     Msg* msg = dealer->Receive();
-    if (msg == nullptr) break; //  interrupted
+    if (msg == nullptr) break;  // interrupted
     Msg* response = nullptr;
     int type = msg->type();
     int slice_id = SliceID(msg->trgt_val());
     if (type == kPut) {
       response = HandlePut(&msg);
+    } else if (shard_.find(slice_id) == shard_.end()) {
+      // TODO(wangsh): buffer the msg instead, and process it after the
+      //               corresponding put request is done
+      // delay the processing by re-queue the msg. May sleep for a while?
+      response = msg;
     } else {
-      if (shard_.find(slice_id) == shard_.end()) {
-        // delay the processing by re-queue the msg. May sleep for a while?
-        response = msg;
-      }  else {
-        switch (type) {
-          case kGet:
-            response = HandleGet(&msg);
-            break;
-          case kUpdate:
-            for (auto reply : HandleUpdate(&msg))
-              dealer->Send(&reply);
-            break;
-          case kSyncRequest:
-            response = HandleSyncRequest(&msg);
-            break;
-          case kSyncResponse:
-            HandleSyncResponse(&msg);
-            break;
-          default:
-            LOG(ERROR)<<"Unknown message type "<<type;
-            break;
-        }
+      switch (type) {
+        case kGet:
+          response = HandleGet(&msg);
+          break;
+        case kUpdate:
+          for (auto reply : HandleUpdate(&msg))
+            dealer->Send(&reply);
+          break;
+        case kSyncRequest:
+          response = HandleSyncRequest(&msg);
+          break;
+        case kSyncResponse:
+          HandleSyncResponse(&msg);
+          break;
+        default:
+          LOG(ERROR) << "Unknown message type: " << type;
+          break;
       }
     }
     if (response != nullptr)
@@ -125,7 +125,6 @@ void Server::Run() {
   msg->set_type(kStop);
   dealer->Send(&msg);
   std::this_thread::sleep_for(std::chrono::milliseconds(1000));
-
   LOG(ERROR) << "Server (group = " << grp_id_ << ", id = " << id_ << ") stops";
   delete dealer;
 }
@@ -154,8 +153,8 @@ Msg* Server::HandlePut(Msg **msg) {
     last_sync_[slice_id].ReshapeLike(param->data());
     last_sync_[slice_id].CopyFrom(param->data());
   }
-  LOG(INFO)<<"server (group = " << grp_id_ << ", id = " << id_ <<") put slice="
-    << slice_id << " size=" << param->size();
+  LOG(INFO) << "server (group = " << grp_id_ << ", id = " << id_
+            <<") put slice=" << slice_id << " size=" << param->size();
   return response;
 }
 
@@ -163,9 +162,9 @@ Msg* Server::HandleGet(Msg **msg) {
   int val = (*msg)->trgt_val();
   auto param = shard_.at(SliceID(val))->shares.at(0);
   // re-queue the request if the param is not updated to the required version
-  if(param->version()<(*msg)->trgt_version())
+  if (param->version() < (*msg)->trgt_version()) {
     return *msg;
-  else {
+  } else {
     // LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first();
     auto reply = param->HandleGetMsg(msg, false);
     reply->set_trgt(val, param->version());
@@ -183,12 +182,14 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
   (*msg)->ParseFormatFrame("i", &num_update);
   (*msg)->FirstFrame();
   entry->num_update += num_update;
-  // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src()) 
<< ", " << num_update << " total " << entry->num_total;
+  // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src())
+  //            << ", " << num_update << " total " << entry->num_total;
   // do update until recv gradients from all shares of this param/slice
   if (entry->num_update >= entry->num_total) {
     CHECK_EQ(entry->num_update, entry->num_total);
     auto& request = buffer_requests_.at(sliceid);
     int step = (*msg)->trgt_version();
+    int trgt_val = (*msg)->trgt_val();
     auto param = entry->shares.at(0);
     // extract and aggregate gradients
     param->ParseUpdateMsgs(request);
@@ -196,16 +197,16 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
     param->set_local_version(param->local_version() + 1);
     // response to all shares of this param
     for (auto response : param->GenUpdateResponseMsgs(&request, false)) {
-      response->set_trgt((*msg)->trgt_val(), param->local_version());
+      response->set_trgt(trgt_val, param->local_version());
       ret.push_back(response);
     }
     entry->num_update = 0;
-    nUpdates_[sliceid]++;
+    n_updates_[sliceid]++;
     // sync with master group after at least sync_freq local updates
     // the last check is to avoid sending msg to stopped servers
     if (slice2group_[sliceid] != grp_id_
-        && nUpdates_[sliceid] >= Cluster::Get()->sync_freq()
-        && nPendingSync_[sliceid] <= Cluster::Get()->sync_freq()) {
+        && n_updates_[sliceid] >= Cluster::Get()->sync_freq()
+        && n_pending_sync_[sliceid] <= Cluster::Get()->sync_freq()) {
       auto shape = Shape1(param->size());
       Tensor<cpu, 1> tmp(last_sync_[sliceid].mutable_cpu_data(), shape);
       Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
@@ -213,14 +214,15 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
       int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer);
       Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr);
       sync->set_type(kSyncRequest);
-      sync->set_trgt((*msg)->trgt_val(), param->local_version());
+      sync->set_trgt(trgt_val, param->local_version());
       sync->AddFrame(tmp.dptr, param->size() * sizeof(float));
       Copy(tmp, cur);
       ret.push_back(sync);
-      nUpdates_[sliceid] = 0;
-      nPendingSync_[sliceid]++;
+      n_updates_[sliceid] = 0;
+      n_pending_sync_[sliceid]++;
     }
   }
+  // message already pushed to buffer, just need to reset the pointer
   *msg = nullptr;
   return ret;
 }
@@ -247,14 +249,14 @@ void Server::HandleSyncResponse(Msg **msg) {
   Msg* msgg = *msg;
   int slice = SliceID(msgg->trgt_val());
   auto param = shard_.at(slice)->shares.at(0);
-  auto shape=Shape1(param->size());
+  auto shape = Shape1(param->size());
   Tensor<cpu, 1> prev(last_sync_[param->id()].mutable_cpu_data(), shape);
   Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape);
   Tensor<cpu, 1> master(static_cast<float*>(msgg->FrameData()), shape);
   cur += master - prev;  // cur = master + (cur - prev);
   Copy(prev, cur);
   DeleteMsg(msg);
-  nPendingSync_[slice]--;
+  n_pending_sync_[slice]--;
 }
 
-} /* singa */
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 4a4c183..c928d91 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -144,7 +144,7 @@ void Trainer::SetupWorkerServer(
     server->Setup(job_conf.updater(), slice2group, slice2server_);
 }
 
-vector<Server*> Trainer::CreateServers(int nthreads, const JobProto& job) {
+vector<Server*> Trainer::CreateServers(const JobProto& job) {
   auto cluster = Cluster::Get();
   vector<Server*> servers;
   if (!cluster->has_server())
@@ -160,7 +160,7 @@ vector<Server*> Trainer::CreateServers(int nthreads, const 
JobProto& job) {
   int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3];
   for (int gid = gstart; gid < gend; gid++) {
     for (int sid = start; sid < end; sid++) {
-      auto server = new Server(nthreads++, gid, sid);
+      auto server = new Server(gid, sid);
       servers.push_back(server);
     }
   }
@@ -244,7 +244,7 @@ void Trainer::Start(bool resume, const SingaProto& 
singaConf, JobProto* job) {
   int nthreads = 1;
   const vector<Worker*> workers = CreateWorkers(nthreads, *job);
   nthreads += workers.size();
-  const vector<Server*> servers = CreateServers(nthreads, *job);
+  const vector<Server*> servers = CreateServers(*job);
   SetupWorkerServer(*job, workers, servers);
 
 #ifdef USE_MPI

Reply via email to