SINGA-50 Improve the code of ParseUpdateMsgs function

refactored the code by using the memory of one message
or one local worker's grad blob to store the aggregated parameter gradients.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b33e50d6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b33e50d6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b33e50d6

Branch: refs/heads/master
Commit: b33e50d65fb67ee9d3cff2d3242ff36005204fca
Parents: d269b67
Author: Wei Wang <[email protected]>
Authored: Fri Aug 14 22:02:46 2015 +0800
Committer: Wei Wang <[email protected]>
Committed: Sat Aug 15 14:59:11 2015 +0800

----------------------------------------------------------------------
 include/utils/param.h |  18 +++---
 src/utils/param.cc    | 142 +++++++++++++++++++++------------------------
 2 files changed, 74 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b33e50d6/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 18293b6..8fabe71 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -1,5 +1,5 @@
-#ifndef INCLUDE_UTILS_PARAM_H_
-#define INCLUDE_UTILS_PARAM_H_
+#ifndef SINGA_UTILS_PARAM_H_
+#define SINGA_UTILS_PARAM_H_
 #include <vector>
 #include <string>
 #include "proto/job.pb.h"
@@ -28,7 +28,7 @@ namespace singa {
 class Param {
  public:
   Param();
-  virtual ~Param(){ }
+  virtual ~Param() {}
   /**
    * Setup param object
    *
@@ -41,7 +41,7 @@ class Param {
    *
    * @param version initial version
    */
-  virtual void InitValues(int version=0);
+  virtual void InitValues(int version = 0);
   /**
    * Share the data blob from other Param objects.
    *
@@ -113,7 +113,7 @@ class Param {
   }
 
   void set_local_version(int v) {
-    local_version_=v;
+    local_version_ = v;
   }
   const std::string& share_from() const {
     return proto_.share_from();
@@ -280,7 +280,7 @@ class Param {
   vector<int> slice_offset_, slice_size_;
 
   //!< for debug checking
-  vector<bool> pending_put_,pending_get_, pending_update_;
+  vector<bool> pending_put_, pending_get_, pending_update_;
   int num_pending_requests_;
 
   shared_ptr<Blob<float>> data_;
@@ -309,8 +309,8 @@ class ParamEntry{
    */
   void AddParam(bool local, Param* p);
   int num_update, next_version;
-  int num_local; //!< # local workers using the shared parameter
-  int num_total; //!< # total workers using the shared parameter
+  int num_local;  //!< # local workers using the shared parameter
+  int num_total;  //!< # total workers using the shared parameter
   //!< Shares are deleted by neuralnet's destructor
   vector<Param*> shares;
 };
@@ -329,4 +329,4 @@ inline int SliceID(int param_trgt) {
 }
 }  // namespace singa
 
-#endif  // INCLUDE_UTILS_PARAM_H_
+#endif  // SINGA_UTILS_PARAM_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b33e50d6/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index b470ea2..61173cb 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -6,30 +6,30 @@
 #include "proto/job.pb.h"
 #include "mshadow/tensor.h"
 #include "utils/singleton.h"
+namespace singa {
 using namespace mshadow;
 using std::vector;
 using std::string;
-namespace singa {
 
 Param::Param():local_version_(-1), slice_start_(0), num_slices_(0),
   num_pending_requests_(0), data_(nullptr) {
 }
-void Param::Setup(const ParamProto& proto, const vector<int>& shape){
-  data_=std::make_shared<Blob<float>>(shape);
+void Param::Setup(const ParamProto& proto, const vector<int>& shape) {
+  data_ = std::make_shared<Blob<float>>(shape);
   grad_.Reshape(shape);
   history_.Reshape(shape);
   proto_.CopyFrom(proto);
 }
 
-void Param::AddSlice(int slice_id, int size){
-  int offset=0;
-  if(slice_size_.size()>0){
-    //must be added in order
-    CHECK_EQ(slice_start_+num_slices_, slice_id);
-    offset=slice_offset_.back()+slice_size_.back();
+void Param::AddSlice(int slice_id, int size) {
+  int offset = 0;
+  if (slice_size_.size() > 0) {
+    // must be added in order
+    CHECK_EQ(slice_start_ + num_slices_, slice_id);
+    offset = slice_offset_.back() + slice_size_.back();
   } else {
-    slice_start_=slice_id;
-    offset=0;
+    slice_start_ = slice_id;
+    offset = 0;
   }
   slice_offset_.push_back(offset);
   slice_size_.push_back(size);
@@ -39,11 +39,11 @@ void Param::AddSlice(int slice_id, int size){
   num_slices_++;
 }
 
-void Param::InitValues(int version){
+void Param::InitValues(int version) {
   Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
-  auto random=TSingleton<Random<cpu>>::Instance();
+  auto random = TSingleton<Random<cpu>>::Instance();
   switch (proto_.init_method()) {
-  case InitMethod::kConstant:
+  case ParamProto::kConstant:
     data = proto_.value();
     break;
   case InitMethod::kUniform:
@@ -61,8 +61,8 @@ void Param::InitValues(int version){
     break;
   case InitMethod::kUniformSqrtFanInOut:
     random->SampleUniform(data, proto_.low(), proto_.high());
-    if(proto_.value())
-      data *= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]);
+    if (proto_.value())
+      data *= proto_.value() / sqrt(data_->shape()[0] + data_->shape()[1]);
     break;
   case InitMethod::kGaussian:
     random->SampleGaussian(data, proto_.mean(), proto_.std());
@@ -71,8 +71,8 @@ void Param::InitValues(int version){
     break;
   case InitMethod::kGaussainSqrtFanIn:
     random->SampleGaussian(data, proto_.mean(), proto_.std());
-    if(proto_.value())
-      data *= proto_.value()/ sqrt(data_->shape()[0]);
+    if (proto_.value())
+      data *= proto_.value() / sqrt(data_->shape()[0]);
     break;
   default:
     LOG(ERROR) << "Illegal parameter init method ";
@@ -90,50 +90,49 @@ void Param::ToProto(BlobProto* blob) {
 /**************Message related functions********/
 Msg* Param::GenPutMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
-  Msg* msg=new Msg();
+  Msg* msg = new Msg();
   msg->set_type(kPut);
-  void *ptr=mutable_cpu_data()+slice_offset_[idx];
+  void *ptr = mutable_cpu_data() + slice_offset_[idx];
   void *p = ptr;
   if (copy) p = nullptr;
   msg->AddFormatFrame("iffp", slice_size_[idx],
       learning_rate_multiplier(), weight_decay_multiplier(), p);
   if (copy) {
-    msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
+    msg->AddFrame(ptr, slice_size_[idx] * sizeof(float));
   }
-  //pending_put_[idx]=true;
-  //num_pending_requests_++;
-       return msg;
+  // pending_put_[idx]=true;
+  // num_pending_requests_++;
+  return msg;
 }
 
 Msg* Param::GenGetMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
-  Msg* msg=new Msg();
+  Msg* msg = new Msg();
   msg->set_type(kGet);
-  msg->AddFormatFrame("ip",  copy, data_->cpu_data()+slice_offset_[idx]);
-  pending_get_[idx]=true;
+  msg->AddFormatFrame("ip",  copy, data_->cpu_data() + slice_offset_[idx]);
+  pending_get_[idx] = true;
   num_pending_requests_++;
   return msg;
 }
 
 Msg* Param::GenUpdateMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
-  Msg* msg=new Msg();
+  Msg* msg = new Msg();
   msg->set_type(kUpdate);
   msg->AddFormatFrame("i", copy);
-  void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx];
-  if(copy){
-    //LOG(ERROR)<<"Copy in gen update";
+  void* ptr = grad_.mutable_cpu_data() + slice_offset_[idx];
+  if (copy) {
     msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
-  } else { // to share values of grad blob
-    msg->AddFormatFrame("p", ptr);
+  } else {
+    msg->AddFormatFrame("p", ptr);  // to share values of grad blob
   }
-  pending_update_[idx]=true;
+  pending_update_[idx] = true;
   num_pending_requests_++;
   return msg;
 }
 
 Msg* Param::GenSyncMsg(int offset, int size) {
-  Msg* msg=new Msg();
+  Msg* msg = new Msg();
   msg->set_type(kSyncRequest);
   msg->set_trgt(ParamTrgt(-1, id()), local_version());
   // always copy data because syn is between server groups in diff procs
@@ -155,7 +154,7 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
     CHECK((*msg)->NextFrame());
     CHECK_EQ(size* sizeof(float), (*msg)->FrameSize());
     memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float));
-  }else{
+  } else {
     data_->set_cpu_data(ptr);
   }
   if (!reserve)
@@ -167,9 +166,9 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
   int copy;
   float* ptr;
   (*msg)->ParseFormatFrame("ip", &copy, &ptr);
-  if(copy)
+  if (copy) {
     (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size());
-  else if(ptr!=data_->cpu_data()){
+  } else if (ptr != data_->cpu_data()) {
     memcpy(ptr, data_->cpu_data(), sizeof(float)*size());
     data_->set_cpu_data(ptr);
   }
@@ -180,42 +179,33 @@ Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
 }
 
 void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
-  bool reset = true;
-  vector<int> copies;
+  CHECK_GT(msgs.size(), 0);
+  float* server_grad = nullptr;
+  vector<float*> worker_grad;
   for (auto *msg : msgs) {
     int copy;
     msg->ParseFormatFrame("i", &copy);
-    reset = reset && copy;
-    copies.push_back(copy);
-  }
-  int idx = 0;
-  for (auto *msg : msgs) {
-    CHECK(msg->NextFrame());
-    if (copies.at(idx++)) {
-      float* server_grad = mutable_cpu_grad();
-      float* worker_grad = static_cast<float*> (msg->FrameData());
-      if (reset) {
-        memcpy(server_grad, worker_grad, sizeof(float) * size());
-        reset = false;
-      } else {
-        for (int i =0; i < size(); i++)
-          server_grad[i] += worker_grad[i];
-      }
+    msg->NextFrame();
+    float* ptr = nullptr;
+    if (copy) {
+      ptr = static_cast<float*> (msg->FrameData());
+      CHECK_EQ(size() * sizeof(float), msg->FrameSize());
     } else {
-      float* ptr = nullptr;
       msg->ParseFormatFrame("p", &ptr);
-      if (grad_.cpu_data() != ptr) {
-        memcpy(ptr, grad_.cpu_data(), msg->FrameSize());
-        grad_.set_cpu_data(ptr);
-      }
+      server_grad = ptr;
     }
+    worker_grad.push_back(ptr);
   }
-
-  if (msgs.size() > 1) {
-    float* server_grad = mutable_cpu_grad();
-    for (int i = 0; i < size(); i++)
-      server_grad[i] /= msgs.size();
+  if (server_grad == nullptr)
+    server_grad = worker_grad.at(0);
+  for (float* grad : worker_grad) {
+    if (grad != server_grad) {
+      for (int i =0; i < size(); i++) {
+        server_grad[i] += grad[i];
+      }
+    }
   }
+  grad_.set_cpu_data(server_grad);
 }
 
 const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) {
@@ -248,33 +238,33 @@ int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
 
 int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) {
   CHECK_EQ(pending_get_[slice_idx], true);
-  pending_get_[slice_idx]=false;
+  pending_get_[slice_idx] = false;
   ParseResponseMsg(msg, slice_idx);
-  return (--num_pending_requests_)%num_slices_==0;
+  return (--num_pending_requests_)%num_slices_ == 0;
 }
 
 int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) {
   CHECK_EQ(pending_update_[slice_idx], true);
-  pending_update_[slice_idx]=false;
+  pending_update_[slice_idx] = false;
   ParseResponseMsg(msg, slice_idx);
-  return (--num_pending_requests_) % num_slices_==0;
+  return (--num_pending_requests_) % num_slices_ == 0;
 }
 
 void Param::ParseResponseMsg(Msg* msg, int slice_idx) {
   int copy;
   msg->ParseFormatFrame("i", &copy);
   msg->NextFrame();
-  if(copy) {
+  if (copy) {
     CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float));
     memcpy(mutable_cpu_data()+slice_offset_[slice_idx],
         msg->FrameData(), msg->FrameSize());
   }
-  //LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id();
+  // LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id();
 }
 
 void Param::ShareFrom(const Param& other) {
   proto_.set_owner(other.owner());
-  if(data_!=nullptr) {
+  if (data_ != nullptr) {
     CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
           other.data_->shape().begin()));
   }
@@ -299,9 +289,7 @@ ParamEntry::ParamEntry(int total, Param* p) : 
num_update(0), num_total(total) {
 void ParamEntry::AddParam(bool local, Param* p) {
   num_local += local;
   num_total += 1;
-  if(local)
+  if (local)
     shares.push_back(p);
 }
-}
-
-// namespace singa
+}  // namespace singa

Reply via email to