Repository: incubator-singa
Updated Branches:
  refs/heads/master fbbcaafdb -> ae2030362


SINGA-21 Code review 4

review param.h, param.cc
  - ShareFrom(): init vectors by resizing instead of assignning
  - add a summary for put/get/update/sync request workflows
  - implement reserve flags correctly for all HandleXX functions
    TODO: remove all reserve-flag checks later
  - format the code

format update.h, updater.cc


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f99246e7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f99246e7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f99246e7

Branch: refs/heads/master
Commit: f99246e7840eb5fcec2918b37736cba0344db19e
Parents: fbbcaaf
Author: wang sheng <[email protected]>
Authored: Fri Aug 21 14:50:32 2015 +0800
Committer: wangwei <[email protected]>
Committed: Fri Aug 28 17:58:40 2015 +0800

----------------------------------------------------------------------
 include/utils/param.h     | 269 ++++++++++++++++++-----------------------
 include/utils/singleton.h |   2 +-
 include/utils/updater.h   |  42 ++++---
 src/trainer/server.cc     |   7 +-
 src/utils/param.cc        | 211 ++++++++++++++++++--------------
 src/utils/updater.cc      |  40 +++---
 6 files changed, 287 insertions(+), 284 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index 0d24e95..4909026 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -1,56 +1,56 @@
 #ifndef SINGA_UTILS_PARAM_H_
 #define SINGA_UTILS_PARAM_H_
-#include <vector>
+
+#include <memory>
 #include <string>
+#include <vector>
+#include "communication/msg.h"
 #include "proto/job.pb.h"
 #include "utils/blob.h"
-#include "communication/msg.h"
 
 namespace singa {
 
 /**
  * Base parameter generator which intializes parameter values.
  */
-
 class ParamGenerator {
  public:
   static ParamGenerator* Create(const ParamGenProto& proto);
-  virtual ~ParamGenerator() {}
 
-  virtual void Init(const ParamGenProto& proto) {
-    proto_ = proto;
-  }
+  virtual ~ParamGenerator() {}
 
+  virtual void Init(const ParamGenProto& proto) { proto_ = proto; }
   virtual void Fill(Blob<float>* data);
 
  protected:
   ParamGenProto proto_;
 };
 
-class GaussianGen: public ParamGenerator {
+class GaussianGen : public ParamGenerator {
  public:
   void  Fill(Blob<float>* data) override;
 };
 
-class UniformGen: public ParamGenerator {
+class GaussianSqrtFanInGen : public GaussianGen {
  public:
   void  Fill(Blob<float>* data) override;
 };
 
-class GaussianSqrtFanInGen: public GaussianGen {
+class UniformGen : public ParamGenerator {
  public:
   void  Fill(Blob<float>* data) override;
 };
 
-class UniformSqrtFanInGen: public UniformGen {
+class UniformSqrtFanInGen : public UniformGen {
  public:
   void Fill(Blob<float>* data) override;
 };
 
-class UniformSqrtFanInOutGen: public UniformGen {
+class UniformSqrtFanInOutGen : public UniformGen {
  public:
   void Fill(Blob<float>* data) override;
 };
+
 /**
  * Base paramter class.
  *
@@ -72,7 +72,8 @@ class UniformSqrtFanInOutGen: public UniformGen {
 class Param {
  public:
   static Param* Create(const ParamProto& proto);
-  Param();
+
+  Param() {}
   virtual ~Param() {}
   void Init(const ParamProto& proto) { proto_ = proto; }
   /**
@@ -87,160 +88,125 @@ class Param {
    *
    * @param version initial version
    */
-  virtual void InitValues(int version = 0);
+  virtual void InitValues();
+  virtual void InitValues(int version);
   /**
    * Share the data blob from other Param objects.
    *
    * @param other the Param object whose owner owns the data blob
    */
   void ShareFrom(const Param& other);
-
+  /**
+   * Init param values from checkpoint blob.
+   */
+  void FromProto(const BlobProto& blob);
+  /**
+   * Dump param values to blob.
+   */
+  void ToProto(BlobProto* blob);
+  /**
+   * Add a slice
+   *
+   * @param slice_id
+   * @param size num of floats for this slice
+   */
+  void AddSlice(int slice_id, int size);
   /**
    * Scale the learning rate when updating parameters in the Param object
    */
-  float lr_scale() {
-    return proto_.lr_scale();
-  }
+  inline float lr_scale() const { return proto_.lr_scale(); }
   /**
    * Scale the weight decay when updating parameters in the Param object
    */
-  float wd_scale() {
-    return proto_.wd_scale();
-  }
+  inline float wd_scale() const { return proto_.wd_scale(); }
   /**
    * Parameter name used for Param re-use in other model or sharing between
    * layers
    */
-  const std::string& name() {
-    return proto_.name();
-  }
-  void set_name(const std::string& name) {
-    proto_.set_name(name);
-  }
+  inline const std::string& name() const { return proto_.name(); }
+  inline void set_name(const std::string& name) { proto_.set_name(name); }
   /**
    * If it shares data from others, then owner is the id of that Param,
    * otherwise it is itself's id.
    */
-  const int owner() const {
-    return proto_.owner();
-  }
+  inline int owner() const { return proto_.owner(); }
   /**
    * ID start from 0 and ordered for all Param from the same neuralnet
    */
-  int id() const {
-    return proto_.id();
-  }
+  inline int id() const { return proto_.id(); }
   /**
    * Set ID
    */
-  void set_id(int id) {
+  inline void set_id(int id) {
     proto_.set_id(id);
     proto_.set_owner(id);
   }
-
   /**
    * Param version is stored inside the data blob to enable all Param objs
    * sharing the same values have the same version.
    * @return the param version
    */
-  int version() const {
-    return data_->version();
-  }
-
-  void set_version(int v) {
-    data_->set_version(v);
-  }
-
+  inline int version() const { return data_->version(); }
+  inline void set_version(int v) { data_->set_version(v); }
   /**
    * @return the version of the parameter value local to a worker
    */
-  int local_version() const {
-    return local_version_;
-  }
-
-  void set_local_version(int v) {
-    local_version_ = v;
-  }
-  const std::string& share_from() const {
-    return proto_.share_from();
-  }
+  inline int local_version() const { return local_version_; }
+  inline void set_local_version(int v) { local_version_ = v; }
+  inline const std::string& share_from() const { return proto_.share_from(); }
    /**
     * @return num of floats.
     */
-  int size() const {
-    return data_->count();
-  }
-  const Blob<float> &data() {
-    return *data_;
-  }
-  Blob<float> *mutable_data() {
-    return data_.get();
-  }
-  /**
-   * Return gradient of this parameter
-   */
-  const Blob<float> &grad() {
-    return grad_;
-  }
-  Blob<float> *mutable_grad() {
-    return &grad_;
-  }
-  float* mutable_cpu_data() {
-    return data_->mutable_cpu_data();
-  }
-  float* mutable_cpu_grad() {
-    return grad_.mutable_cpu_data();
-  }
-  float* mutable_cpu_history() {
-    return history_.mutable_cpu_data();
-  }
-
+  inline int size() const { return data_->count(); }
+  inline const Blob<float>& data() const { return *data_; }
+  inline Blob<float>* mutable_data() { return data_.get(); }
+  inline const Blob<float> &grad() const { return grad_; }
+  inline Blob<float> *mutable_grad() { return &grad_; }
+  inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); }
+  inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); }
+  inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); }
   /**
    * @return slice start ID
    */
-  int slice_start() const {
-    return slice_start_;
-  }
+  inline int slice_start() const { return slice_start_; }
+  inline int num_slices() const { return num_slices_; }
 
-  int num_slices() const {
-    return num_slices_;
-  }
-
-  /**
-   * Add a slice
-   *
-   * @param slice_id
-   * @param size num of floats for this slice
-   */
-  void AddSlice(int slice_id, int size);
   /**
-   * Init param values from checkpoint blob.
+   * Below are message/request related functions.
+   * The basic communication workflows are as follow:
+   *------------------------------------------------------------------------
+   *         |Put         |Get           |Update           |Sync
+   *------------------------------------------------------------------------
+   * Generate|(stub)      |(stub)        |(stub)           |(server)
+   * Message |GenPutMsg   |GenGetMsg     |GenUpdateMsg     |GenSyncMsg
+   *------------------------------------------------------------------------
+   * Handle  |(server)    |(server)      |(server)         |(server)
+   * Message |HandlePutMsg|HandleGetMsg  |ParseUpdateMsg   |HandleSyncMsg
+   *         |            |              |GenUpdateResMsg  |
+   *------------------------------------------------------------------------
+   * Handle  |            |(stub)        |(stub)           |(server)
+   * Response|            |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg
+   *------------------------------------------------------------------------
    */
-  void FromProto(const BlobProto& blob);
-  /**
-   * Dump param values to blob.
-   */
-  void ToProto(BlobProto* blob);
-  /**********************Msg related functions***************************/
 
   /**
-   * Generate the message for a get request, i.e., get parameters from a server
+   * Generate the message for a put request, i.e., put parameters to a server
    *
    * This function is called at worker/stub side.
    * @param copy decides whether to copy the parameter values from the server.
    * @param slice_idx index of the slice from which the message is generated.
    * @return generated message without setting src, dst, target fields.
    */
-  virtual Msg* GenGetMsg(bool copy, int slice_idx);
+  virtual Msg* GenPutMsg(bool copy, int slice_idx);
   /**
-   * Generate the message for a put request, i.e., put parameters to a server.
-   * \copydetails GenGetMsg(bool, int);
+   * Generate the message for a get request, i.e., get parameters from a server
+   * \copydetails GenPutMsg(bool, int);
    */
-  virtual Msg* GenPutMsg(bool copy, int slice_idx);
+  virtual Msg* GenGetMsg(bool copy, int slice_idx);
   /**
    * Generate the message for a update request, i.e., pass info to server for
    * parameter update.
-   * \copydetails GenGetMsg(bool, int);
+   * \copydetails GenPutMsg(bool, int);
    */
   virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
   /**
@@ -251,6 +217,26 @@ class Param {
    * */
   virtual Msg* GenSyncMsg(int offset, int size);
   /**
+   * Server handling function for put request.
+   *
+   * @param msg request
+   * @param reserve if true reserve the msg space for the calling function;
+   * otherwise the msg should be freed inside the function.
+   * @return resposne message
+   */
+  virtual Msg* HandlePutMsg(Msg** msg, bool reserve);
+  /**
+   * Server handling function for put request.
+   *
+   * \copydetails HandleGetMsg(Msg**, bool reserve)
+   */
+  virtual Msg* HandleGetMsg(Msg** msg, bool reserve);
+  /**
+   * Server parse update requests.
+   * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
+   */
+  virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
+  /**
    * Generate the messages to response the update requests.
    *
    * This function is called at the server side, where the Param is actually a
@@ -263,29 +249,13 @@ class Param {
    * @return response messages
    */
   virtual const std::vector<Msg*>
-    GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
-
-  /**
-   * Server handling function for get request.
-   *
-   * @param msg request
-   * @param reserve if true reserve the msg space for the calling function;
-   * otherwise the msg should be freed inside the function.
-   * @return resposne message
-   */
-  virtual Msg* HandleGetMsg(Msg** msg, bool reserve = false);
-  /**
-   * Server handling function for put request.
-   *
-   * \copydetails HandleGetMsg(Msg**, bool reserve)
-   */
-  virtual Msg* HandlePutMsg(Msg** msg, bool reserve = false);
+    GenUpdateResponseMsgs(std::vector<Msg*>* msgs, bool reserve);
   /**
    * Server handling function for synchronization message
    *
    * \copydetails HandleGetMsg(Msg**, bool reserve)
    */
-  virtual Msg* HandleSyncMsg(Msg** msg, bool reserve = false);
+  virtual Msg* HandleSyncMsg(Msg** msg, bool reserve);
   /**
    * Worker/Stub parsing function for get response.
    *
@@ -300,11 +270,6 @@ class Param {
    */
   virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx);
   /**
-   * Server parse update requests.
-   * \copydetails GenUpdateResponseMsgs(const std::vector<Msg*>& msgs);
-   */
-  virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
-  /**
    * Server parsing function for synchronization response.
    *
    * \copydetails ParseGetResponseMsg(Msg** , int);
@@ -319,19 +284,21 @@ class Param {
   void ParseResponseMsg(Msg* msg, int slice_idx);
 
  protected:
-  int local_version_;
-  //!< the ID of the first slice
-  int slice_start_;
-  int num_slices_;
-  //!< offset and size of each slice
-  std::vector<int> slice_offset_, slice_size_;
-
-  //!< for debug checking
-  std::vector<bool> pending_put_, pending_get_, pending_update_;
-  int num_pending_requests_;
-
-  std::shared_ptr<Blob<float>> data_;
-  //! gradient, history gradient of this parameter
+  int local_version_ = -1;
+  // the ID of the first slice
+  int slice_start_ = 0;
+  int num_slices_ = 0;
+  // offset and size of each slice
+  std::vector<int> slice_offset_;
+  std::vector<int> slice_size_;
+  // for debug checking
+  // since put request has no feedback, we do not track its pending status
+  std::vector<bool> pending_get_;
+  std::vector<bool> pending_update_;
+  int num_pending_requests_ = 0;
+  // data field
+  std::shared_ptr<Blob<float>> data_ = nullptr;
+  // gradient, history gradient of this parameter
   Blob<float> grad_, history_;
   ParamProto proto_;
 };
@@ -344,9 +311,9 @@ class Param {
  * Param objects sharing the same values are associated with the same
  * ParamEntry.
  */
-class ParamEntry{
+class ParamEntry {
  public:
-  ParamEntry();
+  ParamEntry() {}
   ParamEntry(int total, Param* p);
   /**
    * Associate the counter to a Param object.
@@ -355,9 +322,10 @@ class ParamEntry{
    * @param local 1 if it is used by workers in this procs, 0 otherwise
    */
   void AddParam(bool local, Param* p);
-  int num_update, next_version;
-  int num_local;  //!< # local workers using the shared parameter
-  int num_total;  //!< # total workers using the shared parameter
+  int next_version = -1;  // next_version & num_update are directly used by 
stub
+  int num_update = 0;
+  int num_local = 0;  //!< # local workers using the shared parameter
+  int num_total = 0;  //!< # total workers using the shared parameter
   //!< Shares are deleted by neuralnet's destructor
   std::vector<Param*> shares;
 };
@@ -371,9 +339,10 @@ inline int ParamID(int param_trgt) {
 }
 
 inline int SliceID(int param_trgt) {
-  static int mask = (1 << 16) -1;
+  static const int mask = (1 << 16) -1;
   return param_trgt & mask;
 }
+
 }  // namespace singa
 
 #endif  // SINGA_UTILS_PARAM_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/singleton.h
----------------------------------------------------------------------
diff --git a/include/utils/singleton.h b/include/utils/singleton.h
index 5048266..f02c595 100644
--- a/include/utils/singleton.h
+++ b/include/utils/singleton.h
@@ -22,7 +22,7 @@ class Singleton {
 template<typename T>
 class TSingleton {
  public:
-  static T* Instance(){
+  static T* Instance() {
     static thread_local T data_;
     return &data_;
   }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/include/utils/updater.h
----------------------------------------------------------------------
diff --git a/include/utils/updater.h b/include/utils/updater.h
index 46d2c53..25fc31f 100644
--- a/include/utils/updater.h
+++ b/include/utils/updater.h
@@ -10,24 +10,20 @@ namespace singa {
  *
  * Generate learning rate for a give training step/iteration.
  * There are many different ways to change the learning rate through time/step.
- * Users can inherint this class to implment their own change method.
+ * Users can inherint this class to implement their own change method.
  */
 class LRGenerator {
  public:
   static LRGenerator* Create(const LRGenProto& proto);
-  virtual ~LRGenerator() {}
 
-  virtual void Init(const LRGenProto& proto) {
-    proto_ = proto;
-  }
+  virtual ~LRGenerator() {}
 
+  virtual void Init(const LRGenProto& proto) { proto_ = proto; }
   /**
    * @param step training step/iteration.
    * @return base learning rate regardless of step
    */
-  virtual float Get(int step) {
-    return proto_.base_lr();
-  }
+  virtual float Get(int step) { return proto_.base_lr(); }
 
  protected:
   LRGenProto proto_;
@@ -39,35 +35,43 @@ class FixedStepLRGen : public LRGenerator {
  private:
   int last_idx_ = 0;
 };
+
 class StepLRGen : public LRGenerator {
  public:
   float Get(int step) override;
 };
+
 class LinearLRGen : public LRGenerator {
  public:
   float Get(int step) override;
 };
+
 class ExpLRGen : public LRGenerator {
  public:
   float Get(int step) override;
 };
+
 class InvLRGen : public LRGenerator {
  public:
   float Get(int step) override;
 };
+
 class InvTLRGen : public LRGenerator {
  public:
   float Get(int step) override;
 };
+
 /**
  * Updater for Param.
  */
-class Updater{
+class Updater {
  public:
   static Updater* Create(const UpdaterProto& proto);
+
   virtual ~Updater() {}
+
   virtual void Init(const UpdaterProto &proto);
-  virtual void Update(int step, Param* param, float grad_scale = 1.0f) = 0;
+  virtual void Update(int step, Param* param, float grad_scale) = 0;
 
  protected:
   UpdaterProto proto_;
@@ -78,23 +82,24 @@ class Updater{
 
 class SGDUpdater : public Updater {
  public:
-  void Update(int step, Param* param, float grad_scale = 1.0f);
+  void Update(int step, Param* param, float grad_scale) override;
 };
 
-class AdaGradUpdater : public Updater{
+class AdaGradUpdater : public Updater {
  public:
-  void Update(int step, Param* param, float grad_scale = 1.0f) override;
+  void Update(int step, Param* param, float grad_scale) override;
 };
 
 
 class NesterovUpdater : public Updater {
  public:
-  void Update(int step, Param* param, float grad_scale = 1.0f) override;
+  void Update(int step, Param* param, float grad_scale) override;
 };
+
 /*
-class RMSPropUpdater : public Updater{
+class RMSPropUpdater : public Updater {
  public:
-  virtual void Update(int step, Param* param, float grad_scale=1.0f);
+  virtual void Update(int step, Param* param, float grad_scale);
 
  protected:
   float base_lr_;
@@ -103,9 +108,9 @@ class RMSPropUpdater : public Updater{
   float weight_decay_;
 };
 
-class AdaDeltaUpdater : public Updater{
+class AdaDeltaUpdater : public Updater {
  public:
-  virtual void Update(int step, Param* param, float grad_scale=1.0f);
+  virtual void Update(int step, Param* param, float grad_scale);
 
  protected:
   float rho_;
@@ -113,6 +118,7 @@ class AdaDeltaUpdater : public Updater{
   float weight_decay_;
 };
 */
+
 }  // namespace singa
 
 #endif  // SINGA_UTILS_UPDATER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 1fda336..09bc75c 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -177,7 +177,7 @@ Msg* Server::HandleGet(Msg **msg) {
     return *msg;
   else {
     // LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first();
-    auto reply = param->HandleGetMsg(msg);
+    auto reply = param->HandleGetMsg(msg, false);
     reply->set_trgt(val, param->version());
     return reply;
   }
@@ -203,14 +203,13 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) {
     auto param = entry->shares.at(0);
     // extract and aggregate gradients
     param->ParseUpdateMsgs(request);
-    updater_->Update(step, param);
+    updater_->Update(step, param, 1.0f);
     param->set_local_version(param->local_version() + 1);
     // response to all shares of this param
-    for (auto response : param->GenUpdateResponseMsgs(request)) {
+    for (auto response : param->GenUpdateResponseMsgs(&request, false)) {
       response->set_trgt((*msg)->trgt_val(), param->local_version());
       ret.push_back(response);
     }
-    request.clear();
     entry->num_update = 0;
   }
   *msg = nullptr;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 69f697b..a7c1897 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -1,14 +1,18 @@
+#include "utils/param.h"
+
 #include <glog/logging.h>
 #include <cmath>
-#include <chrono>
 #include <random>
-#include "utils/param.h"
-#include "proto/job.pb.h"
 #include "mshadow/tensor.h"
-#include "utils/singleton.h"
 #include "utils/factory.h"
+#include "utils/singleton.h"
+
 namespace singa {
-using namespace mshadow;
+
+using mshadow::cpu;
+using mshadow::Random;
+using mshadow::Shape1;
+using mshadow::Tensor;
 using std::vector;
 using std::string;
 
@@ -23,18 +27,20 @@ ParamGenerator* ParamGenerator::Create(const ParamGenProto& 
proto) {
   return gen;
 }
 
-void ParamGenerator::Fill (Blob<float>* blob) {
+void ParamGenerator::Fill(Blob<float>* blob) {
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
   data = proto_.value();
 }
-void GaussianGen::Fill (Blob<float>* blob) {
+
+void GaussianGen::Fill(Blob<float>* blob) {
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
   auto random = TSingleton<Random<cpu>>::Instance();
   random->SampleGaussian(data, proto_.mean(), proto_.std());
-  if(proto_.value() != 1)
+  if (proto_.value() != 1)
     data *= proto_.value();
 }
-void GaussianSqrtFanInGen::Fill (Blob<float>* blob) {
+
+void GaussianSqrtFanInGen::Fill(Blob<float>* blob) {
   // only valid for param matrix with num of cols as fan in
   CHECK_EQ(blob->shape().size(), 2);
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
@@ -42,15 +48,15 @@ void GaussianSqrtFanInGen::Fill (Blob<float>* blob) {
   data /= sqrt(blob->shape().at(1));
 }
 
-void UniformGen::Fill (Blob<float>* blob) {
+void UniformGen::Fill(Blob<float>* blob) {
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
   auto random = TSingleton<Random<cpu>>::Instance();
   random->SampleUniform(data, proto_.low(), proto_.high());
-  if(proto_.value() != 1)
+  if (proto_.value() != 1)
     data *= proto_.value();
 }
 
-void UniformSqrtFanInGen::Fill (Blob<float>* blob) {
+void UniformSqrtFanInGen::Fill(Blob<float>* blob) {
   // only valid for param matrix with num of cols as fan in
   CHECK_EQ(blob->shape().size(), 2);
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
@@ -58,16 +64,16 @@ void UniformSqrtFanInGen::Fill (Blob<float>* blob) {
   data /= sqrt(blob->shape().at(1) / 3.0f);
 }
 
-void UniformSqrtFanInOutGen::Fill (Blob<float>* blob) {
+void UniformSqrtFanInOutGen::Fill(Blob<float>* blob) {
   // only valid for param matrix with num of cols as fan in
   CHECK_EQ(blob->shape().size(), 2);
   Tensor<cpu, 1> data(blob->mutable_cpu_data(), Shape1(blob->count()));
   UniformGen::Fill(blob);
   data /= sqrt(blob->shape()[0] + blob->shape()[1]);
 }
-/*****************Param***********************************/
+
 Param* Param::Create(const ParamProto& proto) {
-  Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
+  Factory<Param>* factory = Singleton<Factory<Param>>::Instance();
   Param* p = nullptr;
   if (proto.has_user_type())
     p = factory->Create(proto.user_type());
@@ -77,16 +83,46 @@ Param* Param::Create(const ParamProto& proto) {
   return p;
 }
 
-Param::Param():local_version_(-1), slice_start_(0), num_slices_(0),
-  num_pending_requests_(0), data_(nullptr) {
-}
-
 void Param::Setup(const vector<int>& shape) {
   data_ = std::make_shared<Blob<float>>(shape);
   grad_.Reshape(shape);
   history_.Reshape(shape);
 }
 
+void Param::InitValues() {
+  InitValues(0);
+}
+
+void Param::InitValues(int version) {
+  ParamGenerator* gen = ParamGenerator::Create(proto_.init());
+  gen->Fill(data_.get());
+  set_version(version);
+}
+
+void Param::ShareFrom(const Param& other) {
+  proto_.set_owner(other.owner());
+  if (data_ != nullptr)
+    CHECK(data_->shape() == other.data_->shape());
+  data_ = other.data_;
+  if (grad_.count() == 0)
+    grad_.Reshape(data_->shape());
+  slice_start_ = other.slice_start_;
+  num_slices_ = other.num_slices_;
+  slice_offset_ = other.slice_offset_;
+  slice_size_ = other.slice_size_;
+  // change pending list size equal to slice size
+  pending_get_.resize(other.pending_get_.size());
+  pending_update_.resize(other.pending_update_.size());
+}
+
+void Param::FromProto(const BlobProto& blob) {
+  data_->FromProto(blob);
+}
+
+void Param::ToProto(BlobProto* blob) {
+  data_->ToProto(blob);
+}
+
 void Param::AddSlice(int slice_id, int size) {
   int offset = 0;
   if (slice_size_.size() > 0) {
@@ -101,37 +137,20 @@ void Param::AddSlice(int slice_id, int size) {
   slice_size_.push_back(size);
   pending_get_.push_back(false);
   pending_update_.push_back(false);
-  pending_put_.push_back(false);
   num_slices_++;
 }
 
-void Param::InitValues(int version) {
-  ParamGenerator* gen = ParamGenerator::Create(proto_.init());
-  gen->Fill(data_.get());
-  set_version(version);
-}
-void Param::FromProto(const BlobProto& blob) {
-  data_->FromProto(blob);
-}
-void Param::ToProto(BlobProto* blob) {
-  data_->ToProto(blob);
-}
-
-/**************Message related functions********/
 Msg* Param::GenPutMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg = new Msg();
   msg->set_type(kPut);
-  void *ptr = mutable_cpu_data() + slice_offset_[idx];
-  void *p = ptr;
+  void* ptr = mutable_cpu_data() + slice_offset_[idx];
+  void* p = ptr;
   if (copy) p = nullptr;
-  msg->AddFormatFrame("iffp", slice_size_[idx],
-      lr_scale(), wd_scale(), p);
+  msg->AddFormatFrame("iffp", slice_size_[idx], lr_scale(), wd_scale(), p);
   if (copy) {
     msg->AddFrame(ptr, slice_size_[idx] * sizeof(float));
   }
-  // pending_put_[idx]=true;
-  // num_pending_requests_++;
   return msg;
 }
 
@@ -171,6 +190,8 @@ Msg* Param::GenSyncMsg(int offset, int size) {
 }
 
 Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
+  // TODO(wangsheng) remove the check later
+  CHECK(reserve);
   int size;
   float lr, wc;
   float* ptr;
@@ -183,43 +204,59 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
   Setup(shape);
   if (ptr == nullptr) {
     CHECK((*msg)->NextFrame());
-    CHECK_EQ(size* sizeof(float), (*msg)->FrameSize());
-    memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float));
+    CHECK_EQ(size * sizeof(float), (*msg)->FrameSize());
+    memcpy(mutable_cpu_data(), (*msg)->FrameData(), size * sizeof(float));
   } else {
     data_->set_cpu_data(ptr);
   }
-  if (!reserve)
-    DeleteMsg(msg);
+  if (!reserve) DeleteMsg(msg);
   return nullptr;
 }
 
 Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
+  // TODO(wangsheng) remove the check later
+  CHECK(!reserve);
   int copy;
   float* ptr;
   (*msg)->ParseFormatFrame("ip", &copy, &ptr);
   if (copy) {
-    (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size());
+    (*msg)->AddFrame(mutable_cpu_data(), sizeof(float) * size());
   } else if (ptr != data_->cpu_data()) {
-    memcpy(ptr, data_->cpu_data(), sizeof(float)*size());
+    // this case reflects following situation:
+    // worker 0 and server are in the same process, while worker 1 is not.
+    // worker 1 "put" data into server, so server need to allocate memory.
+    // then worker 0 "get" data from server, so server need:
+    //  1. copy the data to the worker0 provided space
+    //  2. change its own pointer to that space in order to share memory
+    // in this case, the server always points to last worker's space
+    memcpy(ptr, data_->cpu_data(), sizeof(float) * size());
     data_->set_cpu_data(ptr);
   }
   // else the mem space is shared among all worker and servers
-  (*msg)->SwapAddr();
-  (*msg)->set_type(kRGet);
-  return *msg;
+  Msg* ret = nullptr;
+  if (reserve) {
+    ret = new Msg(**msg);
+  } else {
+    // if not reserve the msg, we reuse it as return value
+    ret = *msg;
+    *msg = nullptr;
+  }
+  ret->SwapAddr();
+  ret->set_type(kRGet);
+  return ret;
 }
 
 void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
   CHECK_GT(msgs.size(), 0);
   float* server_grad = nullptr;
   vector<float*> worker_grad;
-  for (auto *msg : msgs) {
+  for (auto* msg : msgs) {
     int copy;
     msg->ParseFormatFrame("i", &copy);
     msg->NextFrame();
     float* ptr = nullptr;
     if (copy) {
-      ptr = static_cast<float*> (msg->FrameData());
+      ptr = static_cast<float*>(msg->FrameData());
       CHECK_EQ(size() * sizeof(float), msg->FrameSize());
     } else {
       msg->ParseFormatFrame("p", &ptr);
@@ -231,7 +268,8 @@ void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
     server_grad = worker_grad.at(0);
   for (float* grad : worker_grad) {
     if (grad != server_grad) {
-      for (int i =0; i < size(); i++) {
+      // TODO(wangsh) think about optimize it later?
+      for (int i = 0; i < size(); i++) {
         server_grad[i] += grad[i];
       }
     }
@@ -239,39 +277,41 @@ void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
   grad_.set_cpu_data(server_grad);
 }
 
-const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) {
+const vector<Msg*> Param::GenUpdateResponseMsgs(vector<Msg*>* msgs,
+                                                bool reserve) {
+  // TODO(wangsheng) remove the check later
+  CHECK(!reserve);
   vector<Msg*> ret;
-  for (auto msg : msgs) {
-    msg->FirstFrame();
-    msg->SwapAddr();
-    msg->set_type(kRUpdate);
+  for (Msg* msg : *msgs) {
+    Msg* ptr = reserve ? new Msg(*msg) : msg;
+    ptr->FirstFrame();
+    ptr->SwapAddr();
+    ptr->set_type(kRUpdate);
     int copy;
-    msg->ParseFormatFrame("i", &copy);
+    ptr->ParseFormatFrame("i", &copy);
     if (copy) {
-      msg->NextFrame();
-      CHECK_EQ(msg->FrameSize(), sizeof(float) * size());
-      memcpy(msg->FrameData(), mutable_cpu_data(), msg->FrameSize());
+      ptr->NextFrame();
+      CHECK_EQ(ptr->FrameSize(), sizeof(float) * size());
+      memcpy(ptr->FrameData(), mutable_cpu_data(), ptr->FrameSize());
     }
-    ret.push_back(msg);
+    ret.push_back(ptr);
   }
+  // if not reserved, we remove all pointers
+  if (!reserve) msgs->clear();
   return ret;
 }
 
 Msg* Param::HandleSyncMsg(Msg** msg, bool reserve) {
-  if (!reserve)
-    DeleteMsg(msg);
+  // TODO(wangwei) handle it later
+  if (!reserve) DeleteMsg(msg);
   return nullptr;
 }
 
-int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
-  return 1;
-}
-
 int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) {
   CHECK_EQ(pending_get_[slice_idx], true);
   pending_get_[slice_idx] = false;
   ParseResponseMsg(msg, slice_idx);
-  return (--num_pending_requests_)%num_slices_ == 0;
+  return (--num_pending_requests_) % num_slices_ == 0;
 }
 
 int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) {
@@ -281,48 +321,37 @@ int Param::ParseUpdateResponseMsg(Msg *msg, int 
slice_idx) {
   return (--num_pending_requests_) % num_slices_ == 0;
 }
 
+int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
+  // TODO(wangwei) handle it later
+  return 1;
+}
+
 void Param::ParseResponseMsg(Msg* msg, int slice_idx) {
   int copy;
   msg->ParseFormatFrame("i", &copy);
   msg->NextFrame();
   if (copy) {
-    CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float));
-    memcpy(mutable_cpu_data()+slice_offset_[slice_idx],
+    CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx] * sizeof(float));
+    memcpy(mutable_cpu_data() + slice_offset_[slice_idx],
         msg->FrameData(), msg->FrameSize());
   }
   // LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id();
 }
 
-void Param::ShareFrom(const Param& other) {
-  proto_.set_owner(other.owner());
-  if (data_ != nullptr) {
-    CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
-          other.data_->shape().begin()));
-  }
-  data_ = other.data_;
-  if (grad_.count() == 0)
-    grad_.Reshape(data_->shape());
-  slice_offset_ = other.slice_offset_;
-  slice_size_ = other.slice_size_;
-  slice_start_ = other.slice_start_;
-  num_slices_ = other.num_slices_;
-  pending_get_ = other.pending_get_;
-  pending_put_ = other.pending_put_;
-  pending_update_ = other.pending_update_;
-}
-
 /************************ParamEntry***************************/
 ParamEntry::ParamEntry():
   num_update(0), next_version(-1), num_local(0), num_total(0) {
 }
 
-ParamEntry::ParamEntry(int total, Param* p) : num_update(0), num_total(total) {
+ParamEntry::ParamEntry(int total, Param* p) {
+  num_total = total;
   shares.push_back(p);
 }
+
 void ParamEntry::AddParam(bool local, Param* p) {
   num_local += local;
   num_total += 1;
-  if (local)
-    shares.push_back(p);
+  if (local) shares.push_back(p);
 }
+
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f99246e7/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
index 24487d3..2c4b6c3 100644
--- a/src/utils/updater.cc
+++ b/src/utils/updater.cc
@@ -1,16 +1,21 @@
-
 #include "utils/updater.h"
-#include "mshadow/tensor.h"
+
 #include "mshadow/cxxnet_op.h"
+#include "mshadow/tensor.h"
 #include "utils/singleton.h"
 #include "utils/factory.h"
-#include "proto/job.pb.h"
-namespace  singa {
 
-using namespace mshadow;
-using namespace mshadow::expr;
+namespace singa {
+
+using mshadow::cpu;
+using mshadow::expr::F;
+using mshadow::op::sqrtop;
+using mshadow::op::square;
+using mshadow::Shape;
+using mshadow::Shape1;
+using mshadow::Tensor;
+using mshadow::TensorContainer;
 
-/**********************Learning rate generator******************************/
 LRGenerator* LRGenerator::Create(const LRGenProto& proto) {
   auto factory = Singleton<Factory<LRGenerator>>::Instance();
   LRGenerator* gen = nullptr;
@@ -23,9 +28,9 @@ LRGenerator* LRGenerator::Create(const LRGenProto& proto) {
 }
 
 float FixedStepLRGen::Get(int step) {
-  if (last_idx_ < proto_.fixedstep_conf().step_size() -1
+  if (last_idx_ < proto_.fixedstep_conf().step_size() - 1
       && step >= proto_.fixedstep_conf().step(last_idx_ + 1)) {
-      last_idx_ ++;
+      last_idx_++;
     }
   return proto_.fixedstep_conf().step_lr(last_idx_);
 }
@@ -38,7 +43,7 @@ float StepLRGen::Get(int step) {
 
 float LinearLRGen::Get(int step) {
   int freq = proto_.linear_conf().change_freq();
-  float r = step * 1.0  / freq;
+  float r = step * 1.0 / freq;
   return (1.0 - r) * proto_.base_lr() + r * proto_.linear_conf().final_lr();
 }
 
@@ -56,8 +61,6 @@ float InvTLRGen::Get(int step) {
   return proto_.base_lr() / (1 + step * 1. / 
proto_.inverset_conf().final_lr());
 }
 
-/***********************Updater********************************/
-
 Updater* Updater::Create(const UpdaterProto& proto) {
   auto factory = Singleton<Factory<Updater>>::Instance();
   Updater* updater = nullptr;
@@ -84,9 +87,8 @@ void SGDUpdater::Update(int step, Param* param, float 
grad_scale) {
   float wd = weight_decay_ * param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
-  if (wd > 0) {  // L2 regularization, should be done after timing grad_scale
+  if (wd > 0)  // L2 regularization, should be done after timing grad_scale
     grad += data * wd;
-  }
   if (momentum_ > 0) {
     Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
     history = history * momentum_ - lr * grad;
@@ -108,9 +110,8 @@ void NesterovUpdater::Update(int step, Param* param, float 
grad_scale) {
   float wd = weight_decay_*param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
-  if (wd > 0) {  // L2 regularization, should be done after timing grad_scale
+  if (wd > 0)  // L2 regularization, should be done after timing grad_scale
     grad += data * wd;
-  }
   Copy(tmp, history);
   history = history * momentum_ + lr * grad;
   tmp = history * (1 + momentum_) - tmp * momentum_;
@@ -126,11 +127,10 @@ void AdaGradUpdater::Update(int step, Param* param, float 
grad_scale) {
   float wd = weight_decay_*param->wd_scale();
   if (grad_scale != 1.f)
     grad *= grad_scale;
-  if (wd > 0) {  //  L2 regularization, should be done after timing grad_scale
+  if (wd > 0)  //  L2 regularization, should be done after timing grad_scale
     grad += data * wd;
-  }
-  history += F<op::square>(grad);
-  data -= lr * grad / (F<op::sqrtop>(history, proto_.delta()));
+  history += F<square>(grad);
+  data -= lr * grad / (F<sqrtop>(history, proto_.delta()));
 }
 
 /***********************RMSProp******************************


Reply via email to