Repository: incubator-singa
Updated Branches:
  refs/heads/master 2bbed5fc1 -> 56d32e8a0


SINGA-19 Slice large Param objects for load-balance
Tested with single worker, two worker group and two worker groups
TODO test with multiple servers and server groups for distributed hogwild and 
allreduce.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e0a52a62
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e0a52a62
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e0a52a62

Branch: refs/heads/master
Commit: e0a52a62577cc9845130b9d2c664007ec354804c
Parents: 2bbed5f
Author: wang wei <[email protected]>
Authored: Tue Jun 23 00:06:13 2015 +0800
Committer: wang wei <[email protected]>
Committed: Tue Jun 23 15:57:27 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/model.conf |   3 +-
 include/communication/msg.h |   2 +-
 include/trainer/server.h    |  20 +-
 include/trainer/trainer.h   | 138 +++++-----
 include/utils/cluster.h     |  13 +-
 include/utils/common.h      |   2 +
 include/utils/param.h       | 145 ++++++++---
 src/neuralnet/layer.cc      |   8 +-
 src/proto/cluster.proto     |  36 ++-
 src/proto/model.proto       |  88 ++-----
 src/trainer/server.cc       |  33 +--
 src/trainer/trainer.cc      | 544 +++++++++++++++++++++++----------------
 src/trainer/worker.cc       |  13 +-
 src/utils/cluster.cc        |  49 +++-
 src/utils/common.cc         |  15 ++
 src/utils/param.cc          | 221 +++++++++-------
 16 files changed, 786 insertions(+), 544 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
index 0cdf8b0..bfd7683 100644
--- a/examples/cifar10/model.conf
+++ b/examples/cifar10/model.conf
@@ -1,8 +1,9 @@
 name: "cifar10-convnet"
 train_steps: 1000
-test_steps:100
+test_steps:10
 test_frequency:300
 display_frequency:30
+alg: kBackPropagation
 updater{
   momentum:0.0
   weight_decay:0.004

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/communication/msg.h
----------------------------------------------------------------------
diff --git a/include/communication/msg.h b/include/communication/msg.h
index ba7d064..b83c738 100644
--- a/include/communication/msg.h
+++ b/include/communication/msg.h
@@ -45,7 +45,7 @@ class Msg {
   }
   inline int target_first() const { return target_first_; }
   inline int target_second() const { return target_second_; }
-  /**
+ /**
    * Copy src and dst address, including first, id, flag
    */
   inline Msg* CopyAddr() {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index f9fc80b..b07741f 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -1,6 +1,7 @@
 #ifndef INCLUDE_TRAINER_SERVER_H_
 #define INCLUDE_TRAINER_SERVER_H_
 #include <memory>
+#include <unordered_map>
 #include <utils/param.h>
 #include <utils/updater.h>
 #include "proto/model.pb.h"
@@ -8,6 +9,7 @@
 
 using std::shared_ptr;
 namespace singa {
+typedef std::unordered_map<int, shared_ptr<Param>> ServerShard;
 /* Repsond to worker's get/put/udpate request, and periodically syncing with
   * other servers.
   *
@@ -20,10 +22,10 @@ namespace singa {
   */
 class Server{
  public:
-  typedef std::map<int, shared_ptr<Param>> ParamShard;
 
   Server(int thread_id, int group_id, int server_id);
-  void Setup(const UpdaterProto& proto, shared_ptr<ParamShard> shard);
+  void Setup(const UpdaterProto& proto, shared_ptr<ServerShard> shard,
+      const vector<int>& slice2group);
   void Run();
 
  protected:
@@ -48,7 +50,7 @@ class Server{
    * @return the original message or response message. If we don't want need to
    * acknowledge the put request, then return nullptr.
         */
-       virtual Msg* HandlePut(shared_ptr<Param> param, Msg **msg);
+       virtual void HandlePut(shared_ptr<Param> param, Msg **msg);
 
        /**
    * TODO Process SYNC request.
@@ -57,21 +59,15 @@ class Server{
 
        /**
    * TODO Process SYNC response.
-        */
        virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg);
-
-  /**
-   * Scheduler for synchronizing server groups.
-   *
-   * TODO implement the Caffe's synchronization scheduler for data parallelism
-   */
-  virtual bool SyncNow();
+        */
 
  protected:
   int thread_id_,group_id_, server_id_;
   shared_ptr<Dealer> dealer_;
   shared_ptr<Updater> updater_;
-  shared_ptr<ParamShard> shard_;
+  shared_ptr<ServerShard> shard_;
+  vector<int> slice2group_;
 };
 } /* Server */
 #endif //INCLUDE_TRAINER_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index 18250af..6e93f80 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -1,5 +1,6 @@
 #ifndef INCLUDE_TRAINER_TRAINER_H_
 #define INCLUDE_TRAINER_TRAINER_H_
+#include <unordered_map>
 #include "proto/cluster.pb.h"
 #include "proto/model.pb.h"
 #include "utils/updater.h"
@@ -13,63 +14,73 @@
 
 namespace singa {
 /**
- * Every running process has a training object which launches one or more
- * worker (and server) threads.
- *
- * The main thread runs a loop to forward messages between workers and servers.
+ * Callback function for zookeeper
  */
-
-class Trainer{
+void HandleWorkerFinish(void * ctx);
 /**
- * ParamInfo is used to construct a parameter shard.
- *
- * For each worker group:
- *   Every unique Param object is associated with a ParamCounter object whose
- *   param field points the to Param object itself.
- *
- *   Param objects sharing the same values (due to data parallelism) are
- *   associated with the same ParamCounter whose param field also shares the
- *   same values.
- *
- *   Usage: we need to aggregate gradients from all workers for the shared
- *   parameters before sending the update request. The nUpdate counter counts
- *   the number.
- *
- * TODO test with different physical architectures.
+ * Zookeeper handler context used by HandleWorkerFinish(void*)function.
  */
-  public:
-  class ParamInfo{
+typedef struct HandleContext_{
+  shared_ptr<Dealer> dealer;
+  int group_id, id;
+} HandleContext;
+/**
+  * ParamInfo is used to construct a parameter shard.
+  *
+  * For each worker group:
+  *   Every unique Param object is associated with a ParamCounter object whose
+  *   param field points the to Param object itself.
+  *
+  *   Param objects sharing the same values (due to data parallelism) are
+  *   associated with the same ParamCounter whose param field also shares the
+  *   same values.
+  *
+  *   Usage: we need to aggregate gradients from all workers for the shared
+  *   parameters before sending the update request. The nUpdate counter counts
+  *   the number.
+  *
+  * TODO test with different physical architectures.
+  */
+class ParamInfo{
    public:
-    ParamInfo(shared_ptr<Param> p,int local, int owner):
-      num_update(0), next_version(0),num_local(local), num_total(1),
-      owner_procs(owner){
-        shares.push_back(p);
-      }
-
-    /**
-      * Associate the counter to a Param object.
-      *
-      * @param p
-      * @param local 1 if this Param object is used by workers in this procs, 0
-      *  otherwise
-      * @param owner the procs id of the worker who ownes this Param object
-      */
-    void AddParam(shared_ptr<Param> p, bool local){
-      num_local+=local;
-      num_total+=1;
-      if(local)
-        shares.push_back(p);
+  ParamInfo(shared_ptr<Param> p,int local, int owner):
+    num_update(0), next_version(0),num_local(local), num_total(1),
+    owner_procs(owner){
+      shares.push_back(p);
     }
-    int num_update, next_version; //!< all counters are atomic
 
-    int num_local; //!< # local workers uses the shared parameter
-    int num_total; //!< # total workers uses the shared parameter
-    int owner_procs; //!< the procs id of the worker that owns the parameter
-    vector<shared_ptr<Param>> shares;
-  };
+  /**
+    * Associate the counter to a Param object.
+    *
+    * @param p
+    * @param local 1 if this Param object is used by workers in this procs, 0
+    *  otherwise
+    * @param owner the procs id of the worker who ownes this Param object
+    */
+  void AddParam(shared_ptr<Param> p, bool local){
+    num_local+=local;
+    num_total+=1;
+    if(local)
+      shares.push_back(p);
+  }
+  int num_update, next_version; //!< all counters are atomic
+
+  int num_local; //!< # local workers uses the shared parameter
+  int num_total; //!< # total workers uses the shared parameter
+  int owner_procs; //!< the procs id of the worker that owns the parameter
+  vector<shared_ptr<Param>> shares;
+};
+
+typedef std::map<int, shared_ptr<ParamInfo>> WorkerShard;
 
- typedef std::map<int, shared_ptr<ParamInfo>> ParamShard;
+/**
+ * Every running process has a training object which launches one or more
+ * worker (and server) threads.
+ *
+ * The main thread runs a loop to forward messages between workers and servers.
+ */
 
+class Trainer{
  public:
   /**
    * Start the training in one process
@@ -84,8 +95,13 @@ class Trainer{
   // point.
 
  protected:
-  void Run(int nworkers, int nservers,
-      const std::map<int, shared_ptr<ParamShard>>& shards);
+
+  vector<shared_ptr<Server>> CreateServers(int nthread, const ModelProto& 
mproto,
+      const vector<int> slices, vector<HandleContext>* ctx);
+  vector<shared_ptr<Worker>> CreateWorkers(int nthread,
+      const ModelProto& mproto, vector<int> *slice_size);
+
+  void Run(int nworkers, int nservers);
   /**
    * Register default implementations for all base classes used in the system,
    * e.g., the Updater, BaseMsg, etc.
@@ -99,37 +115,35 @@ class Trainer{
 
   /**
    * Workers from the same group resident in the same process share the same
-   * ParamShard which contains ParamCounters for Param objects used/updated by
+   * WorkerShard which contains ParamCounters for Param objects used/updated by
    * these worekrs. Shared Param objects are associated with the same
    * ParamCounter.
    */
 
-  /**
-   * @return server id where the parameter is maintained.
-   */
-  virtual int Sharding(int param_id);
-
        /**
         * Generate a request message to Get the parameter object.
         */
-       virtual Msg* HandleGet(shared_ptr<ParamInfo>counter, Msg** msg);
-       virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+       virtual const vector<Msg*> HandleGet(shared_ptr<ParamInfo>counter, 
Msg** msg);
+       virtual void HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg);
 
        /**
         * Generate a request message to Update the parameter object.
         */
-       virtual Msg* HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg);
-  virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+       virtual const vector<Msg*> HandleUpdate(shared_ptr<ParamInfo>counter, 
Msg** msg);
+  virtual void HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg);
 
   /**
         * Generate a request message to Put the parameter object.
         */
-       virtual Msg* HandlePut(shared_ptr<ParamInfo>counter, Msg** msg);
+       virtual const vector<Msg*> HandlePut(shared_ptr<ParamInfo>counter, 
Msg** msg);
        virtual Msg* HandleConnect(Msg** msg);
 
  protected:
   int procs_id_;
   shared_ptr<Router> router_;
+  std::unordered_map<int, shared_ptr<WorkerShard>> worker_shards_;
+  shared_ptr<ServerShard> server_shard_;
+  vector<int> slice2server_;
 };
 } /* singa */
 #endif // INCLUDE_TRAINER_TRAINER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/cluster.h
----------------------------------------------------------------------
diff --git a/include/utils/cluster.h b/include/utils/cluster.h
index fcdb241..9648bfe 100644
--- a/include/utils/cluster.h
+++ b/include/utils/cluster.h
@@ -5,6 +5,7 @@
 #include <utility>
 #include <memory>
 #include <vector>
+#include <unordered_map>
 #include "proto/cluster.pb.h"
 #include "utils/cluster_rt.h"
 
@@ -39,7 +40,7 @@ class Cluster {
    */
   bool has_server()const {
     if(server_worker_separate()){
-      CHECK_LT(procs_id_, nprocs());
+      CHECK_LT(procs_id_, nprocs_);
       return procs_id_>=nworker_procs();
     }else
       return procs_id_<nserver_procs();
@@ -51,7 +52,7 @@ class Cluster {
     if(server_worker_separate()){
       return procs_id_<nworker_procs();
     }else
-      return procs_id_<nprocs();
+      return procs_id_<nprocs_;
   }
   /**
    * @return global procs id, which starts from 0.
@@ -67,7 +68,7 @@ class Cluster {
     return nserver_groups()*nservers_per_group()/nservers_per_procs();
   }
   int nprocs() const {
-    return cluster_.nprocs();
+    return nprocs_;
   }
 
   const string endpoint() const {
@@ -77,7 +78,7 @@ class Cluster {
    * @return endpoint of the router of a procs with the specified id
    */
   const string endpoint(int procs_id) const {
-    CHECK_LT(procs_id, nprocs());
+    CHECK_LT(procs_id, nprocs_);
     CHECK_GE(procs_id, 0);
     return endpoints_.at(procs_id);
   }
@@ -121,18 +122,22 @@ class Cluster {
     return cluster_rt_;
   }
 
+  int ProcsIDOf(int group_id, int id, int flag);
  private:
   Cluster(const ClusterProto &cluster, int procs_id) ;
   void SetupFolders(const ClusterProto &cluster);
+  int Hash(int gid, int id, int flag);
 
  private:
   int procs_id_;
+  int nprocs_;
   std::vector<std::string> endpoints_;
   // cluster config proto
   ClusterProto cluster_;
   shared_ptr<ClusterRuntime> cluster_rt_;
   // make this class a singlton
   static shared_ptr<Cluster> instance_;
+  std::unordered_map<int, int> procs_ids_;
 };
 
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/common.h
----------------------------------------------------------------------
diff --git a/include/utils/common.h b/include/utils/common.h
index 1644598..98a1cd7 100644
--- a/include/utils/common.h
+++ b/include/utils/common.h
@@ -42,6 +42,8 @@ inline void Sleep(int millisec=1){
 }
 */
 
+int gcd(int a, int b);
+int LeastCommonMultiple(int a, int b);
 inline float rand_real(){
   return  static_cast<float>(rand())/(RAND_MAX+1.0f);
 }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/include/utils/param.h
----------------------------------------------------------------------
diff --git a/include/utils/param.h b/include/utils/param.h
index e55480b..897c97a 100644
--- a/include/utils/param.h
+++ b/include/utils/param.h
@@ -12,33 +12,110 @@ namespace singa {
 class Param {
  public:
   Param();
-  virtual ~Param(){};
-
-  virtual Msg* GenGetMsg(bool copy, int v=-1);
-  virtual Msg* GenPutMsg(bool copy, int v=-1);
-  virtual Msg* GenUpdateMsg(bool copy, int v=-1);
-  virtual Msg* GenSyncMsg(bool copy, int v=-1);
+  virtual ~Param(){ }
+  /**
+   * Generate the message for a get request, i.e., get parameters from a server
+   *
+   * This function is called at worker/stub side.
+   * @param copy decides whether to copy the parameter values from the server.
+   * @param slice_idx index of the slice from which the message is generated.
+   * @return generated message without setting src, dst, target fields.
+   */
+  virtual Msg* GenGetMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a put request, i.e., put parameters to a server.
+   * \copydetails GenGetMsg(bool, int);
+   */
+  virtual Msg* GenPutMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a update request, i.e., pass info to server for
+   * parameter update.
+   * \copydetails GenGetMsg(bool, int);
+   */
+  virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
+  /**
+   * Generate the message for a synchronization request between server groups.
+   *
+   * This function is called at server side where the Param is actually a slice
+   * of an original Param object.
+   * */
+  virtual Msg* GenSyncMsg();
+  /**
+   * Generate the message to response the update request.
+   *
+   * This function is called at the server side, where the Param is actually a 
slice
+   * of an original Param object.
+   * @param copy if true copy the parameter value into the message, otherwise
+   * only transfer the pointer of the parameter values.
+   * @return response message pointer
+   */
+  virtual Msg* GenUpdateResponseMsg(bool copy);
 
+  /**
+   * Server handling function for get request.
+   *
+   * @param msg  request message
+   * @return resposne message
+   */
   virtual Msg* HandleGetMsg(Msg** msg);
+  /**
+   * Server handling function for put request.
+   *
+   * \copydetails HandleGetMsg(Msg**)
+   */
   virtual Msg* HandlePutMsg(Msg** msg);
+  /**
+   * Server handling function for synchronization message
+   *
+   * \copydetails HandleGetMsg(Msg**)
+   */
   virtual Msg* HandleSyncMsg(Msg** msg);
-  virtual const std::pair<bool, int> ParseUpdateMsg(Msg** msg);
-  virtual Msg* GenUpdateResponseMsg(bool copy, int v=-1);
-
 
-  virtual int ParseGetResponseMsg(Msg** msg);
-  virtual int ParsePutResponseMsg(Msg** msg);
-  virtual int ParseUpdateResponseMsg(Msg** msg);
-  virtual int ParseSyncResponseMsg(Msg** msg);
+  /**
+   * Server parses update request message.
+   *
+   * @param msg
+   * @return 1 for copy, 0 for no copy
+   */
+  virtual int ParseUpdateMsg(Msg** msg);
+  /**
+   * Worker/Stub parsing function for get response.
+   *
+   * @param msg
+   * @param slice_idx index for the slice
+   */
+  virtual int ParseGetResponseMsg(Msg** msg, int slice_idx);
+  /**
+   * Worker/Server parsing function for update response
+   *
+   * \copydetails ParseGetResponseMsg(Msg**, int);
+   */
+  virtual int ParseUpdateResponseMsg(Msg** msg, int slice_idx);
+  /**
+   * Server parsing function for synchronization response.
+   *
+   * \copydetails ParseGetResponseMsg(Msg** , int);
+   */
+  virtual int ParseSyncResponseMsg(Msg** msg, int slice_idx);
 
   /**
-   * setup param shape
+   * Setup param object
+   *
+   * @param proto includes learning rate/weight decay multipliers
+   * @param shape
    */
-  virtual void Setup(const ParamProto& proto, const std::vector<int>& shape, 
int fan_in);
+  virtual void Setup(const ParamProto& proto, const std::vector<int>& shape);
   /*
-   * fill the data according to initmethod, i.e., random/gaussian/fixed value
+   * Fill the values according to initmethod, e.g., gaussian distribution
+   *
+   * @param version initial version
+   */
+  virtual void InitValues(int version=0);
+  /**
+   * Share the data blob from other Param objects.
+   *
+   * @param other the Param object whose owner owns the data blob
    */
-  virtual void Init(int v=0);
   void ShareData(shared_ptr<Param> other){
     proto_.set_owner(other->owner());
     if(data_!=nullptr)
@@ -52,11 +129,6 @@ class Param {
   float weight_decay_multiplier() {
     return proto_.weight_decay_multiplier();
   }
-  /*
-  const int split_threshold(){
-    return proto_.split_threshold();
-  }
-  */
   const std::string& name() {
     return proto_.name();
   }
@@ -137,28 +209,35 @@ class Param {
   float* mutable_cpu_history(){
     return history_.mutable_cpu_data();
   }
+  int slice_start() const {
+    return slice_start_;
+  }
+
+  int num_slices() const {
+    return num_slices_;
+  }
+
+  void AddSlice(int slice_id, int size);
+
  protected:
+  void ParseResponseMsg(Msg** msg, int slice_idx);
+
+ protected:
+
   /**
    * name of the parameter used to share wights between neuralnets
    */
   std::string name_;
   shared_ptr<Blob<float>> data_;
+  int slice_start_, num_slices_;
+  vector<int> slice_offset_, slice_size_;
+  vector<bool> pending_put_,pending_get_, pending_update_;
+  int num_pending_requests_;
   //! gradient, history gradient of this parameter
   Blob<float> grad_, history_;
   ParamProto proto_;
-  int fan_in_;
   int local_version_;
 };
-/**
- * To support the shared memory and distributed Hogwild algorithm.
- * Each worker group has one worker. Workers from the same process share the
- * memory space for parameter values. Each process has one server group which
- * also shares the same memory space. Messages except synchronization messages
- * only transfer pointers to parameter value or gradient space. Hence memory
- * copy is avoided for intra-process communication.
- */
-class HogwildParam: public Param{
-};
 
 }  // namespace singa
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/neuralnet/layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/layer.cc b/src/neuralnet/layer.cc
index de13ba7..04ce72a 100644
--- a/src/neuralnet/layer.cc
+++ b/src/neuralnet/layer.cc
@@ -46,9 +46,9 @@ void ConvolutionLayer::Setup(const LayerProto& proto,
 
   Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
   weight_=shared_ptr<Param>(factory->Create("Param"));
-  weight_->Setup(proto.param(0), vector<int>{num_filters_, col_height_}, 
col_height_);
+  weight_->Setup(proto.param(0), vector<int>{num_filters_, col_height_});
   bias_=shared_ptr<Param>(factory->Create("Param"));
-  bias_->Setup(proto.param(1), vector<int>{num_filters_},0);
+  bias_->Setup(proto.param(1), vector<int>{num_filters_});
 }
 
 void ConvolutionLayer::SetupAfterPartition(const LayerProto& proto,
@@ -173,8 +173,8 @@ void InnerProductLayer::Setup(const LayerProto& proto,
   Factory<Param>* factory=Singleton<Factory<Param>>::Instance();
   weight_=shared_ptr<Param>(factory->Create("Param"));
   bias_=shared_ptr<Param>(factory->Create("Param"));
-  weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_}, vdim_*hdim_);
-  bias_->Setup(proto.param(1), vector<int>{hdim_},0);
+  weight_->Setup(proto.param(0), vector<int>{vdim_, hdim_});
+  bias_->Setup(proto.param(1), vector<int>{hdim_});
 }
 void InnerProductLayer::SetupAfterPartition(const LayerProto& proto,
       const vector<int> &shape,

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/proto/cluster.proto
----------------------------------------------------------------------
diff --git a/src/proto/cluster.proto b/src/proto/cluster.proto
index f79d7d4..52cfd51 100644
--- a/src/proto/cluster.proto
+++ b/src/proto/cluster.proto
@@ -1,8 +1,8 @@
 package singa;
 
 message ClusterProto{
-  optional int32 nworker_groups=1;
-  optional int32 nserver_groups=2;
+  optional int32 nworker_groups=1 [default=1];
+  optional int32 nserver_groups=2 [default=1];
   optional int32 nworkers_per_group=3 [default=1];
   optional int32 nservers_per_group=4 [default=1];
   optional int32 nworkers_per_procs=5 [default=1];
@@ -11,27 +11,24 @@ message ClusterProto{
   // Used in standalone mode, one ip or hostname per line
   // For YARN or Mesos version, the processes are allocted dynamically,
   // hence no need to specify the hosts statically
-  optional string hostfile=10;
+  optional string hostfile=10 [default=""];
 
   // servers and workers in different processes?
   optional bool server_worker_separate=11 [default=false];
 
-  // if configured, must be consistent with the one computed from 1-6
-  optional int32 nprocs=12;
-
   // port number is used by ZeroMQ
   optional int32 start_port=13 [default=6723];
   // local workspace, train/val/test shards, checkpoint files
   required string workspace=14;
   // relative path to workspace. if not set, use the default dir of glog
-  optional string log_dir=15;
+  optional string log_dir=15 [default="/tmp"];
   // ip/hostname : port [, ip/hostname : port]
   optional string zookeeper_host=16 [default="localhost:2181"];
   // message size limit, default 1MB
   // optional int32 largest_message=20 [default=1048576];
   // optional float bandwidth=21 [default=100];//MB/s
 
-       repeated ServerTopology server_group = 20;
+       //repeated ServerTopology server_group = 20;
 
   optional int32 stub_timeout=30 [default=5000];
   optional int32 worker_timeout=31 [default=5000];
@@ -50,3 +47,26 @@ message ServerTopology{
   // neighbor group id
        repeated int32 neighbor = 3;
 }
+enum MsgType{
+  kGet=0;
+  kPut=1;
+  kSync=2;
+  kUpdate=3;
+  kSyncRequest=4;
+  kSyncResponse=5;
+  kStop=6;
+  kData=7;
+  kRGet=8;
+  kRUpdate=9;
+  kConnect=10;
+  kMetric=11;
+};
+
+enum EntityType{
+  kWorkerParam=0;
+  kWorkerLayer=1;
+  kServer=2;
+  kStub=3;
+  kRuntime=4;
+};
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index c6e3495..8cb45a3 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -1,26 +1,4 @@
 package singa;
-enum MsgType{
-  kGet=0;
-  kPut=1;
-  kSync=2;
-  kUpdate=3;
-  kSyncRequest=4;
-  kSyncResponse=5;
-  kStop=6;
-  kData=7;
-  kRGet=8;
-  kRUpdate=9;
-  kConnect=10;
-  kMetric=11;
-};
-
-enum EntityType{
-  kWorkerParam=0;
-  kWorkerLayer=1;
-  kServer=2;
-  kStub=3;
-  kRuntime=4;
-};
 enum Phase {
   kTrain = 0;
   kValidation=1;
@@ -33,25 +11,17 @@ enum ShareOption{
   kWhole=1;
 };
 message ModelProto{
-  optional string name = 1;
-  // relative path to system folder
-  optional string train_folder=2 [default="train"];
-  optional string test_folder=3 [default="test"];
-  optional string validation_folder=4 [default="validation"];
+  required string name = 1;
   // start display after this num steps
   optional int32 display_after_steps = 6 [default = 0];
   // frequency of display
   optional int32 display_frequency = 7 [default = 0];
 
-  // the time of validation
-  //optional int32 validation_step = 9 [default = 0];
   // start validation after this num steps
   optional int32 validation_after_steps = 10 [default = 0];
   // frequency of validation
   optional int32 validation_frequency = 11 [default = 0];
 
-  // the time of test
-  //optional int32 test_step = 12 [default = 0];
   // start test after this num steps
   optional int32 test_after_steps = 13 [default = 0];
   // frequency of test
@@ -63,24 +33,23 @@ message ModelProto{
 
 
   // total num of steps for training
-  optional int32 train_steps = 20;
+  required int32 train_steps = 20;
   // total num of steps for validation
-  optional int32 validation_steps=21;
+  optional int32 validation_steps=21 [default=0];
   // total num of steps for test
-  optional int32 test_steps=22;
+  optional int32 test_steps=22 [default=0];
   // last snapshot step
-  optional int32 step=29 [default=0];
+  optional int32 step=29;
 
-  optional UpdaterProto updater=31;
+  required UpdaterProto updater=31;
   // There are two basic algorithms for calculating gradients.
   // Different deep learning models use different algorithms.
   enum GradCalcAlg{
     kBackPropagation = 1;
     kContrastiveDivergence = 2;
   }
-  optional GradCalcAlg alg= 32 [default = kBackPropagation];
-  optional bool hogwild=33 [default=false];
-  optional NetProto neuralnet = 40;
+  required GradCalcAlg alg= 32 [default = kBackPropagation];
+  required NetProto neuralnet = 40;
   optional bool debug=41 [default=false];
   optional int32 warmup_steps=50 [default=0];
 }
@@ -93,7 +62,7 @@ message NetProto{
 message ParamProto {
   // for the program to identify it and share among layers.
   // e.g., "conv1_weight","fc_bias"
-  optional string name = 1;
+  required string name = 1;
   optional int32 id=2;
   // in most situations, user do not need to config this,
   // the program will calculate it
@@ -149,25 +118,24 @@ message BlobProtos{
   repeated string names=3;
 }
 
-
-
 enum PartitionType{
   kDataPartition=0;
   kLayerPartition=1;
   kNone=2;
 }
+
 enum ConnectionType{
   kOneToOne=0;
   kOneToAll=1;
 }
 
 message LayerProto {
-  optional string name = 1; // the layer name
-  optional string type = 2; // the layer type from the enum above
+  required string name = 1; // the layer name
+  required string type = 2; // the layer type from the enum above
   repeated string srclayers=3;
   optional int32 locationid=4 [default=0]; // todo make locationID an array
   optional int32 partitionid=5 [default=0];
-  optional PartitionType partition_type=6;
+  optional PartitionType partition_type=6 [default=kNone];
   optional string datablob=7;
   // can be pos/neg neuron value for CD, neuron value/grad for BP
   //repeated DAryProto ary = 10;
@@ -204,10 +172,10 @@ message RGBImage {
   optional float scale=1 [default=1.0];
   optional int32 cropsize=2 [default=0];
   optional bool mirror=3 [default=false];
-  optional string meanfile=4;
+  optional string meanfile=4 [default=""];
 }
 message SplitProto{
-  optional int32 num_splits=1;
+  required int32 num_splits=1;
 }
 // scaled tan: A*tan(B*x)
 message TanhProto{
@@ -225,7 +193,7 @@ message SoftmaxLossProto {
 }
 // Message that stores parameters used by ConvolutionLayer
 message ConvolutionProto {
-  optional uint32 num_filters = 1; // The number of outputs for the layer
+  required uint32 num_filters = 1; // The number of outputs for the layer
   optional bool bias_term = 2 [default = true]; // whether to have bias terms
   // Pad, kernel size, and stride are all given as a single value for equal
   // dimensions in height and width or as Y, X pairs.
@@ -235,19 +203,17 @@ message ConvolutionProto {
 }
 
 message ConcateProto{
-  optional int32 concate_dimension=1;
-  optional int32 concate_num=2;
+  required int32 concate_dimension=1;
+  required int32 concate_num=2;
 }
 
 // Message that stores parameters used by DataLayer
 message DataProto {
-  // Specify the data source.
-  optional string source = 1;
   // path to the data file/folder, absolute or relative to the
   // ClusterProto::workspace
-  optional string path=2;
+  required string path=2;
   // Specify the batch size.
-  optional uint32 batchsize = 4;
+  required uint32 batchsize = 4;
   // skip [0,random_skip] records
   optional uint32 random_skip=5 [default=0];
 }
@@ -273,13 +239,13 @@ message DropoutProto {
 }
 // Message that stores parameters used by InnerProductLayer
 message InnerProductProto {
-  optional uint32 num_output = 1; // The number of outputs for the layer
+  required uint32 num_output = 1; // The number of outputs for the layer
   optional bool bias_term = 2 [default = true]; // whether to have bias terms
 }
 
 // Message that stores parameters used by LRNLayer
 message LRNProto {
-  optional uint32 local_size = 1 [default = 5];
+  optional int32 local_size = 1 [default = 5];
   optional float alpha = 2 [default = 1.];
   optional float beta = 3 [default = 0.75];
   enum NormRegion {
@@ -305,8 +271,8 @@ message PoolingProto {
 }
 
 message SliceProto{
-  optional int32 slice_dimension=1;
-  optional int32 slice_num=2;
+  required int32 slice_dimension=1;
+  required int32 slice_num=2;
 }
 // Message that stores parameters used by ReLULayer
 message ReLUProto {
@@ -356,9 +322,9 @@ message UpdaterProto {
   optional float pow=7 [default=0];
   optional float delta=8 [default=0.0000001];
   optional float rho=9 [default=0.9];
-  optional float base_learning_rate=12;
-  optional float final_learning_rate=13;
-  optional int32 learning_rate_change_frequency = 14;
+  optional float base_learning_rate=12 [default=0];
+  optional float final_learning_rate=13 [default=0];
+  optional int32 learning_rate_change_frequency = 14 [default=0];
   enum ChangeProto {
     kFixed = 0;
     kInverse_t= 1;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index e0fcb48..36b04b6 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -13,15 +13,15 @@ Server::Server(int thread_id,int group_id, int server_id):
   thread_id_(thread_id),group_id_(group_id), server_id_(server_id){}
 
 void Server::Setup(const UpdaterProto& proto,
-    shared_ptr<Server::ParamShard> shard){
+    shared_ptr<ServerShard> shard, const vector<int>& slice2group){
        //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id 
= " <<id_;
   updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
       ->Create("Updater"));
   updater_->Init(proto);
   shard_=shard;
+  slice2group_=slice2group;
 }
 
-
 void Server::Run(){
   dealer_=std::make_shared<Dealer>(2*thread_id_);
   dealer_->Connect(kInprocRouterEndpoint);
@@ -61,7 +61,7 @@ void Server::Run(){
         param->set_id(pid);
         (*shard_)[pid]=param;
       }
-      response = HandlePut(param, &msg);
+      param->HandlePutMsg(&msg);
     }else{
       int pid=msg->target_first();
       if(shard_->find(pid)==shard_->end()){
@@ -83,10 +83,6 @@ void Server::Run(){
             VLOG(3)<<"Handle SYNC-REQUEST";
             response = HandleSyncRequest(param, &msg);
             break;
-          case kSyncResponse:
-            VLOG(3) << "Handle SYNC response";
-            HandleSyncResponse(param, &msg);
-            break;
         }
         if (response!=nullptr){
           dealer_->Send(&response);
@@ -96,11 +92,8 @@ void Server::Run(){
   }
 }
 
-bool Server::SyncNow(){
-  return false;
-}
-Msg* Server::HandlePut(shared_ptr<Param> param, Msg **msg){
-  return param->HandlePutMsg(msg);
+void Server::HandlePut(shared_ptr<Param> param, Msg **msg){
+  param->HandlePutMsg(msg);
 }
 
 Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){
@@ -110,11 +103,15 @@ Msg* Server::HandleGet(shared_ptr<Param> param, Msg 
**msg){
 Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) {
   //repsonse of the format: <identity><type: kData><paramId><param content>
   auto* tmp=static_cast<Msg*>((*msg)->CopyAddr());
-  const std::pair<bool, int> copy_step=param->ParseUpdateMsg(msg);
-  updater_->Update(copy_step.second, param);
-  param->set_version(param->version()+1);
-  auto response=param->GenUpdateResponseMsg(copy_step.first, param->version());
   tmp->SwapAddr();
+  int paramid=(*msg)->target_first();
+  int sliceid=(*msg)->target_second();
+  int step=(*msg)->target_third();
+  bool copy=param->ParseUpdateMsg(msg);
+  updater_->Update(step, param);
+  param->set_version(param->version()+1);
+  auto response=param->GenUpdateResponseMsg(copy);
+  response->set_target(paramid, sliceid, param->version());
   response->SetAddr(tmp);
   delete tmp;
   return response;
@@ -124,8 +121,4 @@ Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg 
**msg){
   return param->HandleSyncMsg(msg);
 }
 
-int Server::HandleSyncResponse(shared_ptr<Param> param, Msg **msg){
-  return param->ParseSyncResponseMsg(msg);
-}
-
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index cd0189c..989a020 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -1,6 +1,7 @@
 #include <thread>
 #include <vector>
 #include <map>
+#include <queue>
 #include <glog/logging.h>
 #include "trainer/trainer.h"
 #include "mshadow/tensor.h"
@@ -8,24 +9,6 @@ using std::vector;
 using std::map;
 
 namespace singa {
-int ProcsIDOf(int group_id, int id, int flag){
-  int procsid=-1;
-  auto cluster=Cluster::Get();
-  if(flag==kServer){
-    procsid=group_id*cluster->nservers_per_group()/
-      cluster->nservers_per_procs()+id/cluster->nservers_per_procs();
-    if(cluster->server_worker_separate())
-      procsid+=cluster->nworker_procs();
-  }else if(flag==kWorkerLayer || flag==kWorkerParam){
-    procsid=group_id*cluster->nworkers_per_group()
-      /cluster->nworkers_per_procs();
-    if(cluster->nworkers_per_group()>cluster->nworkers_per_procs())
-      procsid+=id/cluster->nworkers_per_procs();
-  }else{
-    LOG(ERROR)<<"Unkown flag ("<<flag<<")";
-  }
-  return procsid;
-}
 
 void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
   // register all layers appearing in the neural net
@@ -36,11 +19,6 @@ void Trainer::RegisterDefaultClasses(const 
singa::ModelProto& proto){
       "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
 }
 
-typedef struct HandleContext_{
-  shared_ptr<Dealer> dealer;
-  int group_id, id;
-} HandleContext;
-
 void HandleWorkerFinish(void * ctx){
   HandleContext* hctx=static_cast<HandleContext*> (ctx);
   Msg* msg=new Msg();
@@ -50,125 +28,225 @@ void HandleWorkerFinish(void * ctx){
   hctx->dealer->Send(&msg);
 }
 
-void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
-    int procs_id){
-  procs_id_=procs_id;
-  RegisterDefaultClasses(mproto);
-
-  auto cluster=Cluster::Get(cproto, procs_id);
-  router_=make_shared<Router>();
-  router_->Bind(kInprocRouterEndpoint);
-  if(cluster->nprocs()>1)
-    router_->Bind(cluster->endpoint());
+const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num,
+    const vector<shared_ptr<Param>>& params){
+  CHECK_GT(num,0);
+  vector<int> param_size;
+  int avg=0;
+  for(const auto& x:params){
+    if(x->owner()==x->id())
+      avg+=x->size();
+  }
+  avg/=num;
+  int diff=avg/10;
+  LOG(INFO)<<"Slicer, param avg="<<avg<<", diff= "<<diff;
 
-  // create servers
-  vector<shared_ptr<Server>> servers;
-  vector<HandleContext> ctx;
-  int nthreads=1; // the first socket is the router
-  if(cluster->has_server()){ // todo move sever creation to a method
-    int pid=cluster->procs_id();
-    if(cluster->server_worker_separate())
-      pid-=cluster->nworker_procs();
-    int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group();
-    int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
-    int end=start+cluster->nservers_per_group();
-    // the ParamShard for servers consists of a dictionary of Param objects
-    auto shard=make_shared<Server::ParamShard>();
-    if(start<end){
-      auto dealer=make_shared<Dealer>();
-      dealer->Connect(kInprocRouterEndpoint);
-      for(int sid=start;sid<end;sid++){
-        auto server=make_shared<Server>(nthreads++, gid, sid);
-        server->Setup(mproto.updater(), shard);
-        servers.push_back(server);
-        HandleContext hc{dealer, gid, sid};
-        ctx.push_back(hc);
-        CHECK(cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish,
-            &ctx.back()));
+  int capacity=avg, sliceid=0;
+  std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
+  for(auto& param: params){
+    if(param->id()!=param->owner())
+      continue;
+    int x=param->size(), paramid=param->id();
+    LOG(INFO)<<"param id="<<paramid<<", total size="<<x;
+    while(x>0){
+      int size=0;
+      if(capacity>x){
+        capacity-=x;
+        size=x;
+        x=0;
+      }else if(capacity+diff>x){
+        capacity=avg;
+        size=x;
+        x=0;
+      }else if(capacity>diff){
+        x-=capacity;
+        size=capacity;
+        capacity=avg;
+      }else{
+        capacity=avg;
+      }
+      if(size){
+        paramid2slices[paramid].push_back(std::make_pair(sliceid++, size));
+        LOG(INFO)<<"param id="<<paramid<<", slice size="<<size;
       }
     }
   }
-  // create workers
-  vector<shared_ptr<Worker>> workers;
-  std::map<int, shared_ptr<Trainer::ParamShard>> shards;
-  if(cluster->has_worker()){ //move worker creation to a method
-    auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
-        cluster->nworkers_per_group());
-    //LOG(ERROR)<<net->ToString();
-    int pid=cluster->procs_id();
-    int gstart, gend, wstart, wend;
-    if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
-      // all workers in this procs are from the same group
-      gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group();
-      gend=gstart+1;
-      wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group();
-      wend=wstart+cluster->nworkers_per_group();
+  return paramid2slices;
+}
+const vector<int> PartitionSlice(int num, const vector<int>& slices){
+  int avg=0;
+  for(int x: slices)
+    avg+=x;
+  avg/=num;
+  int box=avg, boxid=0, diff=avg/10;
+  vector<int> slice2box;
+  for(int x: slices){
+    if(box>=x){
+      box-=x;
+      slice2box.push_back(boxid);
+    }else if(box+diff>=x){
+      slice2box.push_back(boxid);
+      box=avg;
+      boxid++;
     }else{
-      // there are multiple groups in this procs
-      CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0);
-      int groups_per_procs=
-        cluster->nworkers_per_procs()/cluster->nworkers_per_group();
-      gstart=pid*groups_per_procs;
-      gend=(pid+1)*groups_per_procs;
-      wstart=0;
-      wend=cluster->nworkers_per_group();
+      box=avg;
+      boxid++;
     }
-    for(int gid=gstart;gid<gend;gid++){
-      shared_ptr<NeuralNet> train_net, test_net, validation_net;
-      if(gid==gstart)
-        train_net=net;
-      else{
-        train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
+  }
+  CHECK_LE(boxid, num);
+  return slice2box;
+}
+vector<shared_ptr<Server>> Trainer::CreateServers(int nthreads,
+    const ModelProto & mproto,
+    const vector<int> slices,
+    vector<HandleContext>* ctx){
+  auto cluster=Cluster::Get();
+  vector<shared_ptr<Server>> servers;
+  if(!cluster->has_server())
+    return servers;
+
+  int pid=cluster->procs_id();
+  if(cluster->server_worker_separate())
+    pid-=cluster->nworker_procs();
+  int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group();
+  int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
+  int end=start+cluster->nservers_per_group();
+  // the ServerShard for servers consists of a dictionary of Param objects
+  server_shard_=make_shared<ServerShard>();
+  auto slice2group=PartitionSlice(cluster->nserver_groups(), slices);
+  if(start<end){
+    auto dealer=make_shared<Dealer>();
+    dealer->Connect(kInprocRouterEndpoint);
+    for(int sid=start;sid<end;sid++){
+      auto server=make_shared<Server>(nthreads++, gid, sid);
+      server->Setup(mproto.updater(), server_shard_, slice2group);
+      servers.push_back(server);
+      HandleContext hc{dealer, gid, sid};
+      ctx->push_back(hc);
+      CHECK(cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish,
+            &(ctx->back())));
+    }
+  }
+  return servers;
+}
+
+vector<shared_ptr<Worker>> Trainer::CreateWorkers(int nthreads,
+    const ModelProto& mproto, vector<int> *slice_size){
+  auto cluster=Cluster::Get();
+  vector<shared_ptr<Worker>> workers;
+  if(!cluster->has_worker())
+    return workers;
+  //LOG(ERROR)<<net->ToString();
+  int pid=cluster->procs_id();
+  int gstart, gend, wstart, wend;
+  if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
+    // all workers in this procs are from the same group
+    gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+    gend=gstart+1;
+    wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group();
+    wend=wstart+cluster->nworkers_per_group();
+  }else{
+    // there are multiple groups in this procs
+    CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0);
+    int groups_per_procs=
+      cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+    gstart=pid*groups_per_procs;
+    gend=(pid+1)*groups_per_procs;
+    wstart=0;
+    wend=cluster->nworkers_per_group();
+  }
+  auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
+      cluster->nworkers_per_group());
+  int lcm=LeastCommonMultiple(cluster->nserver_groups(), 
cluster->nservers_per_group());
+  auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
+  for(auto param: net->params()){
+    if(param->id()==param->owner())
+      for(auto entry: paramid2slices[param->id()])
+        slice_size->push_back(entry.second);
+  }
+
+  for(int gid=gstart;gid<gend;gid++){
+    shared_ptr<NeuralNet> train_net, test_net, validation_net;
+    if(gid==gstart)
+      train_net=net;
+    else{
+      train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
+          cluster->nworkers_per_group());
+      // the train net for other groups may share parameter values from the
+      // first group
+      if(cluster->share_memory())
+        train_net->ShareParams(net, kValueOnly);
+    }
+    if(gid==0){
+      // validation and test are performed only by the first group
+      if(mproto.test_steps()){
+        test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest,
             cluster->nworkers_per_group());
-        // the train net for other groups may share parameter values from the
-        // first group
-        if(cluster->share_memory())
-          train_net->ShareParams(net, kValueOnly);
+        if(test_net!=nullptr)
+          test_net->ShareParams(train_net, kValueOnly);
       }
-      if(gid==0){
-        // validation and test are performed only by the first group
-        if(mproto.test_steps()){
-          test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest,
-              cluster->nworkers_per_group());
-          if(test_net!=nullptr)
-            test_net->ShareParams(train_net, kValueOnly);
-        }
-        if(mproto.validation_steps()){
-          validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), 
kValidation,
-              cluster->nworkers_per_group());
-          if(validation_net!=nullptr)
-            validation_net->ShareParams(train_net, kValueOnly);
-        }
+      if(mproto.validation_steps()){
+        validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), 
kValidation,
+            cluster->nworkers_per_group());
+        if(validation_net!=nullptr)
+          validation_net->ShareParams(train_net, kValueOnly);
       }
-      // create ParamShard for the workers
-      auto shard=make_shared<Trainer::ParamShard>();
-      shards[gid]=shard;
-      for(auto layer: train_net->layers()){
-        int procsid=ProcsIDOf(gid, layer->partitionid(),kWorkerParam);
-        bool local=procsid==cluster->procs_id();
-        for(auto param: layer->GetParams()){
-          int owner_procs=param->owner()==param->id()?procsid:procs_id_;
-          if(shard->find(param->owner())==shard->end())
-            (*shard)[param->owner()]=
-              make_shared<ParamInfo>(param, local, owner_procs);
-          else
-            shard->at(param->owner())->AddParam(param, local);
+    }
+    // create ServerShard for the workers
+    auto shard=make_shared<WorkerShard>();
+    worker_shards_[gid]=shard;
+    for(auto layer: train_net->layers()){
+      int procsid=cluster->ProcsIDOf(gid, layer->partitionid(), kWorkerLayer);
+      bool local=procsid==cluster->procs_id();
+      for(auto param: layer->GetParams()){
+        for(auto entry :paramid2slices[param->owner()]){
+          param->AddSlice(entry.first,  entry.second);
         }
+        int owner_procs=param->owner()==param->id()?procsid:procs_id_;
+        if(shard->find(param->owner())==shard->end())
+          (*shard)[param->owner()]=
+            make_shared<ParamInfo>(param, local, owner_procs);
+        else
+          shard->at(param->owner())->AddParam(param, local);
       }
-      for(int wid=wstart;wid<wend;wid++){
-        shared_ptr<Worker> worker=nullptr;
-        if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
-          worker=make_shared<BPWorker>(nthreads++,gid, wid);
-        else{
+    }
+    for(int wid=wstart;wid<wend;wid++){
+      shared_ptr<Worker> worker=nullptr;
+      if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
+        worker=make_shared<BPWorker>(nthreads++,gid, wid);
+      else{
         // TODO add CDWorker
-        }
-        worker->Setup(mproto, train_net);
-        worker->set_test_net(test_net);
-        worker->set_validation_net(validation_net);
-        workers.push_back(worker);
       }
+      worker->Setup(mproto, train_net);
+      worker->set_test_net(test_net);
+      worker->set_validation_net(validation_net);
+      workers.push_back(worker);
     }
   }
+  return workers;
+}
+
+void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
+    int procs_id){
+  procs_id_=procs_id;
+  RegisterDefaultClasses(mproto);
+
+  auto cluster=Cluster::Get(cproto, procs_id);
+  router_=make_shared<Router>();
+  router_->Bind(kInprocRouterEndpoint);
+  if(cluster->nprocs()>1)
+    router_->Bind(cluster->endpoint());
+
+  int nthreads=1;
+  // create workers
+  vector<int> slices;
+  vector<shared_ptr<Worker>> workers=CreateWorkers(nthreads, mproto, &slices);
+  slice2server_=PartitionSlice(cluster->nservers_per_group(), slices);
+  nthreads+=workers.size();
+  // create servers
+  vector<HandleContext> ctx;
+  vector<shared_ptr<Server>> servers=CreateServers(nthreads, mproto, slices,
+      &ctx);
 
 #ifdef USE_MPI
   for(int i=0;i<nSocket;i++){
@@ -180,17 +258,16 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
     threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
     threads.push_back(std::thread(&Worker::Run,worker.get()));
-  Run(workers.size(), servers.size(), shards);
+  Run(workers.size(), servers.size());
   for(auto& thread: threads)
     thread.join();
 }
 
-void Trainer::Run(int nworkers, int nservers,
-    const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){
+void Trainer::Run(int nworkers, int nservers){
   auto cluster=Cluster::Get();
   procs_id_=cluster->procs_id();
   map<int, shared_ptr<Dealer>> interprocs_dealers;
-  Metric perf;
+  std::queue<Msg*> msg_queue;
   bool stop=false;
   while(!stop){
     Msg* msg=router_->Receive();
@@ -198,13 +275,16 @@ void Trainer::Run(int nworkers, int nservers,
       LOG(ERROR)<<"Connection broken!";
       exit(0);
     }
-    while(msg!=nullptr){
+    msg_queue.push(msg);
+    while(!msg_queue.empty()){
+      msg=msg_queue.front();
+      msg_queue.pop();
       int dst_flag=msg->dst_flag();
       int type=msg->type();
       int dst_procs=msg->dst_first();
       if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){
         if(type==kConnect){
-          msg =HandleConnect(&msg);
+          msg_queue.push(HandleConnect(&msg));
         }else if(type==kStop){
           if(msg->src_flag()==kServer)
             nservers--;
@@ -223,63 +303,63 @@ void Trainer::Run(int nworkers, int nservers,
             msg->next_frame();
             Metric cur;
             cur.ParseString(string((char*)msg->frame_data(), 
msg->frame_size()));
-            perf.AddMetrics(cur);
-            LOG(ERROR)<<prefix<<" step-" <<step<<", "<<perf.ToString();
-            perf.Reset();
+            LOG(ERROR)<<prefix<<" step-" <<step<<", "<<cur.ToString();
           }
           DeleteMsg(&msg);
         }else if(cluster->nserver_groups()>0){
-          int group_id=msg->src_first();
+          int group_id;
           int paramid=msg->target_first();
-          auto entry=shards.at(group_id)->at(paramid);
+          shared_ptr<ParamInfo> entry;
           switch (type){ // TODO process other requests, e.g. RESTful
             case kUpdate:
-              msg=HandleUpdate(entry, &msg);
+              group_id=msg->src_first();
+              entry=worker_shards_.at(group_id)->at(paramid);
+              for(auto x:HandleUpdate(entry, &msg))
+                msg_queue.push(x);
               break;
             case kRUpdate:
+              group_id=msg->dst_second();
+              entry=worker_shards_.at(group_id)->at(paramid);
               HandleUpdateResponse(entry, &msg);
               break;
             case kGet:
-              msg=HandleGet(entry, &msg);
+              group_id=msg->src_first();
+              entry=worker_shards_.at(group_id)->at(paramid);
+              for(auto x:HandleGet(entry, &msg))
+                msg_queue.push(x);
               break;
             case kRGet:
-              msg=HandleGetResponse(entry, &msg);
+              group_id=msg->dst_second();
+              entry=worker_shards_.at(group_id)->at(paramid);
+              HandleGetResponse(entry, &msg);
               break;
             case kPut:
-              msg=HandlePut(entry, &msg);
+              group_id=msg->src_first();
+              entry=worker_shards_.at(group_id)->at(paramid);
+              for(auto x:HandlePut(entry, &msg))
+                msg_queue.push(x);
               break;
             default:
               break;
           }
         }else{
-          delete msg;
-          msg=nullptr;
+          DeleteMsg(&msg);
         }
       }else{
         int dst_procs_id;
         if(dst_flag==kStub){
           dst_procs_id=msg->dst_first();
         }else{
-          dst_procs_id=ProcsIDOf(msg->dst_first(), msg->dst_second(), 
msg->dst_flag());
+          dst_procs_id=cluster->ProcsIDOf(msg->dst_first(),
+              msg->dst_second(), msg->dst_flag());
         }
         if(dst_procs_id!=procs_id_){
-        /*
-          // forward to other procs
-          if (interprocs_dealers.find(procs_id)==interprocs_dealers.end())
-          interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id);
-          interprocs_dealers[procs_id]->Send(&msg);
-          */
         }else{
           router_->Send(&msg);
         }
       }
     }
   }
-  /*
-  perf.Avg();
-  if(perf_step>=0)
-    LOG(ERROR)<<perf_prefix<<" step-"<<perf_step<<", "<<perf.ToString();
-    */
 }
 Msg* Trainer::HandleConnect(Msg** msg){
   string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());
@@ -294,64 +374,45 @@ Msg* Trainer::HandleConnect(Msg** msg){
   *msg=NULL;
   return reply;
 }
-int Trainer::Sharding(int param_id){
-  return param_id%Cluster::Get()->nservers_per_group();
-}
-/*
-int Worker::Sharding(int param_id){
-  static map<int, int> id2procs;
-  if(id2procs.find(param_id)==id2procs.end()){
-  auto cluster=Cluster::Get();
-  int server_group=group_id_%cluster->nserver_groups();
-  int nprocs_per_server_group=
-    cluster->nservers_per_group()/cluster->nservers_per_procs();
-  int procsid=server_group*nprocs_per_server_group+
-    param_id%nprocs_per_server_group;
-  procsid= cluster->server_worker_separate()?
-    cluster->nworker_procs()+procsid:procsid;
-  id2procs[param_id]=procsid;
-  }
-  return id2procs[param_id];
-}
-*/
-
 
-Msg* Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){
-  Msg* msgg=*msg, *reply=nullptr;
+const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){
+  Msg* msgg=*msg;
+  vector<Msg*> replies;
   int version=msgg->target_second();
   if(msgg->src_flag()==kStub){
     if(version<=pi->shares.at(0)->version()){
-      reply=pi->shares.at(0)->HandleGetMsg(msg);
+      pi->shares.at(0)->HandleGetMsg(msg);
     }else if(version>pi->next_version){
       // reinsert into a msg queue.
     }
   }else if(version>pi->next_version){
     pi->next_version=version;
-    int gid=msgg->src_first(), pid=msgg->target_first();
-    int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group();
-    int dstid=Sharding(pid);
-    int dstprocs=ProcsIDOf(dstgroup, dstid, kServer);
-    reply=pi->shares.at(0)->GenGetMsg(dstprocs!=procs_id_);
-    reply->set_src(procs_id_, gid, kStub);
-    reply->set_dst(dstgroup, dstid, kServer);
+    int gid=msgg->src_first();
+    int group=gid/Cluster::Get()->nworker_groups_per_server_group();
+    auto param=pi->shares.at(0);
+    for(int idx=0, id=param->slice_start();idx<param->num_slices();idx++){
+      int server=slice2server_[id+idx];
+      int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
+      auto x=param->GenGetMsg(procs!=procs_id_, idx);
+      x->set_target(param->owner(), id+idx, param->local_version()+1);
+      x->set_src(procs_id_, gid, kStub);
+      x->set_dst(group, server, kServer);
+      replies.push_back(x);
+    }
   }
-  return reply;
-}
-
-Msg* Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){
-  pi->shares.at(0)->ParseGetResponseMsg(msg);
-  return nullptr;
-  // process get requests in waiting queue
+  return replies;
 }
 
-Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){
-  Msg* msgg=*msg, *update=nullptr;
+const vector<Msg*> Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){
+  Msg* msgg=*msg ;
+  vector<Msg*> ret;
   int step= msgg->target_second();
   if(msgg->src_flag()==kStub){
-    if(pi->num_update<pi->num_local)
-      return *msg; //wait unitl local updates are ready
-    int n;
-    sscanf((char*)(*msg)->frame_data(), "%d", &n);
+    if(pi->num_update<pi->num_local){
+      ret.push_back(*msg);
+      return ret; //wait unitl local updates are ready
+    }
+    int n; sscanf((char*)(*msg)->frame_data(), "%d", &n);
     pi->num_update+=n;
     auto it=pi->shares.begin();
     auto shape=mshadow::Shape1((*it)->size());
@@ -368,45 +429,70 @@ Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** 
msg){
     }
     agg/=pi->num_total;
     if(pi->num_local<pi->num_total){
-      update=pi->shares.at(0)->GenUpdateMsg(pi->owner_procs!=procs_id_, step);
+      /*
       int gid=msgg->src_first();
-      update->set_src(procs_id_, gid,kStub);
-      update->set_dst(pi->owner_procs, gid, kStub);
+      for(auto update: pi->shares.at(0)->GenUpdateMsg(step)){
+        update->set_src(procs_id_, gid,kStub);
+        update->set_dst(pi->owner_procs, gid, kStub);
+        ret.push_back(update);
+      }
       pi->num_update=0;
+      */
     }
   }
   if(pi->num_update==pi->num_total){
-    int gid=msgg->src_first();
-    int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group();
-    int dstid=Sharding(msgg->target_first());
-    int dstprocs=ProcsIDOf(dstgroup, dstid, kServer);
-    update=pi->shares.at(0)->GenUpdateMsg(dstprocs!=procs_id_, step);
-    update->set_src(procs_id_, gid, kStub);
-    update->set_dst(dstgroup, dstid, kServer);
+    auto param=pi->shares.at(0);
+    int 
group=msgg->src_first()/Cluster::Get()->nworker_groups_per_server_group();
+    int srcgid=msgg->src_first();
+    for(int idx=0, id=param->slice_start(); idx<param->num_slices();idx++){
+      int server=slice2server_[idx+id];
+      int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
+      auto x=param->GenUpdateMsg(procs!=procs_id_, idx);
+      x->set_target(param->owner(), id+idx, step);
+      x->set_src(procs_id_, srcgid, kStub);
+      x->set_dst(group, server, kServer);
+      ret.push_back(x);
+    }
     pi->num_update=0;
   }
-  delete *msg;
-  *msg=NULL;
-  return update;
+  DeleteMsg(msg);
+  return ret;
 }
 
-int Trainer::HandleUpdateResponse(shared_ptr<Trainer::ParamInfo> pi, Msg** 
msg){
-  HandleGetResponse(pi, msg);
-  return 1;
-}
-
-Msg* Trainer::HandlePut(shared_ptr<Trainer::ParamInfo>pi, Msg** msg){
+const vector<Msg*> Trainer::HandlePut(shared_ptr<ParamInfo>pi, Msg** msg){
+  vector<Msg*> ret;
   CHECK_NE((*msg)->src_flag(), kStub);
   int gid=(*msg)->src_first();
-  int id=(*msg)->target_first();
-  int dstgroup=gid/Cluster::Get()->nworker_groups_per_server_group();
-  int dstid=Sharding(id);
-  int dstprocs=ProcsIDOf(dstgroup, dstid, kServer);
-  Msg* put=pi->shares.at(0)->GenPutMsg(dstprocs!=procs_id_);
-  put->set_src(procs_id_, gid , kStub);
-  put->set_dst(dstgroup, dstid, kServer);
-  delete *msg;
-  *msg=NULL;
-  return put;
+  auto param=pi->shares.at(0);
+  int group=gid/Cluster::Get()->nworker_groups_per_server_group();
+  for(int idx=0, start=param->slice_start();idx<param->num_slices(); idx++){
+    int server=slice2server_[start+idx];
+    int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
+    auto x=param->GenPutMsg(procs!=procs_id_, idx);
+    x->set_target(param->owner(), start+idx, param->version());
+    x->set_src(procs_id_, gid, kStub);
+    x->set_dst(group, server, kServer);
+    ret.push_back(x);
+  }
+  DeleteMsg(msg);
+  return ret;
+}
+
+void Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){
+  int version=(*msg)->target_third();
+  int sliceid=(*msg)->target_second();
+  auto param=pi->shares.at(0);
+  if(param->ParseGetResponseMsg(msg,sliceid-param->slice_start()))
+    param->set_version(version);
+  // process get requests in waiting queue
+}
+
+
+void Trainer::HandleUpdateResponse(shared_ptr<ParamInfo> pi, Msg** msg){
+  int version=(*msg)->target_third();
+  int sliceid=(*msg)->target_second();
+  auto param=pi->shares.at(0);
+  if(param->ParseUpdateResponseMsg(msg,sliceid-param->slice_start()))
+    param->set_version(version);
 }
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index f047f0f..0835bbb 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -19,10 +19,7 @@ void Worker::Setup(const ModelProto& model,
   train_net_=train_net;
   modelproto_=model;
   auto cluster=Cluster::Get();
-  if(cluster->nserver_groups()&&cluster->server_update()){
-    int sgid=group_id_/cluster->nworker_groups_per_server_group();
-    CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid));
-  }else{
+  if(!(cluster->nserver_groups()&&cluster->server_update())){
     updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
         ->Create("Updater"));
     updater_->Init(model.updater());
@@ -30,6 +27,12 @@ void Worker::Setup(const ModelProto& model,
 }
 
 void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){
+  if(updater_==nullptr){
+    auto cluster=Cluster::Get();
+    int sgid=group_id_/cluster->nworker_groups_per_server_group();
+    CHECK(cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid));
+  }
+
   dealer->Connect(kInprocRouterEndpoint);
   Msg* ping=new Msg();
   ping->set_src(group_id_, worker_id_, type);
@@ -60,7 +63,7 @@ void Worker::Run(){
       for(auto param: layer->GetParams()){
         if(param->owner() == param->id()){
           if(group_id_==0)
-            param->Init(0);
+            param->InitValues(0);
           else
             Get(param, modelproto_.warmup_steps());
         }

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index b00a3cd..c344627 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -12,24 +12,38 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) 
{
   procs_id_=procs_id;
   cluster_ = cluster;
   SetupFolders(cluster);
-  size_t nprocs;
   if(server_worker_separate())
-    nprocs=nworker_procs()+nserver_procs();
+    nprocs_=nworker_procs()+nserver_procs();
   else
-    nprocs=std::max(nworker_procs(), nserver_procs());
-  CHECK_LT(procs_id, nprocs);
-  if (cluster_.has_nprocs())
-    CHECK_EQ(cluster.nprocs(), nprocs);
-  else
-    cluster_.set_nprocs(nprocs);
-  if(nprocs>1){
+    nprocs_=std::max(nworker_procs(), nserver_procs());
+  CHECK_LT(procs_id, nprocs_);
+  if(nprocs_>1){
     std::ifstream ifs(cluster.hostfile(), std::ifstream::in);
     std::string line;
-    while(std::getline(ifs, line)&&endpoints_.size()<nprocs){
+    while(std::getline(ifs, line)&&endpoints_.size()<nprocs_){
       endpoints_.push_back(line);
     }
-    CHECK_EQ(endpoints_.size(), nprocs);
+    CHECK_EQ(endpoints_.size(), nprocs_);
+  }
+
+  // locate the process id of every worker/server
+  int ngrps=cluster_.nworker_groups(), grp_size=cluster_.nworkers_per_group();
+  int procs;
+  for(int i=0;i<ngrps;i++){
+    for(int j=0;j<grp_size;j++){
+      procs=(i*grp_size+j) / cluster_.nworkers_per_procs();
+      procs_ids_[Hash(i,j,kWorkerLayer)]=procs;
+      procs_ids_[Hash(i,j,kWorkerParam)]=procs;
+    }
+  }
+  ngrps=cluster_.nserver_groups(), grp_size=cluster_.nservers_per_group();
+  int offset=cluster_.server_worker_separate()? procs:0;
+  for(int i=0;i<ngrps;i++){
+    for(int j=0;j<grp_size;j++){
+      procs_ids_[Hash(i,j,kServer)]=(i*grp_size+j) / 
cluster_.nservers_per_procs()+offset;
+    }
   }
+
   auto rt=new ZKClusterRT(cluster_.zookeeper_host());
   rt->Init();
   cluster_rt_=shared_ptr<ClusterRuntime>(static_cast<ClusterRuntime*>(rt));
@@ -52,4 +66,17 @@ shared_ptr<Cluster> Cluster::Get() {
   }
   return instance_;
 }
+int Cluster::Hash(int gid, int id, int flag){
+  int ret=-1;
+  if(flag==kServer){
+    ret=(flag*cluster_.nserver_groups()+gid)*cluster_.nservers_per_group() + 
id;
+  }else{
+    ret=(flag*cluster_.nworker_groups()+gid)*cluster_.nworkers_per_group() + 
id;
+  }
+  return ret;
+}
+int Cluster::ProcsIDOf(int group_id, int id, int flag){
+  return procs_ids_.at(Hash(group_id, id, flag));
+}
+
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
index 0697060..783b1f9 100644
--- a/src/utils/common.cc
+++ b/src/utils/common.cc
@@ -85,5 +85,20 @@ void WriteProtoToBinaryFile(const Message& proto, const 
char* filename) {
   int fd= open(filename, O_CREAT|O_WRONLY|O_TRUNC, 0644);
   CHECK(proto.SerializeToFileDescriptor(fd));
 }
+int gcd(int a, int b)
+{
+  for (;;)
+  {
+    if (a == 0) return b;
+    b %= a;
+    if (b == 0) return a;
+    a %= b;
+  }
+}
+int LeastCommonMultiple(int a, int b)
+{
+  int temp = gcd(a, b);
 
+  return temp ? (a / temp * b) : 0;
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e0a52a62/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index a7e5230..75cc4cc 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -3,6 +3,7 @@
 #include <chrono>
 #include <random>
 #include "utils/param.h"
+#include "proto/cluster.pb.h"
 #include "mshadow/tensor.h"
 #include "utils/singleton.h"
 using namespace mshadow;
@@ -10,59 +11,136 @@ using std::vector;
 using std::string;
 namespace singa {
 
-Param::Param():data_(nullptr), local_version_(-1){}
+Param::Param():data_(nullptr), slice_start_(0), num_slices_(0),
+  num_pending_requests_(0),local_version_(-1){
+}
+void Param::Setup(const ParamProto& proto, const vector<int>& shape){
+  data_=std::make_shared<Blob<float>>(shape);
+  grad_.Reshape(shape);
+  history_.Reshape(shape);
+  proto_=proto;
+}
 
-Msg* Param::GenPutMsg(bool copy, int v){
+void Param::AddSlice(int slice_id, int size){
+  int offset=0;
+  if(slice_size_.size()>0){
+    //must be added in order
+    CHECK_EQ(slice_start_+num_slices_, slice_id);
+    offset=slice_offset_.back()+slice_size_.back();
+  }
+  else{
+    slice_start_=slice_id;
+    offset=0;
+  }
+  slice_offset_.push_back(offset);
+  slice_size_.push_back(size);
+  pending_get_.push_back(false);
+  pending_update_.push_back(false);
+  pending_put_.push_back(false);
+  num_slices_++;
+}
+
+void Param::InitValues(int version){
+  Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
+  auto random=TSingleton<Random<cpu>>::Instance();
+  switch (proto_.init_method()) {
+  case ParamProto::kConstant:
+    data=proto_.value();
+    break;
+  case ParamProto::kUniform:
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value();
+    break;
+    /*
+  case ParamProto::kUniformSqrtFanIn:
+    CHECK_GT(fan_in_,0);
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(fan_in_ / 3.0f);
+    break;
+    */
+  case ParamProto::kUniformSqrtFanInOut:
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]);
+    break;
+  case ParamProto::kGaussian:
+    random->SampleGaussian(data, proto_.mean(), proto_.std());
+    if(proto_.value())
+      data*= proto_.value();
+    break;
+  case ParamProto::kGaussainSqrtFanIn:
+    random->SampleGaussian(data, proto_.mean(), proto_.std());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(data_->shape()[0]);
+    break;
+  default:
+    LOG(ERROR) << "Illegal parameter init method ";
+    break;
+  }
+  set_version(version);
+}
+
+/**************Message related functions********/
+Msg* Param::GenPutMsg(bool copy, int idx){
+  CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kPut);
-  msg->set_target(owner(), version());
   char buf[128];
-  sprintf(buf, "%d %f %f", size(),
+  sprintf(buf, "%d %f %f", slice_size_[idx],
       learning_rate_multiplier(), weight_decay_multiplier());
+  void *ptr=mutable_cpu_data()+slice_offset_[idx];
   if(copy){
     sprintf(buf+strlen(buf), " %p ", nullptr);
     msg->add_frame(buf, strlen(buf));
-    msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+    msg->add_frame(ptr, slice_size_[idx]*sizeof(float));
   }else{
-    //share the data blob which includes the blob version
-    sprintf(buf+strlen(buf), " %p ", data_.get());
+    sprintf(buf+strlen(buf), " %p ", ptr);
     msg->add_frame(buf, strlen(buf));
   }
+  pending_put_[idx]=true;
+  num_pending_requests_++;
        return msg;
 }
 
-Msg* Param::GenGetMsg(bool copy, int v){
+Msg* Param::GenGetMsg(bool copy, int idx){
+  CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kGet);
-  msg->set_target(owner(), local_version()+1);
   msg->add_frame(&copy, sizeof(bool));
+  pending_get_[idx]=true;
+  num_pending_requests_++;
   return msg;
 }
 
-Msg* Param::GenUpdateMsg(bool copy, int v){
+Msg* Param::GenUpdateMsg(bool copy, int idx){
+  CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kUpdate);
-  msg->set_target(owner(), v);
   msg->add_frame(&copy, sizeof(bool));
+  void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx];
   if(copy)
-    msg->add_frame(mutable_cpu_grad(), size()*sizeof(float));
+    msg->add_frame(ptr, slice_size_[idx]*sizeof(float));
   else{ // to share values of grad blob
-    char buf[32]; sprintf(buf, " %p ", &grad_);
+    char buf[32]; sprintf(buf, " %p ", ptr);
     msg->add_frame(buf, strlen(buf));
   }
+  pending_update_[idx]=true;
+  num_pending_requests_++;
   return msg;
 }
 
-Msg* Param::GenSyncMsg(bool copy, int v){
+Msg* Param::GenSyncMsg(){
   return nullptr;
 }
 
 Msg* Param::HandlePutMsg(Msg** msg){
   int size;
   float lr, wc;
-  void* ptr;
-  sscanf(static_cast<char*>((*msg)->frame_data()), "%d %f %f %p ",
-      &size, &lr, &wc, &ptr);
+  float* ptr;
+  sscanf(static_cast<char*>((*msg)->frame_data()),
+      "%d %f %f %p ", &size, &lr, &wc, &ptr);
   proto_.set_learning_rate_multiplier(lr);
   proto_.set_weight_decay_multiplier(wc);
   vector<int> shape{size};
@@ -70,49 +148,45 @@ Msg* Param::HandlePutMsg(Msg** msg){
   history_.Reshape(shape);
   data_=std::make_shared<Blob<float>>(shape);
   if(ptr==nullptr){
-    data_->set_version((*msg)->target_second());
     CHECK((*msg)->next_frame());
     CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
     memcpy(mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float));
   } else{
-    data_->ShareData(*static_cast<Blob<float>*>(ptr));
+    data_->set_cpu_data(ptr);
   }
   DeleteMsg(msg);
   return nullptr;
 }
 
 Msg* Param::HandleGetMsg(Msg** msg){
-  if((*msg)->target_second()<=version()){
-    bool* copy=static_cast<bool*>((*msg)->frame_data());
-    (*msg)->next_frame();
-    if(*copy)
-      (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
-    (*msg)->SwapAddr();
-    (*msg)->set_type(kRGet);
-  }
+  bool* copy=static_cast<bool*>((*msg)->frame_data());
+  (*msg)->next_frame();
+  if(*copy)
+    (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
+  // else the mem space is shared among all worker and servers
+  (*msg)->SwapAddr();
+  (*msg)->set_type(kRGet);
   return *msg;
 }
 
-const std::pair<bool, int> Param::ParseUpdateMsg(Msg** msg){
-  int step=(*msg)->target_second();
+int Param::ParseUpdateMsg(Msg** msg){
   bool* copy=static_cast<bool*>((*msg)->frame_data());
   (*msg)->next_frame();
   if(*copy){
     CHECK((*msg)->frame_size());
     memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size());
   }else {// use the same data field of the grad blob
-    Blob<float>* ptr=nullptr;
+    float* ptr=nullptr;
     sscanf(static_cast<char*>((*msg)->frame_data()), " %p ", &ptr);
-    grad_.ShareData(*ptr);
+    grad_.set_cpu_data(ptr);
   }
   DeleteMsg(msg);
-  return std::make_pair(*copy, step);
+  return *copy;
 }
 
-Msg* Param::GenUpdateResponseMsg(bool copy, int v){
+Msg* Param::GenUpdateResponseMsg(bool copy){
   Msg* msg=new Msg();
   msg->set_type(kRUpdate);
-  msg->set_target(owner(), v);
   msg->add_frame(&copy, sizeof(bool));
   if(copy)
     msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
@@ -124,74 +198,35 @@ Msg* Param::HandleSyncMsg(Msg** msg){
   return nullptr;
 }
 
-int Param::ParseSyncResponseMsg(Msg** msg){
+int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){
   DeleteMsg(msg);
   return 1;
 }
-int Param::ParsePutResponseMsg(Msg **msg){
-  return ParseSyncResponseMsg(msg);
+
+int Param::ParseGetResponseMsg(Msg **msg, int slice_idx){
+  CHECK_EQ(pending_get_[slice_idx], true);
+  pending_get_[slice_idx]=false;
+  ParseResponseMsg(msg, slice_idx);
+  return (--num_pending_requests_)%num_slices_==0;
+}
+
+int Param::ParseUpdateResponseMsg(Msg **msg, int slice_idx){
+  CHECK_EQ(pending_update_[slice_idx], true);
+  pending_update_[slice_idx]=false;
+  ParseResponseMsg(msg, slice_idx);
+  return (--num_pending_requests_)%num_slices_==0;
 }
-int Param::ParseGetResponseMsg(Msg **msg){
+
+void Param::ParseResponseMsg(Msg** msg, int slice_idx){
   bool *copy=static_cast<bool*>((*msg)->frame_data());
   (*msg)->next_frame();
   if(*copy){
     CHECK((*msg)->frame_size());
-    memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size());
-  }  // must be set after all other settings are done!
-  set_version((*msg)->target_second());
+    memcpy(mutable_cpu_data()+slice_offset_[slice_idx-slice_start_],
+        (*msg)->frame_data(), (*msg)->frame_size());
+  }
   DeleteMsg(msg);
-  return 1;
 }
-int Param::ParseUpdateResponseMsg(Msg **msg){
-  return ParseGetResponseMsg(msg);
-}
-
-void Param::Setup(const ParamProto& proto, const vector<int>& shape,
-    int fan_in){
-  data_=std::make_shared<Blob<float>>(shape);
-  grad_.Reshape(shape);
-  history_.Reshape(shape);
-  proto_=proto;
-  fan_in_=fan_in;
 }
 
-void Param::Init(int v){
-  Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(size()));
-  auto random=TSingleton<Random<cpu>>::Instance();
-  switch (proto_.init_method()) {
-  case ParamProto::kConstant:
-    data=proto_.value();
-    break;
-  case ParamProto::kUniform:
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    if(proto_.value())
-      data*= proto_.value();
-    break;
-  case ParamProto::kUniformSqrtFanIn:
-    CHECK_GT(fan_in_,0);
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    if(proto_.value())
-      data*= proto_.value()/ sqrt(fan_in_ / 3.0f);
-    break;
-  case ParamProto::kUniformSqrtFanInOut:
-    random->SampleUniform(data, proto_.low(), proto_.high());
-    if(proto_.value())
-      data*= proto_.value()/ sqrt(data_->shape()[0] +data_->shape()[1]);
-    break;
-  case ParamProto::kGaussian:
-    random->SampleGaussian(data, proto_.mean(), proto_.std());
-    if(proto_.value())
-      data*= proto_.value();
-    break;
-  case ParamProto::kGaussainSqrtFanIn:
-    random->SampleGaussian(data, proto_.mean(), proto_.std());
-    if(proto_.value())
-      data*= proto_.value()/ sqrt(data_->shape()[0]);
-    break;
-  default:
-    LOG(ERROR) << "Illegal parameter init method ";
-    break;
-  }
-  set_version(v);
-}
-}  // namespace singa
+// namespace singa

Reply via email to