1. move functions in pm_server (pm_worker) into server (trainer) to simplify 
the logics. now workers send simple messages to the stub thread which construct 
the real update/get/put requests.
the stub thread also handles the responses from servers. E.g., the get/update 
response is handled by the stub now. the workers then wait until its param's 
version is updated in the collect function.
avoid deadlocks for param_dealer_ and layer_dealer_
2. tested data partition in single group in one procs.
3. generate a json file under workspace/visualization representing the neural 
net structure. users can create an image using the python script 
(scirpt/graph.py) reading the json file.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b5b943c7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b5b943c7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b5b943c7

Branch: refs/heads/master
Commit: b5b943c7deb33c976272471446545f72d2b5a2be
Parents: 39969f2
Author: wang wei <[email protected]>
Authored: Sat May 16 22:59:49 2015 +0800
Committer: wang wei <[email protected]>
Committed: Sat May 16 22:59:49 2015 +0800

----------------------------------------------------------------------
 examples/cifar10/model.conf    |   4 +-
 include/communication/msg.h    |  77 +++++----
 include/communication/socket.h |   6 +-
 include/neuralnet/base_layer.h |  14 ++
 include/trainer/pm_server.h    |  91 ----------
 include/trainer/pm_worker.h    | 172 -------------------
 include/trainer/server.h       |  62 ++++++-
 include/trainer/trainer.h      |  91 +++++++++-
 include/trainer/worker.h       |  30 +---
 include/utils/common.h         |  31 ++++
 src/communication/socket.cc    |  16 +-
 src/main.cc                    |   1 +
 src/neuralnet/base_layer.cc    |  37 ++--
 src/neuralnet/neuralnet.cc     |  19 ++-
 src/proto/model.pb.h           |  82 ++++-----
 src/proto/model.proto          |   2 +-
 src/trainer/pm_server.cc       |  99 -----------
 src/trainer/pm_worker.cc       | 324 ------------------------------------
 src/trainer/server.cc          | 123 ++++++++++----
 src/trainer/trainer.cc         | 220 +++++++++++++++++++-----
 src/trainer/worker.cc          | 241 ++++++++++-----------------
 src/utils/graph.cc             |  50 ++++--
 src/utils/param.cc             |  71 ++++----
 23 files changed, 758 insertions(+), 1105 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/examples/cifar10/model.conf
----------------------------------------------------------------------
diff --git a/examples/cifar10/model.conf b/examples/cifar10/model.conf
index bace74d..76ce8db 100644
--- a/examples/cifar10/model.conf
+++ b/examples/cifar10/model.conf
@@ -1,8 +1,8 @@
 name: "cifar10-convnet"
 train_steps: 70000
-test_steps:5
+test_steps:100
 test_frequency:1000
-display_frequency:1
+display_frequency:30
 updater{
   momentum:0.9
   weight_decay:0.004

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/communication/msg.h
----------------------------------------------------------------------
diff --git a/include/communication/msg.h b/include/communication/msg.h
index 54b4601..21ac78e 100644
--- a/include/communication/msg.h
+++ b/include/communication/msg.h
@@ -13,27 +13,28 @@ class BaseMsg{
     */
   virtual ~BaseMsg(){};
   /**
-    * @param group_id worker/server group id
+    * @param first worker/server group id
     * @param id worker/server id within the group
     * @param flag 0 for server, 1 for worker, 2 for stub
     */
-  virtual void set_src(int group_id, int id, int flag)=0;
-  virtual void set_dst(int group_id, int id, int flag)=0;
+  virtual void set_src(int first, int second, int flag)=0;
+  virtual void set_dst(int first, int second, int flag)=0;
   virtual void set_src(int procs_id, int flag)=0;
   virtual void set_dst(int procs_id, int flag)=0;
-  virtual int src_group_id() const=0;
-  virtual int dst_group_id() const=0;
-  virtual int src_id() const=0;
-  virtual int dst_id() const=0;
+  virtual int src_first() const=0;
+  virtual int dst_first() const=0;
+  virtual int src_second() const=0;
+  virtual int dst_second() const=0;
   virtual int src_flag() const=0;
   virtual int dst_flag() const=0;
   virtual void set_type(int type)=0;
   virtual int type() const=0;
-  virtual void set_target(int target)=0;
-  virtual int target() const=0;
+  virtual void set_target(int first, int second)=0;
+  virtual int target_first() const=0;
+  virtual int target_second() const=0;
 
   /**
-   * Copy src and dst address, including group_id, id, flag
+   * Copy src and dst address, including first, id, flag
    */
   virtual BaseMsg* CopyAddr()=0;
   virtual void SetAddr(BaseMsg* msg)=0;
@@ -64,11 +65,11 @@ class Msg : public BaseMsg{
     if(msg_!=NULL)
       zmsg_destroy(&msg_);
   }
-  virtual void set_src(int group_id, int id, int flag){
-    src_=(group_id<<kOff1)|(id<<kOff2)|flag;
+  virtual void set_src(int first, int second, int flag){
+    src_=(first<<kOff1)|(second<<kOff2)|flag;
   }
-  virtual void set_dst(int group_id, int id, int flag){
-    dst_=(group_id<<kOff1)|(id<<kOff2)|flag;
+  virtual void set_dst(int first, int second, int flag){
+    dst_=(first<<kOff1)|(second<<kOff2)|flag;
   }
   virtual void set_src(int procs_id, int flag){
     set_src(procs_id, 0, flag);
@@ -82,20 +83,20 @@ class Msg : public BaseMsg{
   int dst() const {
     return dst_;
   }
-  virtual int src_group_id() const {
+  virtual int src_first() const {
     int ret=src_>>kOff1;
     return ret;
   }
 
-  virtual int dst_group_id() const{
+  virtual int dst_first() const{
     int ret=dst_>>kOff1;
     return ret;
   }
-  virtual int src_id() const{
+  virtual int src_second() const{
     int ret=(src_&kMask1)>>kOff2;
     return ret;
   }
-  virtual int dst_id() const{
+  virtual int dst_second() const{
     int ret=(dst_&kMask1)>>kOff2;
     return ret;
   }
@@ -113,22 +114,24 @@ class Msg : public BaseMsg{
   }
 
   virtual void set_type(int type){
-    target_=(type<<kOff3)|(target_&kMask3);
-  }
-  virtual void set_target(int target){
-    target_=(target_>>kOff3)<<kOff3;
-    target_=target_|target;
+    type_=type;
   }
   virtual int type() const{
-    int ret=target_>>kOff3;
-    return ret;
+    return type_;
   }
-  virtual int target() const{
-    int ret=target_&kMask3;
-    return ret;
+
+  virtual void set_target(int first, int second){
+    target_first_=first;
+    target_second_=second;
+  }
+  virtual int target_first() const{
+    return target_first_;
+  }
+  virtual int target_second() const{
+    return target_second_;
   }
 
-  virtual BaseMsg* CopyAddr(){
+ virtual BaseMsg* CopyAddr(){
     Msg* msg=new Msg();
     msg->src_=src_;
     msg->dst_=dst_;
@@ -158,25 +161,27 @@ class Msg : public BaseMsg{
 
   void ParseFromZmsg(zmsg_t* msg){
     char* tmp=zmsg_popstr(msg);
-    sscanf(tmp, "%d %d %d", &src_, &dst_, &target_);
+    sscanf(tmp, "%d %d %d %d %d",
+        &src_, &dst_, &type_, &target_first_, &target_second_);
     //LOG(ERROR)<<"recv "<<src_<<" "<<dst_<<" "<<target_;
     frame_=zmsg_next(msg);
     msg_=msg;
   }
 
   zmsg_t* DumpToZmsg(){
-    zmsg_pushstrf(msg_, "%d %d %d",src_, dst_,target_);
+    zmsg_pushstrf(msg_, "%d %d %d %d %d",
+        src_, dst_, type_, target_first_, target_second_);
     //LOG(ERROR)<<"send "<<src_<<" "<<dst_<<" "<<target_;
-    zmsg_t* tmp=msg_;
+    zmsg_t *tmp=msg_;
     msg_=NULL;
     return tmp;
   }
 
  protected:
-  static const unsigned int kOff1=16, kOff2=4, kOff3=24;
-  static const unsigned int kMask1=(1<<kOff1)-1, kMask2=(1<<kOff2)-1,
-               kMask3=(1<<kOff3)-1;
-  unsigned int src_, dst_, target_;
+  static const unsigned int kOff1=16, kOff2=4;
+  static const unsigned int kMask1=(1<<kOff1)-1, kMask2=(1<<kOff2)-1;
+  int src_, dst_;
+  int type_, target_first_, target_second_;
   zmsg_t* msg_;
   zframe_t *frame_;
 };

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/communication/socket.h
----------------------------------------------------------------------
diff --git a/include/communication/socket.h b/include/communication/socket.h
index 771c3da..4b1f467 100644
--- a/include/communication/socket.h
+++ b/include/communication/socket.h
@@ -16,7 +16,7 @@ class Socket{
     * @param  the message to be sent
     * @return 1 for success queuing the message for sending, 0 for failure
     */
-  virtual int Send(Msg* msg)=0;
+  virtual int Send(Msg** msg)=0;
   /**
     * Receive a message from any connected socket.
     *
@@ -84,7 +84,7 @@ class Dealer : public Socket{
     * @return 1 connection sets up successfully; 0 otherwise
     */
   virtual int Connect(string endpoint);
-  virtual int Send(Msg* msg);
+  virtual int Send(Msg** msg);
   virtual Msg* Receive();
   virtual void* InternalID() const{
     return dealer_;
@@ -123,7 +123,7 @@ class Router : public Socket{
  /**
    * If the destination socket has not connected yet, buffer this the message.
    */
-  virtual int Send(Msg* msg);
+  virtual int Send(Msg** msg);
   virtual Msg* Receive();
   virtual void* InternalID() const{
     return router_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/neuralnet/base_layer.h
----------------------------------------------------------------------
diff --git a/include/neuralnet/base_layer.h b/include/neuralnet/base_layer.h
index 8e49059..777c2cb 100644
--- a/include/neuralnet/base_layer.h
+++ b/include/neuralnet/base_layer.h
@@ -288,6 +288,19 @@ class BridgeSrcLayer: public Layer {
 
   virtual void ComputeFeature(bool training, const vector<SLayer>& srclayers);
   virtual void ComputeGradient(const vector<SLayer>& srclayers);
+  virtual const Blob<float>& data(const Layer* from) const {
+    return srclayers_[0]->data(this);
+  }
+  virtual Blob<float>* mutable_data(const Layer* from){
+    return srclayers_[0]->mutable_data(this);
+  }
+
+  virtual const Blob<float>& grad(const Layer* from) const {
+    return srclayers_[0]->grad(this);
+  }
+  virtual Blob<float>* mutable_grad(const Layer* from) {
+    return srclayers_[0]->mutable_grad(this);
+  }
   int dst_partition() const;
   virtual bool is_bridgesrclayer() const {
     return true;
@@ -478,6 +491,7 @@ class SliceLayer: public Layer {
  protected:
   int SliceID(const Layer* layer) const;
   vector<Blob<float>> datavec_, gradvec_;
+  int slice_dim_, slice_num_;
 };
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/pm_server.h
----------------------------------------------------------------------
diff --git a/include/trainer/pm_server.h b/include/trainer/pm_server.h
deleted file mode 100644
index b759844..0000000
--- a/include/trainer/pm_server.h
+++ /dev/null
@@ -1,91 +0,0 @@
-#ifndef INCLUDE_TRAINER_PM_SERVER_H_
-#define INCLUDE_TRAINER_PM_SERVER_H_
-
-#include <czmq.h>
-#include <memory>
-#include <vector>
-#include <map>
-#include <string.h>
-#include "proto/model.pb.h"
-#include "utils/updater.h"
-#include "utils/param.h"
-#include "communication/msg.h"
-#include "communication/socket.h"
-using std::vector;
-using std::string;
-using std::shared_ptr;
-
-namespace singa{
-
-/**
- * Parameter manager at the server side.
- *
- * Repsond to worker's get/put/udpate request, and periodically syncing with
- * other servers.
- *
- * Normally, the PMServer creates a response message for each request which
- * will be sent back to the one who issued the request. However, if the request
- * are not processed successfully, the original message will be returned. The
- * sever does not know the returned message (response or the original message),
- * it just sends it to the router. The router will decide to re-send the
- * request to the server or send it to the worker.
- *
- */
-class PMServer{
-public:
-  typedef std::map<int, shared_ptr<Param>> ParamShard;
-
-       void Setup(int group_id, int server_id, shared_ptr<ParamShard> shard,
-       const UpdaterProto& proto);
-
-       ~PMServer();
-
-       /**
-        * Process GET request.
-   *
-   * @return the orignal message or response message
-   */
-       virtual Msg* HandleGet(Msg** msg);
-
-       /**
-        * Process Update request.
-   *
-   * @return the orignal message or response message
-   */
-       virtual Msg* HandleUpdate(Msg** msg);
-
-       /**
-        * Process PUT request.
-   *
-   * @return the original message or response message. If we don't want need to
-   * acknowledge the put request, then return nullptr.
-        */
-       virtual Msg* HandlePut(Msg **msg);
-
-       /**
-   * TODO Process SYNC request.
-        */
-       virtual Msg* HandleSyncRequest(Msg** msg);
-
-       /**
-   * TODO Process SYNC response.
-        */
-       virtual int HandleSyncResponse(Msg** msg);
-
-  /**
-   * Scheduler for synchronizing server groups.
-   *
-   * TODO implement the Caffe's synchronization scheduler for data parallelism
-   */
-  virtual bool SyncNow();
-
- protected:
-  int group_id_, server_id_;
-  shared_ptr<ParamShard> shard_;
-  shared_ptr<Dealer> dealer_;
-  shared_ptr<Updater> updater_;
-};
-
-} // namespace singa
-
-#endif // INCLUDE_TRAINER_PM_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/pm_worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/pm_worker.h b/include/trainer/pm_worker.h
deleted file mode 100644
index 9b973d6..0000000
--- a/include/trainer/pm_worker.h
+++ /dev/null
@@ -1,172 +0,0 @@
-#ifndef INCLUDE_TRAINER_PM_WORKER_H_
-#define INCLUDE_TRAINER_PM_WORKER_H_
-
-#include <memory>
-#include <vector>
-#include <map>
-#include <string>
-#include <atomic>
-#include "utils/param.h"
-#include "communication/msg.h"
-
-using std::string;
-using std::vector;
-using std::shared_ptr;
-using std::map;
-
-namespace singa {
-
-/**
- * Counters used to construct a parameter shard.
- *
- * For each worker group:
- *   Every unique Param object is associated with a ParamCounter object whose
- *   param field points the to Param object itself.
- *
- *   Param objects sharing the same values (due to data parallelism) are
- *   associated with the same ParamCounter whose param field also shares the
- *   same values.
- *
- *   Usage: we need to aggregate gradients from all workers for the shared
- *   parameters before sending the update request. The nUpdate counter counts
- *   the number.
- *
- * TODO test with different physical architectures.
- */
-class ParamCounter{
-  public:
-  ParamCounter(shared_ptr<Param> p,int local, int owner):
-    nUpdate(0), nGet(0), nPut(0), nCollect(0), nLocal(local), nTotal(1),
-    owner_procs(owner){
-      shares.push_back(p);
-    }
-
-  /**
-   * Associate the counter to a Param object.
-   *
-   * @param p
-   * @param local 1 if this Param object is used by workers in this procs, 0
-   *  otherwise
-   * @param owner the procs id of the worker who ownes this Param object
-   */
-  void AddParam(shared_ptr<Param> p, int local, int owner){
-    nLocal+=local;
-    nTotal+=1;
-    if(owner>-1)
-      owner_procs=owner;
-    if(local>0){
-      shares.push_back(p);
-    }
-  }
-  std::atomic<int> nUpdate, nGet, nPut, nCollect; //!< all counters are atomic
-
-  int nLocal; //!< # local workers uses the shared parameter
-  int nTotal; //!< # total workers uses the shared parameter
-  int owner_procs; //!< the procs id of the worker that owns the parameter
-  vector<shared_ptr<Param>> shares;
-};
-
-/**
- * Parameter manager at the worker side.
- */
-class PMWorker{
-public:
-  /**
-   * Workers from the same group resident in the same process share the same
-   * ParamShard which contains ParamCounters for Param objects used/updated by
-   * these worekrs. Shared Param objects are associated with the same
-   * ParamCounter.
-   */
-  typedef std::map<int, shared_ptr<ParamCounter>> ParamShard;
-
-
-       void Setup(int group_id, int worker_id, shared_ptr<ParamShard> shard);
-
-  void set_id(int group_id, int worker_id){
-    group_id_=group_id;
-    worker_id_=worker_id;
-  }
-
-  /**
-   * @return server id where the parameter is maintained.
-   */
-  virtual int Sharding(int param_id);
-
-       /**
-        * Generate a request message to Get the parameter object.
-        */
-       virtual Msg* Get(shared_ptr<Param> param, int step);
-  virtual Msg* Get(Msg** msg);
-
-       /**
-        * Generate a request message to Update the parameter object.
-        */
-       virtual Msg* Update(shared_ptr<Param> param, int step);
-  virtual Msg* Update(Msg** msg);
-
-       /**
-        * Collect a Param object returned from server.
-        */
-       virtual Msg* Collect(Msg**);
-
-       /**
-        * Generate a request message to Put the parameter object.
-        */
-       virtual Msg* Put(shared_ptr<Param> param, int step);
-  virtual Msg* Put(Msg** msg);
-
- protected:
-  int group_id_, worker_id_;
-  shared_ptr<ParamShard> shard_;
-};
-
-/**
- * Testing worker functionality.The main thread reads the config file and set 
up the socket.
- *
- * Create the shared ParamShard, then starts worker thread which basically 
carries out the work.
- * Each thread creates a PMClient object.
- *
- * The main thread then enter the loops to forward messages.
- *
- * Requests from the worker thread is prepend the paramId, which is stripped 
by the main thread
- * before forwarding to the correct server.
- *
- * The 1st thread in Client 0 populates the servers with data (PUT request). 
Wait
- * for a while before starting the client thread (which does get/update
- * continuously).
-class SingaClient {
-public:
-       SingaClient(int worker_id, Topology &topology, vector<string> &hosts);
-       void StartClient();
-
-       int id() {
-               return id_;
-       }
-       ParamShard *param_shard() {
-               return param_shard_;
-       }
-       char *backend_endpoint() {
-               return backend_endpoint_;
-       }
-
-private:
-       int id_, local_id_, group_id_;
-       char backend_endpoint_[256];
-       vector<char*> neighbors_;
-       ParamShard *param_shard_;
-
-       int param_to_server_id(int paramId);//< mapping paramId to server ID
-};
-
-//Zthread function for the worker thread, in the global namespace.
-//Basically a loop of: compute, get, update, compute, etc.
-void ClientThread(void *args, zctx_t *ctx, void *pipe);
-
-vector<Param*> gen_random_params();
-void test_get(PMClient *client);
-void test_update(PMClient *client, vector<Param*> params);
-void test_collect(PMClient *client);
- */
-
-} // namespace singa
-#endif // INCLUDE_TRAINER_PM_WORKER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/server.h
----------------------------------------------------------------------
diff --git a/include/trainer/server.h b/include/trainer/server.h
index 6ae09e4..f9fc80b 100644
--- a/include/trainer/server.h
+++ b/include/trainer/server.h
@@ -1,21 +1,77 @@
 #ifndef INCLUDE_TRAINER_SERVER_H_
 #define INCLUDE_TRAINER_SERVER_H_
 #include <memory>
-#include "trainer/pm_server.h"
+#include <utils/param.h>
+#include <utils/updater.h>
+#include "proto/model.pb.h"
 #include "communication/socket.h"
 
 using std::shared_ptr;
 namespace singa {
+/* Repsond to worker's get/put/udpate request, and periodically syncing with
+  * other servers.
+  *
+  * Normally, the Server creates a response message for each request which
+  * will be sent back to the one who issued the request. However, if the 
request
+  * are not processed successfully, the original message will be returned. The
+  * sever does not know the returned message (response or the original 
message),
+  * it just sends it to the router. The router will decide to re-send the
+  * request to the server or send it to the worker.
+  */
 class Server{
  public:
+  typedef std::map<int, shared_ptr<Param>> ParamShard;
+
   Server(int thread_id, int group_id, int server_id);
-  void Setup(const UpdaterProto& proto, shared_ptr<PMServer::ParamShard> 
shard);
+  void Setup(const UpdaterProto& proto, shared_ptr<ParamShard> shard);
   void Run();
 
  protected:
+
+       /**
+        * Process GET request.
+   *
+   * @return the orignal message or response message
+   */
+       virtual Msg* HandleGet(shared_ptr<Param> param, Msg** msg);
+
+       /**
+        * Process Update request.
+   *
+   * @return the orignal message or response message
+   */
+       virtual Msg* HandleUpdate(shared_ptr<Param> param, Msg** msg);
+
+       /**
+        * Process PUT request.
+   *
+   * @return the original message or response message. If we don't want need to
+   * acknowledge the put request, then return nullptr.
+        */
+       virtual Msg* HandlePut(shared_ptr<Param> param, Msg **msg);
+
+       /**
+   * TODO Process SYNC request.
+        */
+       virtual Msg* HandleSyncRequest(shared_ptr<Param> param, Msg** msg);
+
+       /**
+   * TODO Process SYNC response.
+        */
+       virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg);
+
+  /**
+   * Scheduler for synchronizing server groups.
+   *
+   * TODO implement the Caffe's synchronization scheduler for data parallelism
+   */
+  virtual bool SyncNow();
+
+ protected:
   int thread_id_,group_id_, server_id_;
-  shared_ptr<PMServer> pmserver_;
   shared_ptr<Dealer> dealer_;
+  shared_ptr<Updater> updater_;
+  shared_ptr<ParamShard> shard_;
 };
 } /* Server */
 #endif //INCLUDE_TRAINER_SERVER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/trainer.h
----------------------------------------------------------------------
diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h
index 34d95f1..f5d2591 100644
--- a/include/trainer/trainer.h
+++ b/include/trainer/trainer.h
@@ -7,8 +7,6 @@
 #include "utils/singleton.h"
 #include "utils/factory.h"
 #include "neuralnet/neuralnet.h"
-#include "trainer/pm_worker.h"
-#include "trainer/pm_server.h"
 #include "trainer/worker.h"
 #include "trainer/server.h"
 
@@ -19,7 +17,61 @@ namespace singa {
  *
  * The main thread runs a loop to forward messages between workers and servers.
  */
+
 class Trainer{
+/**
+ * ParamInfo is used to construct a parameter shard.
+ *
+ * For each worker group:
+ *   Every unique Param object is associated with a ParamCounter object whose
+ *   param field points the to Param object itself.
+ *
+ *   Param objects sharing the same values (due to data parallelism) are
+ *   associated with the same ParamCounter whose param field also shares the
+ *   same values.
+ *
+ *   Usage: we need to aggregate gradients from all workers for the shared
+ *   parameters before sending the update request. The nUpdate counter counts
+ *   the number.
+ *
+ * TODO test with different physical architectures.
+ */
+  public:
+  class ParamInfo{
+   public:
+    ParamInfo(shared_ptr<Param> p,int local, int owner):
+      num_update(0), next_version(0),num_local(local), num_total(1),
+      owner_procs(owner){
+        shares.push_back(p);
+      }
+
+    /**
+      * Associate the counter to a Param object.
+      *
+      * @param p
+      * @param local 1 if this Param object is used by workers in this procs, 0
+      *  otherwise
+      * @param owner the procs id of the worker who ownes this Param object
+      */
+    void AddParam(shared_ptr<Param> p, int local, int owner){
+      num_local+=local;
+      num_total+=1;
+      if(owner>-1)
+        owner_procs=owner;
+      if(local>0){
+        shares.push_back(p);
+      }
+    }
+    int num_update, next_version; //!< all counters are atomic
+
+    int num_local; //!< # local workers uses the shared parameter
+    int num_total; //!< # total workers uses the shared parameter
+    int owner_procs; //!< the procs id of the worker that owns the parameter
+    vector<shared_ptr<Param>> shares;
+  };
+
+ typedef std::map<int, shared_ptr<ParamInfo>> ParamShard;
+
  public:
   /**
    * Start the training in one process
@@ -34,7 +86,7 @@ class Trainer{
   // point.
 
  protected:
-  void Run();
+  void Run(const std::map<int, shared_ptr<ParamShard>>& shards);
   /**
    * Register default implementations for all base classes used in the system,
    * e.g., the Updater, BaseMsg, etc.
@@ -45,6 +97,39 @@ class Trainer{
    * implementation class as the value, e.g., <"Updater" SGDUpdater>.
    */
   void RegisterDefaultClasses(const singa::ModelProto& proto);
+
+  /**
+   * Workers from the same group resident in the same process share the same
+   * ParamShard which contains ParamCounters for Param objects used/updated by
+   * these worekrs. Shared Param objects are associated with the same
+   * ParamCounter.
+   */
+
+  /**
+   * @return server id where the parameter is maintained.
+   */
+  virtual int Sharding(int param_id);
+
+       /**
+        * Generate a request message to Get the parameter object.
+        */
+       virtual Msg* HandleGet(shared_ptr<ParamInfo>counter, Msg** msg);
+       virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+
+       /**
+        * Generate a request message to Update the parameter object.
+        */
+       virtual Msg* HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg);
+  virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg);
+
+  /**
+        * Generate a request message to Put the parameter object.
+        */
+       virtual Msg* HandlePut(shared_ptr<ParamInfo>counter, Msg** msg);
+       virtual Msg* HandleConnect(Msg** msg);
+
+ protected:
+  int procs_id_;
 };
 } /* singa */
 #endif // INCLUDE_TRAINER_TRAINER_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/trainer/worker.h
----------------------------------------------------------------------
diff --git a/include/trainer/worker.h b/include/trainer/worker.h
index 13f7798..afa56ae 100644
--- a/include/trainer/worker.h
+++ b/include/trainer/worker.h
@@ -4,35 +4,12 @@
 #include <exception>
 #include "neuralnet/neuralnet.h"
 #include "proto/model.pb.h"
-#include "trainer/pm_worker.h"
 #include "utils/cluster.h"
 #include "communication/socket.h"
 #include "communication/msg.h"
 
 namespace singa {
 /**
- * Collecting metrics, like accuracy, loss, etc.
- */
-class Performance{
- public:
-  /**
-   * Collect from LossLayer of net.
-   */
-  explicit Performance(shared_ptr<NeuralNet> net);
-  /**
-   * aggregate metrics from LossLayerS
-   */
-  void Update();
-  void Reset();
-  string ToString();
- private:
-  vector<string> name_;
-  shared_ptr<NeuralNet> net_;
-  vector<vector<float>> metric_;
-  int counter_; //!< inc by 1 for every Update
-};
-
-/**
  * The Worker class which runs the training algorithm.
  * The first worker group will initialize parameters of the Net,
  * and put them into the distributed memory/table.
@@ -41,8 +18,7 @@ class Worker {
  public:
   Worker(int thread_id, int group_id, int worker_id);
   ~Worker(){}
-  void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net,
-      shared_ptr<PMWorker::ParamShard> shard);
+  void Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net);
   void set_test_net(shared_ptr<NeuralNet> test_net){
     test_net_=test_net;
   }
@@ -61,7 +37,7 @@ class Worker {
     * Hence, no need to collect performance in every thread.
     * Only the main thread will pass none null perf.
     */
-  void RunOneBatch(int step, Performance* perf=nullptr);
+  void RunOneBatch(int step, Metric* perf=nullptr);
   /**
     * Train one mini-batch.
     * Test/Validation is done before training.
@@ -105,6 +81,7 @@ class Worker {
   const bool DisplayDebugInfo(const int step) const {
     return DisplayNow(step)&&modelproto_.debug()&&group_id_==0;
   }
+  const void DisplayPerformance(const Metric & perf, const string& prefix);
 
   /**
    * return true if the stop condition is satisfied, e.g., the maximum number
@@ -163,7 +140,6 @@ class Worker {
   int thread_id_,group_id_, worker_id_;
   int step_;
   ModelProto modelproto_;
-  shared_ptr<PMWorker> pmworker_;
   shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
   shared_ptr<Dealer> layer_dealer_, param_dealer_;
   Poller layer_poller_, param_poller_;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/include/utils/common.h
----------------------------------------------------------------------
diff --git a/include/utils/common.h b/include/utils/common.h
index 993c153..5a59127 100644
--- a/include/utils/common.h
+++ b/include/utils/common.h
@@ -47,5 +47,36 @@ inline float rand_real(){
   return  static_cast<float>(rand())/(RAND_MAX+1.0f);
 }
 
+class Metric{
+ public:
+  Metric():counter_(0){}
+  void AddMetric(string name, float value){
+    if(data_.find(name)==data_.end())
+      data_[name]=value;
+    else
+      data_[name]+=value;
+  }
+  void Reset(){
+    data_.clear();
+    counter_=0;
+  }
+  void Avg(){
+    for(auto& entry: data_)
+      entry.second/=counter_;
+  }
+  void Inc(){
+    counter_++;
+  }
+  const string ToString() const{
+    string disp="";
+    for(const auto& entry: data_){
+      disp+=entry.first+":"+std::to_string(entry.second)+"\t";
+    }
+    return disp;
+  }
+ private:
+  map<string, float> data_;
+  int counter_;
+};
 } /* singa */
 #endif  // INCLUDE_UTILS_COMMON_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/communication/socket.cc
----------------------------------------------------------------------
diff --git a/src/communication/socket.cc b/src/communication/socket.cc
index ef6174a..385950b 100644
--- a/src/communication/socket.cc
+++ b/src/communication/socket.cc
@@ -29,10 +29,11 @@ int Dealer::Connect(string endpoint){
     CHECK_EQ(zsock_connect(dealer_,"%s", endpoint.c_str()),0);
   return 1;
 }
-int Dealer::Send(Msg *msg){
-  zmsg_t* zmsg=(static_cast<Msg*>(msg))->DumpToZmsg();
+int Dealer::Send(Msg** msg){
+  zmsg_t* zmsg=(*msg)->DumpToZmsg();
   zmsg_send(&zmsg, dealer_);
-  delete msg;
+  delete *msg;
+  *msg=NULL;
   return 1;
 }
 
@@ -61,9 +62,9 @@ int Router::Bind(string endpoint){
   return 1;
 }
 
-int Router::Send(Msg *msg){
-  zmsg_t* zmsg=static_cast<Msg*>(msg)->DumpToZmsg();
-  int dstid=static_cast<Msg*>(msg)->dst();
+int Router::Send(Msg **msg){
+  zmsg_t* zmsg=(*msg)->DumpToZmsg();
+  int dstid=(*msg)->dst();
   if(id2addr_.find(dstid)!=id2addr_.end()){
     // the connection has already been set up
     zframe_t* addr=zframe_dup(id2addr_[dstid]);
@@ -77,7 +78,8 @@ int Router::Send(Msg *msg){
     nBufmsg_++;
     CHECK_LE(nBufmsg_, bufsize_);
   }
-  delete msg;
+  delete *msg;
+  *msg=NULL;
   return 1;
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/main.cc
----------------------------------------------------------------------
diff --git a/src/main.cc b/src/main.cc
index 89306d8..eaf88cc 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -20,6 +20,7 @@
 DEFINE_int32(procsID, 0, "Global process ID");
 DEFINE_string(cluster, "examples/mnist/cluster.conf", "Cluster config file");
 DEFINE_string(model, "examples/mnist/conv.conf", "Model config file");
+DEFINE_int32(sleep, 5, "sleep seconds");
 
 /**
  * Register layers, and other customizable classes.

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/neuralnet/base_layer.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/base_layer.cc b/src/neuralnet/base_layer.cc
index e7702e9..d3ff24b 100644
--- a/src/neuralnet/base_layer.cc
+++ b/src/neuralnet/base_layer.cc
@@ -180,19 +180,20 @@ PrefetchLayer::~PrefetchLayer(){
 /************* Implementation for SliceLayer****************/
 void SliceLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){
-  int slice_dim=proto.slice_param().slice_dimension();
-  int slice_num=proto.slice_param().slice_num();
-  CHECK_GE(slice_dim,0);
-  CHECK_EQ(slice_num, dstlayers_.size());
+  slice_dim_=proto.slice_param().slice_dimension();
+  slice_num_=proto.slice_param().slice_num();
+  CHECK_GE(slice_dim_,0);
+  CHECK_EQ(slice_num_, dstlayers_.size());
   data_.Reshape(srclayers[0]->data(this).shape());
   grad_.ReshapeLike(data_);
-  datavec_.resize(slice_num);
-  gradvec_.resize(slice_num);
+  datavec_.resize(slice_num_);
+  gradvec_.resize(slice_num_);
+  CHECK_EQ(data_.count()%slice_num_, 0); // restrict equal slicing
   //LOG(ERROR)<<"slice dim "<<slice_dim<<" slice num "<<slice_num;
-  for(int i=0;i<slice_num;i++){
+  for(int i=0;i<slice_num_;i++){
     vector<int> newshape(data_.shape());
-    newshape[slice_dim]=newshape[slice_dim]/slice_num+
-      ((i==slice_num-1)?newshape[slice_dim]%slice_num:0);
+    newshape[slice_dim_]=newshape[slice_dim_]/slice_num_+
+      ((i==slice_num_-1)?newshape[slice_dim_]%slice_num_:0);
     datavec_[i].Reshape(newshape);
     gradvec_[i].Reshape(newshape);
     //LOG(ERROR)<<"slice "<<IntVecToString(newshape);
@@ -236,8 +237,22 @@ Blob<float>* SliceLayer::mutable_grad(const Layer* layer){
     return &grad_;
   return &gradvec_[SliceID(layer)];
 }
-void SliceLayer::ComputeFeature(bool training, const 
vector<shared_ptr<Layer>>& srclayers){}
-void SliceLayer::ComputeGradient(const vector<shared_ptr<Layer>>& srclayers){}
+void SliceLayer::ComputeFeature(bool training,
+    const vector<shared_ptr<Layer>>& srclayers){
+  CHECK_EQ(srclayers.size(),1);
+  if(slice_dim_==0){
+    const auto& blob=srclayers.at(0)->data(this);
+    int size=blob.count()/slice_num_;
+    for(int i=0;i<slice_num_;i++){
+      float* dst=datavec_[i].mutable_cpu_data();
+      const float* src=blob.cpu_data()+i*size;
+      memcpy(dst, src, size*sizeof(float));
+    }
+  }
+}
+void SliceLayer::ComputeGradient(const vector<shared_ptr<Layer>>& srclayers){
+
+}
 
 void SplitLayer::Setup(const LayerProto& proto,
     const vector<SLayer>& srclayers){

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/neuralnet/neuralnet.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuralnet.cc b/src/neuralnet/neuralnet.cc
index 4dac512..ca371f6 100644
--- a/src/neuralnet/neuralnet.cc
+++ b/src/neuralnet/neuralnet.cc
@@ -5,7 +5,7 @@
 #include "utils/singleton.h"
 #include "utils/factory.h"
 #include "utils/graph.h"
-
+#include "utils/cluster.h"
 
 namespace singa {
 #define CreateLayer(id) CreateInstance(id, Layer)
@@ -61,8 +61,21 @@ NeuralNet::NeuralNet(NetProto net_proto, int group_size) {
 
   LOG(INFO)<<"Construct Neural Net...";
   ConstructNeuralNet(net_proto);
-  if(group_size_>1)
+  {
+    string vis_folder=Cluster::Get()->vis_folder();
+    std::ofstream fout(vis_folder+"/nopartition.json", std::ofstream::out);
+    fout<<ToString();
+    fout.flush();
+    fout.close();
+  }
+  if(group_size_>1){
     PartitionNeuralNet();
+    string vis_folder=Cluster::Get()->vis_folder();
+    std::ofstream fout(vis_folder+"/partition.json", std::ofstream::out);
+    fout<<ToString();
+    fout.flush();
+    fout.close();
+  }
   for(auto layer: layers_){
     DLOG(INFO)<<layer->name();
   }
@@ -88,7 +101,7 @@ void NeuralNet::ConstructNeuralNet(const NetProto& 
net_proto){
 
   // topology sort
   graph_.Sort();
-  //DLOG(INFO)<<"pure graph without partition\n"<< graph_.ToString();
+  //LOG(ERROR)<<"pure graph without partition\n"<< graph_.ToString();
 
   auto* factory=Singleton<Factory<Layer>>::Instance();
   // create Layers according to topology order

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/proto/model.pb.h
----------------------------------------------------------------------
diff --git a/src/proto/model.pb.h b/src/proto/model.pb.h
index 4f68462..567021a 100644
--- a/src/proto/model.pb.h
+++ b/src/proto/model.pb.h
@@ -567,6 +567,13 @@ class ModelProto : public ::google::protobuf::Message {
   inline bool debug() const;
   inline void set_debug(bool value);
 
+  // optional int32 warmup_steps = 50 [default = 0];
+  inline bool has_warmup_steps() const;
+  inline void clear_warmup_steps();
+  static const int kWarmupStepsFieldNumber = 50;
+  inline ::google::protobuf::int32 warmup_steps() const;
+  inline void set_warmup_steps(::google::protobuf::int32 value);
+
   // @@protoc_insertion_point(class_scope:singa.ModelProto)
  private:
   inline void set_has_name();
@@ -613,6 +620,8 @@ class ModelProto : public ::google::protobuf::Message {
   inline void clear_has_neuralnet();
   inline void set_has_debug();
   inline void clear_has_debug();
+  inline void set_has_warmup_steps();
+  inline void clear_has_warmup_steps();
 
   ::google::protobuf::UnknownFieldSet _unknown_fields_;
 
@@ -641,9 +650,10 @@ class ModelProto : public ::google::protobuf::Message {
   bool debug_;
   int alg_;
   ::singa::NetProto* neuralnet_;
+  ::google::protobuf::int32 warmup_steps_;
 
   mutable int _cached_size_;
-  ::google::protobuf::uint32 _has_bits_[(22 + 31) / 32];
+  ::google::protobuf::uint32 _has_bits_[(23 + 31) / 32];
 
   friend void  protobuf_AddDesc_model_2eproto();
   friend void protobuf_AssignDesc_model_2eproto();
@@ -3577,13 +3587,6 @@ class UpdaterProto : public ::google::protobuf::Message {
   inline ::google::protobuf::int32 sync_frequency() const;
   inline void set_sync_frequency(::google::protobuf::int32 value);
 
-  // optional int32 warmup_steps = 25 [default = 10];
-  inline bool has_warmup_steps() const;
-  inline void clear_warmup_steps();
-  static const int kWarmupStepsFieldNumber = 25;
-  inline ::google::protobuf::int32 warmup_steps() const;
-  inline void set_warmup_steps(::google::protobuf::int32 value);
-
   // optional float moving_rate = 26 [default = 0];
   inline bool has_moving_rate() const;
   inline void clear_moving_rate();
@@ -3651,8 +3654,6 @@ class UpdaterProto : public ::google::protobuf::Message {
   inline void clear_has_learning_rate_change_method();
   inline void set_has_sync_frequency();
   inline void clear_has_sync_frequency();
-  inline void set_has_warmup_steps();
-  inline void clear_has_warmup_steps();
   inline void set_has_moving_rate();
   inline void clear_has_moving_rate();
   inline void set_has_param_type();
@@ -3671,15 +3672,14 @@ class UpdaterProto : public ::google::protobuf::Message 
{
   ::google::protobuf::int32 learning_rate_change_frequency_;
   int learning_rate_change_method_;
   ::google::protobuf::int32 sync_frequency_;
-  ::google::protobuf::int32 warmup_steps_;
+  float moving_rate_;
   ::std::string* param_type_;
   static ::std::string* _default_param_type_;
   ::google::protobuf::RepeatedField< ::google::protobuf::int32 > step_;
   ::google::protobuf::RepeatedField< float > step_lr_;
-  float moving_rate_;
 
   mutable int _cached_size_;
-  ::google::protobuf::uint32 _has_bits_[(16 + 31) / 32];
+  ::google::protobuf::uint32 _has_bits_[(15 + 31) / 32];
 
   friend void  protobuf_AddDesc_model_2eproto();
   friend void protobuf_AssignDesc_model_2eproto();
@@ -4544,6 +4544,28 @@ inline void ModelProto::set_debug(bool value) {
   debug_ = value;
 }
 
+// optional int32 warmup_steps = 50 [default = 0];
+inline bool ModelProto::has_warmup_steps() const {
+  return (_has_bits_[0] & 0x00400000u) != 0;
+}
+inline void ModelProto::set_has_warmup_steps() {
+  _has_bits_[0] |= 0x00400000u;
+}
+inline void ModelProto::clear_has_warmup_steps() {
+  _has_bits_[0] &= ~0x00400000u;
+}
+inline void ModelProto::clear_warmup_steps() {
+  warmup_steps_ = 0;
+  clear_has_warmup_steps();
+}
+inline ::google::protobuf::int32 ModelProto::warmup_steps() const {
+  return warmup_steps_;
+}
+inline void ModelProto::set_warmup_steps(::google::protobuf::int32 value) {
+  set_has_warmup_steps();
+  warmup_steps_ = value;
+}
+
 // -------------------------------------------------------------------
 
 // NetProto
@@ -7917,37 +7939,15 @@ inline void 
UpdaterProto::set_sync_frequency(::google::protobuf::int32 value) {
   sync_frequency_ = value;
 }
 
-// optional int32 warmup_steps = 25 [default = 10];
-inline bool UpdaterProto::has_warmup_steps() const {
-  return (_has_bits_[0] & 0x00000800u) != 0;
-}
-inline void UpdaterProto::set_has_warmup_steps() {
-  _has_bits_[0] |= 0x00000800u;
-}
-inline void UpdaterProto::clear_has_warmup_steps() {
-  _has_bits_[0] &= ~0x00000800u;
-}
-inline void UpdaterProto::clear_warmup_steps() {
-  warmup_steps_ = 10;
-  clear_has_warmup_steps();
-}
-inline ::google::protobuf::int32 UpdaterProto::warmup_steps() const {
-  return warmup_steps_;
-}
-inline void UpdaterProto::set_warmup_steps(::google::protobuf::int32 value) {
-  set_has_warmup_steps();
-  warmup_steps_ = value;
-}
-
 // optional float moving_rate = 26 [default = 0];
 inline bool UpdaterProto::has_moving_rate() const {
-  return (_has_bits_[0] & 0x00001000u) != 0;
+  return (_has_bits_[0] & 0x00000800u) != 0;
 }
 inline void UpdaterProto::set_has_moving_rate() {
-  _has_bits_[0] |= 0x00001000u;
+  _has_bits_[0] |= 0x00000800u;
 }
 inline void UpdaterProto::clear_has_moving_rate() {
-  _has_bits_[0] &= ~0x00001000u;
+  _has_bits_[0] &= ~0x00000800u;
 }
 inline void UpdaterProto::clear_moving_rate() {
   moving_rate_ = 0;
@@ -7963,13 +7963,13 @@ inline void UpdaterProto::set_moving_rate(float value) {
 
 // optional string param_type = 27 [default = "Param"];
 inline bool UpdaterProto::has_param_type() const {
-  return (_has_bits_[0] & 0x00002000u) != 0;
+  return (_has_bits_[0] & 0x00001000u) != 0;
 }
 inline void UpdaterProto::set_has_param_type() {
-  _has_bits_[0] |= 0x00002000u;
+  _has_bits_[0] |= 0x00001000u;
 }
 inline void UpdaterProto::clear_has_param_type() {
-  _has_bits_[0] &= ~0x00002000u;
+  _has_bits_[0] &= ~0x00001000u;
 }
 inline void UpdaterProto::clear_param_type() {
   if (param_type_ != _default_param_type_) {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index 950bc2e..19727a9 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -78,6 +78,7 @@ message ModelProto{
   optional bool hogwild=33 [default=false];
   optional NetProto neuralnet = 40;
   optional bool debug=41 [default=false];
+  optional int32 warmup_steps=50 [default=0];
 }
 
 message NetProto{
@@ -366,7 +367,6 @@ message UpdaterProto {
   optional ChangeProto learning_rate_change_method = 16 [default = kFixed];
   optional int32 sync_frequency=17 [default=1];
   // warmup the parameters and then send to parameter servers.
-  optional int32 warmup_steps=25 [default=10];
   optional float moving_rate=26 [default=0];
   optional string param_type=27[default="Param"];
   repeated int32 step=28;

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/pm_server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_server.cc b/src/trainer/pm_server.cc
deleted file mode 100644
index 28fa28d..0000000
--- a/src/trainer/pm_server.cc
+++ /dev/null
@@ -1,99 +0,0 @@
-#include <gflags/gflags.h>
-#include <glog/logging.h>
-#include "trainer/pm_server.h"
-#include "utils/singleton.h"
-#include "utils/factory.h"
-#include <vector>
-
-using std::vector;
-
-namespace singa{
-void PMServer::Setup(int group_id, int server_id, shared_ptr<ParamShard> shard,
-      const UpdaterProto& proto){
-  group_id_=group_id;
-  server_id_=server_id;
-  shard_=shard;
-  updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
-      ->Create("Updater"));
-  updater_->Init(proto);
-}
-
-PMServer::~PMServer(){
-}
-
-bool PMServer::SyncNow(){
-  return false;
-}
-Msg* PMServer::HandlePut(Msg **msg){
-  int id=(*msg)->target();
-  shared_ptr<Param> param=nullptr;
-  if(shard_->find(id)!=shard_->end()){
-    LOG(ERROR)<<"Param ("<<id<<") is put more than once";
-    param=shard_->at(id);
-  }else{
-    param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance()
-        ->Create("Param"));
-    param->set_id(id);
-    (*shard_)[id]=param;
-  }
-  return param->HandlePutMsg(msg);
-}
-
-Msg* PMServer::HandleGet(Msg **msg){
-  int id=(*msg)->target();
-  shared_ptr<Param> param=nullptr;
-  if(shard_->find(id)!=shard_->end()){
-    param=shard_->at(id);
-    return param->HandleGetMsg(msg);
-       } else {
-               //re-construct msg to be re-queued.
-               //the calling function will send this message off
-    return *msg;
-       }
-}
-
-Msg* PMServer::HandleUpdate(Msg **msg) {
-  int id=(*msg)->target();
-  shared_ptr<Param> param=nullptr;
-  if(shard_->find(id)!=shard_->end()){
-               //repsonse of the format: <identity><type: 
kData><paramId><param content>
-    param=shard_->at(id);
-    Msg* tmp=static_cast<Msg*>((*msg)->CopyAddr());
-    param->ParseUpdateMsg(msg);
-    updater_->Update(param->version(), param);
-    param->set_version(param->version()+1);
-    auto response=param->GenUpdateResponseMsg();
-    tmp->SwapAddr();
-    response->SetAddr(tmp);
-    delete tmp;
-    return response;
-       } else {
-    LOG(ERROR)<<"Param ("<<id<<") is not maintained by server ("<<group_id_
-      <<", "<<server_id_<<")";
-               //re-construct msg to be re-queued.
-               return *msg;
-       }
-}
-
-Msg* PMServer::HandleSyncRequest(Msg **msg){
-  int id=(*msg)->target();
-  shared_ptr<Param> param=nullptr;
-  if(shard_->find(id)!=shard_->end()){
-               //repsonse of the format: <identity><type: 
kData><paramId><param content>
-    param=shard_->at(id);
-    return param->HandleSyncMsg(msg);
-       } else {
-               //re-construct msg to be re-queued.
-    return *msg;
-       }
-}
-
-int PMServer::HandleSyncResponse(Msg **msg){
-  int id=(*msg)->target();
-  CHECK(shard_->find(id)!=shard_->end());
-  return shard_->at(id)->ParseSyncResponseMsg(msg);
-}
-
-} // namespace singa
-
-

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/pm_worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc
deleted file mode 100644
index d2531e0..0000000
--- a/src/trainer/pm_worker.cc
+++ /dev/null
@@ -1,324 +0,0 @@
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <fcntl.h>
-#include "gflags/gflags.h"
-#include <glog/logging.h>
-#include "proto/model.pb.h"
-#include "trainer/pm_worker.h"
-#include "mshadow/tensor.h"
-#include "utils/cluster.h"
-
-
-namespace singa{
-
-void PMWorker::Setup(int group_id, int worker_id,
-    shared_ptr<ParamShard> shard){
-  group_id_=group_id;
-  worker_id_=worker_id;
-  shard_=shard;
-}
-int PMWorker::Sharding(int param_id){
-  return param_id%Cluster::Get()->nservers_per_group();
-}
-/*
-int PMWorker::Sharding(int param_id){
-  static map<int, int> id2procs;
-  if(id2procs.find(param_id)==id2procs.end()){
-  auto cluster=Cluster::Get();
-  int server_group=group_id_%cluster->nserver_groups();
-  int nprocs_per_server_group=
-    cluster->nservers_per_group()/cluster->nservers_per_procs();
-  int procsid=server_group*nprocs_per_server_group+
-    param_id%nprocs_per_server_group;
-  procsid= cluster->server_worker_separate()?
-    cluster->nworker_procs()+procsid:procsid;
-  id2procs[param_id]=procsid;
-  }
-  return id2procs[param_id];
-}
-*/
-
-Msg* PMWorker::Put(Msg** msg){
-  return *msg;
-}
-
-Msg* PMWorker::Put(shared_ptr<Param> param, int step){
-  int id=param->owner();
-  auto entry=shard_->at(id);
-  Msg* msg= param->GenPutMsg(&step);
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-      Sharding(id), kServer);
-  msg->set_type(kPut);
-  msg->set_target(id);
-  return msg;
-}
-
-Msg* PMWorker::Get(Msg** msg){
-  return *msg;
-}
-
-Msg* PMWorker::Get(shared_ptr<Param> param, int step){
-  int id=param->owner();
-  shared_ptr<ParamCounter> entry=shard_->at(id);
-  Msg *msg=nullptr;
-  if((entry->nGet+1)%entry->nLocal==0&&param->version()<step){
-    msg=param->GenGetMsg(&step);
-    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-        Sharding(id), kServer);
-    msg->set_src(group_id_, worker_id_, kWorkerParam);
-    msg->set_type(kGet);
-    msg->set_target(id);
-  }
-  entry->nGet++;
-  return msg;
-}
-
-Msg* PMWorker::Update(Msg** msg){
-  return *msg;
-}
-Msg* PMWorker::Update(shared_ptr<Param> param, int step){
-  int id=param->owner();
-  shared_ptr<ParamCounter> entry=shard_->at(id);
-  Msg* msg=nullptr;
-  if((entry->nUpdate+1)%entry->nLocal==0){
-    auto shape=mshadow::Shape1(param->size());
-    auto it=entry->shares.begin();
-    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
-    for(++it;it!=entry->shares.end();it++){
-      mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
-      agg+=grad/entry->nTotal;
-    }
-    msg=entry->shares.at(0)->GenUpdateMsg(&step);
-    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
-        Sharding(id), kServer);
-    /*
-       entry->param->GenUpdateMsg(&step);
-       msg->set_dst(entry->owner_procs,kStub);
-       memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size());
-       */
-    msg->set_type(kUpdate);
-    msg->set_target(id);
-    msg->set_src(group_id_, worker_id_, kWorkerParam);
-  }
-  entry->nUpdate++;
-  return msg;
-}
-
-Msg* PMWorker::Collect(Msg** msg){
-  int id=(*msg)->target();
-  int type=(*msg)->type();
-  auto pp=shard_->at(id)->shares.at(0);
-  if(type==kRGet){
-    pp->ParseGetResponseMsg(msg);
-  }else if(type==kRUpdate){
-    pp->ParseUpdateResponseMsg(msg);
-  }
-  if(pp->owner()!=pp->id()){
-    // forwarding to workers on other procs
-  }
-  delete (*msg);
-  *msg=nullptr;
-  return nullptr;
-}
-
-/*
-//id is the global worker id
-SingaClient::SingaClient(int global_id, Topology &topology, vector<string> 
&hosts) {
-       //Read the config files and store endpoints
-       id_ = global_id;
-
-       int n_workers = hosts.size() - topology.nservers();
-       int n_worker_groups = topology.nworker_groups();
-       int group_size = n_workers/n_worker_groups;
-       int server_group_size = 
topology.nservers()/topology.server_group_size();
-       FLAGS_client_threads = topology.worker_threads();
-
-       local_id_ = (id_-topology.nservers())%group_size;//local worker id.
-       group_id_ = (id_-topology.nservers())/group_size;
-
-       VLOG(3) << "Parsing client config for "<<hosts[id_];
-
-       //connect to all server in the server group group_id_
-       int start_server_idx = group_id_*server_group_size;
-       int end_server_idx = start_server_idx+server_group_size;
-
-       for (int i = start_server_idx; i < end_server_idx; i++) {
-               char *neighbor_endpoint = (char*) malloc(256);
-               sprintf(neighbor_endpoint, "tcp://%s:%d", hosts[i].c_str(), 
topology.port());
-               neighbors_.push_back(neighbor_endpoint);
-               VLOG(3) << "Worker neighbor (server): "<<neighbor_endpoint;
-       }
-
-       sprintf(backend_endpoint_, "inproc://singanus%d",id_);
-
-       //Create shared paramshard
-       param_shard_ = new ParamShard(id_,0);
-}
-
-void SingaClient::StartClient(){
-       //Create and connect sockets to the server
-       vector<void *> server_sockets;
-       zctx_t *context = zctx_new();
-       int nservers = neighbors_.size();
-       int rc;
-       for (int i=0; i<nservers; i++){
-               void *socket = zsocket_new(context, ZMQ_DEALER);
-               rc = zsocket_connect(socket, neighbors_[i]);
-               VLOG(3) << "Connected to neighbor " <<neighbors_[i];
-               assert(rc==0);
-               server_sockets.push_back(socket);
-       }
-
-       //Create and bind backend socket
-       void *backend = zsocket_new(context, ZMQ_ROUTER);
-       rc = zsocket_bind(backend, backend_endpoint_);
-       assert(rc==0);
-
-       //Start client threads
-       for (int i=0; i<FLAGS_client_threads; i++){
-               void * socket = zthread_fork(context, ClientThread, this);
-               zmsg_t *control_msg = zmsg_new();
-               if (i==0 && local_id_==0)
-                       zmsg_pushstr(control_msg,POPULATE);
-               else
-                       zmsg_pushstr(control_msg, WAIT);
-               zmsg_send(&control_msg, socket);
-       }
-
-       //Star the message loop
-       bool is_running = true;
-       int nsockets= nservers+1;
-       while (is_running) {
-               zmq_pollitem_t items[nsockets];
-               for (int i = 0; i < nsockets-1; i++)
-                       items[i] = {server_sockets[i], 0, ZMQ_POLLIN, 0};
-               items[nsockets-1] = {backend, 0, ZMQ_POLLIN, 0};
-
-               int rc = zmq_poll(items,nsockets,-1);
-               if (rc<0) break;
-
-               for (int i=0; i<nsockets-1; i++){
-                       if (items[i].revents & ZMQ_POLLIN){
-                               zmsg_t *msg = zmsg_recv(server_sockets[i]);
-                               if (!msg){
-                                       is_running = false;
-                                       break;
-                               }
-                               //forward to backend
-                               zmsg_send(&msg, backend);
-                       }
-               }
-               if (items[nsockets-1].revents & ZMQ_POLLIN){
-                       //compute serverId from paramId and forward to the 
socket
-                       zmsg_t *msg = zmsg_recv(backend);
-                       if (!msg) is_running=false;
-                       zframe_t *identity = zmsg_pop(msg);
-                       zframe_t *type = zmsg_pop(msg);
-                       int paramId;
-                       sscanf(zmsg_popstr(msg), "%d", &paramId);
-                       zmsg_pushstrf(msg,"%d",paramId);
-                       zmsg_prepend(msg,&type);
-                       zmsg_prepend(msg,&identity);
-                       zmsg_send(&msg, 
server_sockets[param_to_server_id(paramId)]);
-               }
-       }
-
-       zsocket_destroy(context, backend);
-       for (int i=0; i<nsockets-1; i++)
-               zsocket_destroy(context, server_sockets[i]);
-       zctx_destroy(&context);
-}
-
-vector<Param*> gen_random_params() {
-       int size[] = { 1960000, 2500, 5000000, 2000, 3000000, 1500, 1500000, 
1000, 500000, 500, 5000, 10 };
-       vector<Param*> params;
-       for (int i = 0; i < 12; i++) {
-               ParamProto proto;
-               proto.set_id(i);
-               proto.set_init_method(ParamProto::kGaussain);
-               Param* p = new Param();
-               p->Setup(proto, vector<int> { size[i] }, 0);
-               p->Init();
-               params.push_back(p);
-       }
-       return params;
-}
-
-//simple mapping
-int SingaClient::param_to_server_id(int paramId){
-       return paramId % neighbors_.size();
-}
-
-void ClientThread(void *args, zctx_t *ctx, void *pipe){
-       SingaClient *client = static_cast<SingaClient*>(args);
-
-       //Create back-end socket and connect to the main thread
-       void *backend = zsocket_new(ctx, ZMQ_DEALER);
-       int rc = zsocket_connect(backend, client->backend_endpoint());
-       assert(rc==0);
-       //Create PMClient object
-       PMClient *pmclient = new PMClient(client->id(), client->param_shard(), 
backend);
-
-       //FOR TESTING ONLY. REMOVE THIS!
-       //wait for control from main thread
-       vector<Param*> params = gen_random_params();
-       zmsg_t *control_msg = zmsg_recv(pipe);
-       zframe_t *msg = zmsg_pop(control_msg);
-       if (zframe_streq(msg,WAIT))
-               zclock_sleep(2000); //2s
-       else{
-               for (int i=0; i<params.size(); i++){
-                       pmclient->Put(i, params[i]);
-               }
-               VLOG(3)<<"Done PUT requests for populating servers.";
-               zclock_sleep(2000);
-       }
-       zframe_destroy(&msg);
-       //END TESTING
-       LOG(ERROR) << "Done putting";
-
-       //first, get the params
-
-       test_get(pmclient);
-       test_collect(pmclient);
-
-
-       int iterations = 1;
-       while (iterations<=200){
-               VLOG(3) << "Iteration "<<iterations;
-               test_update(pmclient, params);
-               test_collect(pmclient);
-               iterations++;
-       }
-
-       zsocket_destroy(ctx, backend);
-}
-
-void test_get(PMClient *client){
-       for (int i=0; i<12; i++){
-               Param pm;
-               int status = client->Get(i, &pm);
-               assert(status==NON_LOCAL);
-       }
-}
-
-void test_collect(PMClient *client){
-       for (int i=0; i<12; i++){
-               Param pm;
-               int64_t start_time = zclock_time();
-               while (!client->Collect(&pm))
-                       zclock_sleep(1);
-               int64_t end_time = zclock_time();
-               VLOG(3) << "Collected: " <<(end_time-start_time);
-       }
-}
-
-void test_update(PMClient *client, vector<Param*> params){
-       for (int i=0; i<params.size(); i++)
-               client->Update(i, params[i]);
-}
-*/
-
-
-} //namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
index 36c1302..cd2bc02 100644
--- a/src/trainer/server.cc
+++ b/src/trainer/server.cc
@@ -13,11 +13,12 @@ Server::Server(int thread_id,int group_id, int server_id):
   thread_id_(thread_id),group_id_(group_id), server_id_(server_id){}
 
 void Server::Setup(const UpdaterProto& proto,
-    shared_ptr<PMServer::ParamShard> shard){
+    shared_ptr<Server::ParamShard> shard){
        //VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id 
= " <<id_;
-  pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance()
-      ->Create("PMServer"));
-  pmserver_->Setup(group_id_, server_id_, shard, proto);
+  updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
+      ->Create("Updater"));
+  updater_->Init(proto);
+  shard_=shard;
 }
 
 void Server::Run(){
@@ -26,12 +27,10 @@ void Server::Run(){
 
   Msg* ping=new Msg();
   ping->set_src(group_id_, server_id_, kServer);
-  ping->set_dst(0,0,kStub);
+  ping->set_dst(-1,-1,kStub);
   ping->add_frame("PING", 4);
   ping->set_type(kConnect);
-  dealer_->Send(ping);
-  Poller poller;
-  poller.Add(dealer_.get());
+  dealer_->Send(&ping);
        //start recv loop and process requests
   while (true){
     Msg* msg=dealer_->Receive();
@@ -39,39 +38,89 @@ void Server::Run(){
       break;
     Msg* response=nullptr;
     int type=msg->type();
-    switch (type){
-      case kConnect:{
-        string pong((char*)msg->frame_data(), msg->frame_size());
-        CHECK_STREQ("PONG", pong.c_str());
-        delete msg;
-        break;
-                    }
-      case kPut:
-        response = pmserver_->HandlePut(&msg);
-        break;
-      case kGet:
-        response = pmserver_->HandleGet(&msg);
-        break;
-      case kUpdate:
-        response = pmserver_->HandleUpdate(&msg);
-        break;
-      case kSyncRequest:
-        VLOG(3)<<"Handle SYNC-REQUEST";
-        response = pmserver_->HandleSyncRequest(&msg);
-        break;
-      case kSyncResponse:
-        VLOG(3) << "Handle SYNC response";
-        pmserver_->HandleSyncResponse(&msg);
-        break;
-    }
-
-    if (response!=nullptr){
-      //LOG(ERROR)<<"type: "<<type<<" response to "<<response->dst_id();
-      dealer_->Send(response);
+    if (type==kConnect){
+      // TODO remove receiving pong msg
+      string pong((char*)msg->frame_data(), msg->frame_size());
+      CHECK_STREQ("PONG", pong.c_str());
+      delete msg;
+    }else if(type==kPut){
+      int pid=msg->target_first();
+      shared_ptr<Param> param=nullptr;
+      if(shard_->find(pid)!=shard_->end()){
+        LOG(ERROR)<<"Param ("<<pid<<") is put more than once";
+        param=shard_->at(pid);
+      }else{
+        param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance()
+            ->Create("Param"));
+        param->set_id(pid);
+        (*shard_)[pid]=param;
+      }
+      response = HandlePut(param, &msg);
+    }else{
+      int pid=msg->target_first();
+      if(shard_->find(pid)==shard_->end()){
+        // delay the processing by re-queue the msg.
+        response=msg;
+      } else{
+        CHECK(shard_->find(pid)!=shard_->end()) <<"Param ("<<pid
+          <<") is not maintained by server ("
+          <<group_id_ <<", " <<server_id_<<")";
+        auto param=shard_->at(pid);
+        switch (type){
+          case kGet:
+            response=HandleGet(param, &msg);
+            break;
+          case kUpdate:
+            response = HandleUpdate(param, &msg);
+            break;
+          case kSyncRequest:
+            VLOG(3)<<"Handle SYNC-REQUEST";
+            response = HandleSyncRequest(param, &msg);
+            break;
+          case kSyncResponse:
+            VLOG(3) << "Handle SYNC response";
+            HandleSyncResponse(param, &msg);
+            break;
+        }
+        if (response!=nullptr){
+          dealer_->Send(&response);
+        }
+      }
     }
   }
 }
 
+bool Server::SyncNow(){
+  return false;
+}
+Msg* Server::HandlePut(shared_ptr<Param> param, Msg **msg){
+  return param->HandlePutMsg(msg);
+}
+
+Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){
+  return param->HandleGetMsg(msg);
+}
 
+Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) {
+  //repsonse of the format: <identity><type: kData><paramId><param content>
+  auto* tmp=static_cast<Msg*>((*msg)->CopyAddr());
+  int v=(*msg)->target_second()+1;
+  param->ParseUpdateMsg(msg);
+  updater_->Update(param->version(), param);
+  param->set_version(param->version()+1);
+  auto response=param->GenUpdateResponseMsg(&v);
+  tmp->SwapAddr();
+  response->SetAddr(tmp);
+  delete tmp;
+  return response;
+}
+
+Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){
+  return param->HandleSyncMsg(msg);
+}
+
+int Server::HandleSyncResponse(shared_ptr<Param> param, Msg **msg){
+  return param->ParseSyncResponseMsg(msg);
+}
 
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index 4ac51ce..35b8f6c 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -3,6 +3,7 @@
 #include <map>
 #include <glog/logging.h>
 #include "trainer/trainer.h"
+#include "mshadow/tensor.h"
 using std::vector;
 using std::map;
 
@@ -33,16 +34,11 @@ void Trainer::RegisterDefaultClasses(const 
singa::ModelProto& proto){
       "Param", CreateInstance(singa::Param, singa::Param));
   Singleton<Factory<singa::Updater>>::Instance() ->Register(
       "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
-  Singleton<Factory<singa::PMWorker>>::Instance() ->Register(
-      "PMWorker", CreateInstance(singa::PMWorker, singa::PMWorker));
-  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
-      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
-  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
-      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
 }
 
 void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
     int procs_id){
+  procs_id_=procs_id;
   RegisterDefaultClasses(mproto);
 
   auto cluster=Cluster::Get(cproto, procs_id);
@@ -57,7 +53,7 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
     int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
     int end=start+cluster->nservers_per_group();
     // the ParamShard for servers consists of a dictionary of Param objects
-    auto shard=make_shared<PMServer::ParamShard>();
+    auto shard=make_shared<Server::ParamShard>();
     for(int sid=start;sid<end;sid++){
       auto server=make_shared<Server>(nthreads++,gid, sid);
       server->Setup(mproto.updater(), shard);
@@ -67,6 +63,7 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
 
   // create workers
   vector<shared_ptr<Worker>> workers;
+  std::map<int, shared_ptr<Trainer::ParamShard>> shards;
   if(cluster->has_worker()){
     auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain,
         cluster->nworkers_per_group());
@@ -117,14 +114,15 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
         }
       }
       // create ParamShard for the workers
-      auto shard=make_shared<PMWorker::ParamShard>();
+      auto shard=make_shared<Trainer::ParamShard>();
+      shards[gid]=shard;
       for(auto layer: train_net->layers()){
         int procsid=ProcsIDOf(gid, layer->partitionid(),kWorkerParam);
         int local=procsid==cluster->procs_id();
         for(auto param: layer->GetParams()){
           int owner=param->owner()<0||param->owner()==param->id()?procsid:-1;
           if(shard->find(param->owner())==shard->end())
-            (*shard)[param->owner()]=make_shared<ParamCounter>(param, local, 
owner);
+            (*shard)[param->owner()]=make_shared<ParamInfo>(param, local, 
owner);
           else
             shard->at(param->owner())->AddParam(param, local, owner);
         }
@@ -136,7 +134,7 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
         else{
         // TODO add CDWorker
         }
-        worker->Setup(mproto, train_net, shard);
+        worker->Setup(mproto, train_net);
         worker->set_test_net(test_net);
         worker->set_validation_net(validation_net);
         workers.push_back(worker);
@@ -154,13 +152,14 @@ void Trainer::Start(const ModelProto& mproto, const 
ClusterProto& cproto,
     threads.push_back(std::thread(&Server::Run,server.get()));
   for(auto worker: workers)
     threads.push_back(std::thread(&Worker::Run,worker.get()));
-  Run();
+  Run(shards);
   for(auto& thread: threads)
     thread.join();
 }
 
-void Trainer::Run(){
+void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& 
shards){
   auto cluster=Cluster::Get();
+  procs_id_=cluster->procs_id();
   auto router=make_shared<Router>();
   router->Bind(kInprocRouterEndpoint);
   if(cluster->nprocs()>1)
@@ -173,38 +172,179 @@ void Trainer::Run(){
       LOG(ERROR)<<"Connection broken!";
       exit(0);
     }
-    int dst_flag=msg->dst_flag();
-    int type=msg->type();
-    int group_id, id, procs_id;
-    switch (dst_flag){ // TODO process other requests, e.g. RESTful
-      case kStub:
+    while(msg!=nullptr){
+      int dst_flag=msg->dst_flag();
+      int type=msg->type();
+      int dst_procs=msg->dst_first();
+      if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){
         if(type==kConnect){
-          string ping((char*)msg->frame_data(), msg->frame_size());
-          CHECK_STREQ("PING", ping.c_str());
-          msg->SwapAddr();
-          Msg* reply=new Msg();
-          reply->SetAddr(msg);
-          reply->add_frame("PONG", 4);
-          reply->set_type(kConnect);
-          delete msg;
-          router->Send(reply);
+          msg =HandleConnect(&msg);
         }else{
-          // TODO processing requests for worker group spanning multiple procs.
-          LOG(ERROR)<<"Unkown message type ("<<type<<") to stub";
+          int group_id=msg->src_first();
+          int paramid=msg->target_first();
+          auto entry=shards.at(group_id)->at(paramid);
+          switch (type){ // TODO process other requests, e.g. RESTful
+            case kUpdate:
+              msg=HandleUpdate(entry, &msg);
+              break;
+            case kRUpdate:
+              HandleUpdateResponse(entry, &msg);
+              break;
+            case kGet:
+              msg=HandleGet(entry, &msg);
+              break;
+            case kRGet:
+              msg=HandleGetResponse(entry, &msg);
+              break;
+            case kPut:
+              msg=HandlePut(entry, &msg);
+              break;
+            default:
+              break;
+          }
         }
-        break;
-      default:
-        group_id=msg->dst_group_id();
-        id=msg->dst_id();
-        procs_id=ProcsIDOf(group_id, id, dst_flag);
-        if(procs_id!=cluster->procs_id()){
-          if (interprocs_dealers.find(procs_id)==interprocs_dealers.end())
-            interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id);
-          interprocs_dealers[procs_id]->Send(msg);
-        } else
-          router->Send(msg);
-        break;
+      }else{
+        int dst_procs_id;
+        if(dst_flag==kStub){
+          dst_procs_id=msg->dst_first();
+        }else{
+          dst_procs_id=ProcsIDOf(msg->dst_first(), msg->dst_second(), 
msg->dst_flag());
+        }
+        if(dst_procs_id!=procs_id_){
+          /*
+             // forward to other procs
+             if (interprocs_dealers.find(procs_id)==interprocs_dealers.end())
+             interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id);
+             interprocs_dealers[procs_id]->Send(&msg);
+             */
+        }else{
+          router->Send(&msg);
+        }
+      }
     }
   }
 }
+Msg* Trainer::HandleConnect(Msg** msg){
+  string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());
+  CHECK_STREQ("PING", ping.c_str());
+  // ping-pong for debug
+  (*msg)->SwapAddr();
+  Msg* reply=new Msg();
+  reply->SetAddr(*msg);
+  reply->add_frame("PONG", 4);
+  reply->set_type(kConnect);
+  delete *msg;
+  *msg=NULL;
+  return reply;
+}
+int Trainer::Sharding(int param_id){
+  return param_id%Cluster::Get()->nservers_per_group();
+}
+/*
+int Worker::Sharding(int param_id){
+  static map<int, int> id2procs;
+  if(id2procs.find(param_id)==id2procs.end()){
+  auto cluster=Cluster::Get();
+  int server_group=group_id_%cluster->nserver_groups();
+  int nprocs_per_server_group=
+    cluster->nservers_per_group()/cluster->nservers_per_procs();
+  int procsid=server_group*nprocs_per_server_group+
+    param_id%nprocs_per_server_group;
+  procsid= cluster->server_worker_separate()?
+    cluster->nworker_procs()+procsid:procsid;
+  id2procs[param_id]=procsid;
+  }
+  return id2procs[param_id];
+}
+*/
+
+
+Msg* Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){
+  Msg* msgg=*msg, *reply=nullptr;
+  int version=msgg->target_second();
+  if(msgg->src_flag()==kStub){
+    if(version<=pi->shares.at(0)->version()){
+      reply=pi->shares.at(0)->HandleGetMsg(msg);
+    }else if(version>pi->next_version){
+      // reinsert into a msg queue.
+    }
+  }else if(version>pi->next_version){
+    pi->next_version=version;
+    reply=pi->shares.at(0)->GenGetMsg(&version);
+    int gid=msgg->src_first(), pid=msgg->target_first();
+    reply->set_src(procs_id_, gid, kStub);
+    reply->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding(pid), kServer);
+  }
+  return reply;
+}
+
+Msg* Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){
+  pi->shares.at(0)->ParseGetResponseMsg(msg);
+  return nullptr;
+  // process get requests in waiting queue
+}
+
+Msg* Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){
+  Msg* msgg=*msg, *update=nullptr;
+  if(msgg->src_flag()==kStub){
+    if(pi->num_update<pi->num_local)
+      return *msg; //wait unitl local updates are ready
+    int n;
+    sscanf((char*)(*msg)->frame_data(), "%d", &n);
+    pi->num_update+=n;
+    auto it=pi->shares.begin();
+    auto shape=mshadow::Shape1((*it)->size());
+    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
+    mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+    agg+=grad;
+  }else if(++pi->num_update>=pi->num_local){
+    auto it=pi->shares.begin();
+    auto shape=mshadow::Shape1((*it)->size());
+    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
+    for(++it;it!=pi->shares.end();it++){
+      mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+      agg+=grad;
+    }
+    agg/=pi->num_total;
+    if(pi->num_local<pi->num_total){
+      int v=msgg->target_second();
+      update=pi->shares.at(0)->GenUpdateMsg(&v);
+      int gid=msgg->src_first();
+      update->set_src(procs_id_, gid,kStub);
+      update->set_dst(pi->owner_procs, gid, kStub);
+      pi->num_update=0;
+    }
+  }
+  if(pi->num_update==pi->num_total){
+    int v=msgg->target_second();
+    update=pi->shares.at(0)->GenUpdateMsg(&v);
+    int gid=msgg->src_first();
+    update->set_src(procs_id_, gid, kStub);
+    update->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding((*msg)->target_first()), kServer);
+    pi->num_update=0;
+  }
+  delete *msg;
+  *msg=NULL;
+  return update;
+}
+
+int Trainer::HandleUpdateResponse(shared_ptr<Trainer::ParamInfo> pi, Msg** 
msg){
+  HandleGetResponse(pi, msg);
+  return 1;
+}
+
+Msg* Trainer::HandlePut(shared_ptr<Trainer::ParamInfo>pi, Msg** msg){
+  CHECK_NE((*msg)->src_flag(), kStub);
+  Msg* put=pi->shares.at(0)->GenPutMsg();
+  int gid=(*msg)->src_first();
+  int id=(*msg)->target_first();
+  put->set_src(procs_id_, gid , kStub);
+  put->set_dst(gid/Cluster::Get()->nworker_groups_per_server_group(),
+      Sharding(id), kServer);
+  delete *msg;
+  *msg=NULL;
+  return put;
+}
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 3f1c83f..7565d49 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -2,24 +2,23 @@
 #include <thread>
 #include <memory>
 #include <iostream>
+#include <chrono>
+#include <thread>
 #include "utils/singleton.h"
 #include "utils/factory.h"
 #include "trainer/worker.h"
 #include "proto/model.pb.h"
 using std::thread;
+DECLARE_int32(sleep);
 namespace singa {
 Worker::Worker(int thread_id, int group_id, int worker_id):
-  thread_id_(thread_id),group_id_(group_id), worker_id_(worker_id){
+  thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){
   }
 
 void Worker::Setup(const ModelProto& model,
-    shared_ptr<NeuralNet> train_net,
-    shared_ptr<PMWorker::ParamShard> shard){
+    shared_ptr<NeuralNet> train_net){
   train_net_=train_net;
   modelproto_=model;
-  pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance()
-      ->Create("PMWorker"));
-  pmworker_->Setup(group_id_, worker_id_, shard);
 }
 
 void Worker::Run(){
@@ -29,13 +28,13 @@ void Worker::Run(){
   layer_dealer_=make_shared<Dealer>(2*thread_id_+1);
   layer_dealer_->Connect(kInprocRouterEndpoint);
 
-  {
+  { // TODO remove waiting pong msg
   Msg* ping=new Msg();
   ping->set_src(group_id_, worker_id_, kWorkerParam);
-  ping->set_dst(0,0,kStub);
+  ping->set_dst(-1,-1,kStub);
   ping->set_type(kConnect);
   ping->add_frame("PING", 4);
-  param_dealer_->Send(ping);
+  param_dealer_->Send(&ping);
   ping=param_dealer_->Receive();
   string pong((char*)ping->frame_data(), ping->frame_size());
   CHECK_STREQ("PONG", pong.c_str());
@@ -45,10 +44,10 @@ void Worker::Run(){
   {
   Msg* ping=new Msg();
   ping->set_src(group_id_, worker_id_, kWorkerLayer);
-  ping->set_dst(0,0,kStub);
+  ping->set_dst(-1,-1,kStub);
   ping->set_type(kConnect);
   ping->add_frame("PING", 4);
-  layer_dealer_->Send(ping);
+  layer_dealer_->Send(&ping);
   ping=layer_dealer_->Receive();
   string pong((char*)ping->frame_data(), ping->frame_size());
   CHECK_STREQ("PONG", pong.c_str());
@@ -60,37 +59,60 @@ void Worker::Run(){
     //LOG(ERROR)<<layer->partitionid()<<" : "<<layer->name();
     if(layer->partitionid()==worker_id_)
       for(auto param: layer->GetParams()){
-        if(group_id_==0&&param->owner()==param->id()){
-          param->Init(0);
-          Put(param, step_);
+        if(group_id_==0){
+          if(param->owner()==param->id()){
+            param->Init(0);
+            Put(param, step_);
+          }else{
+            Get(param, 0);
+          }
         }else{
-          Get(param, step_);
+          Get(param, modelproto_.warmup_steps());
         }
       }
   }
-  step_=modelproto_.step();
-  Performance perf(train_net_);
+  Metric perf;
+  if(group_id_==0&&step_<modelproto_.warmup_steps()){
+    for(step_=0;step_<modelproto_.warmup_steps();step_++)
+      RunOneBatch(step_, &perf);
+    for(auto layer: train_net_->layers()){
+      //LOG(ERROR)<<layer->partitionid()<<" : "<<layer->name();
+      if(layer->partitionid()==worker_id_)
+        for(auto param: layer->GetParams())
+          if(param->owner()==param->id())
+            Put(param, step_);
+    }
+  }
   while(!StopNow(step_)){
     RunOneBatch(step_, &perf);
     step_++;
   }
 }
 int Worker::Put(shared_ptr<Param> param, int step){
-  auto msg=pmworker_->Put(param, step);
-  if(msg!=nullptr)
-    param_dealer_->Send(msg);
+  Msg* msg=new Msg();
+  msg->set_src(group_id_, worker_id_, kWorkerParam);
+  msg->set_dst(-1, -1, kStub);
+  msg->set_type(kPut);
+  msg->set_target(param->owner(), step);
+  param_dealer_->Send(&msg);
   return 1;
 }
 int Worker::Get(shared_ptr<Param> param, int step){
-  auto msg=pmworker_->Get(param, step);
-  if(msg!=nullptr)
-    param_dealer_->Send(msg);
+  Msg* msg=new Msg();
+  msg->set_src(group_id_, worker_id_, kWorkerParam);
+  msg->set_dst(-1, -1, kStub);
+  msg->set_type(kGet);
+  msg->set_target(param->owner(), step);
+  param_dealer_->Send(&msg);
   return 1;
 }
 int Worker::Update(shared_ptr<Param> param, int step){
-  auto msg=pmworker_->Update(param, step);
-  if(msg!=nullptr)
-    param_dealer_->Send(msg);
+  Msg* msg=new Msg();
+  msg->set_src(group_id_, worker_id_, kWorkerParam);
+  msg->set_dst(-1, -1, kStub);
+  msg->set_type(kUpdate);
+  msg->set_target(param->owner(), step);
+  param_dealer_->Send(&msg);
   return 1;
 }
 
@@ -106,22 +128,24 @@ int Worker::CollectAll(shared_ptr<NeuralNet> net, int 
step){
 }
 int Worker::Collect(shared_ptr<Param> param, int step){
   while(param->version()<step){
-    Socket* which=param_poller_.Wait(10);
-    if(which!=nullptr){
-      Msg* msg=param_dealer_->Receive();
-      if(msg==nullptr)
-        return 0;
-      pmworker_->Collect(&msg);
-    }
+    std::this_thread::sleep_for(std::chrono::milliseconds(FLAGS_sleep));
   }
   return 1;
 }
+const void Worker::DisplayPerformance(const Metric & perf, const string& 
prefix){
+  /* TODO send perf to Stub thread for printing
+     Msg* msg=new Msg();
+     msg->set_src(group_id_, worker_id_, kWorkerParam);
+     msg->set_dst(-1,-1, kStub);
+     msg->set_type(kMetric);
+     const string disp=perf.ToString();
+     msg->AddFrame(disp.c_str(), disp.length());
+     param_dealer_->Send(&msg);
+     */
+  LOG(ERROR)<<prefix<<" "<<perf.ToString();
+}
 
-void Worker::RunOneBatch(int step, Performance* perf){
-  //DLOG(ERROR)<<"Step "<<step;
-  // Test will call Pull which updates the sync time
-  // Hence we store the sync time, and restore it later
-  //float tSyncData=tSyncData_, tSyncParam=tSyncParam_;
+void Worker::RunOneBatch(int step, Metric* perf){
   if(ValidateNow(step)){
     LOG(ERROR)<<"Validation at step "<<step;
     CollectAll(validation_net_, step);
@@ -132,20 +156,23 @@ void Worker::RunOneBatch(int step, Performance* perf){
     CollectAll(test_net_, step);
     Test(test_net_, modelproto_.test_steps(), perf!=nullptr);
   }
-  //tSyncData_=tSyncData; tSyncParam_=tSyncParam;
-
-  CollectAll(train_net_, step);
   TrainOneBatch(step);
   if(perf!=nullptr){
-    perf->Update();
+    auto losslayers=train_net_->losslayers();
+    for(auto layer: losslayers){
+      if(layer->partitionid()==worker_id_){
+        const float * ptr=layer->metric().cpu_data();
+        for(int j=0;j<layer->metric().count();j++)
+          perf->AddMetric(layer->name()+"-"+std::to_string(j), ptr[j]);
+      }
+    }
+    perf->Inc();
     if(DisplayNow(step)){
-      LOG(ERROR)<<"Training at step "<<step;
-      LOG(ERROR)<<"\t"<<perf->ToString();
+      perf->Avg();
+      DisplayPerformance(*perf, "Train at step "+std::to_string(step));
       perf->Reset();
-      //LOG(ERROR)<<"\t"<<TimerInfo();
     }
   }
-
   /*
   if(CheckpointNow(step)){
     pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step));
@@ -154,44 +181,32 @@ void Worker::RunOneBatch(int step, Performance* perf){
 }
 
 void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){
-  /*
-  int type;
-  char *name;
-  int64_t tick=zclock_mono();
-  zframe_t* frame=zframe_new_empty();
-
-  zsock_recv(pull_, "isf", &type, &name, &frame);
-  if(type==kDataFrame){
-    auto* dst=static_cast<BridgeDstLayer*>(
-        net->name2layer(string(name)).get());
-    memcpy(dst->mutable_data()->mutable_cpu_data(), zframe_data(frame),
-        zframe_size(frame));
-    dst->set_ready(true);
-  }else if(type==kGradFrame){
-    auto* 
src=static_cast<BridgeSrcLayer*>(net->name2layer(string(name)).get());
-    memcpy(src->mutable_grad()->mutable_cpu_data(), zframe_data(frame),
-        zframe_size(frame));
-    src->set_ready(true);
-  }
-  zframe_destroy(&frame);
-  delete name;
-  tSyncData_+=zclock_mono()-tick;
-  */
 }
 
 void Worker::SendBlob(){
-
 }
 
 void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, bool disperf){
-  Performance perf(net);
+  const auto& losslayers=net->losslayers();
+  Metric perf;
   for(int step=0;step<nsteps;step++){
     TestOneBatch(net, step, kTest);
-    if(disperf)
-      perf.Update();
+    if(disperf){
+      for(auto layer: losslayers){
+        if(layer->partitionid()==worker_id_){
+          const float * ptr=layer->metric().cpu_data();
+          for(int j=0;j<layer->metric().count();j++)
+            perf.AddMetric(layer->name()+"-"+std::to_string(j), ptr[j]);
+        }
+      }
+      perf.Inc();
+    }
+  }
+  if(disperf){
+    perf.Avg();
+    DisplayPerformance(perf, "Test");
+    perf.Reset();
   }
-  if(disperf)
-    LOG(ERROR)<<"\t"<<perf.ToString();
 }
 
 /****************************BPWorker**********************************/
@@ -204,7 +219,7 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int step, 
 bool training){
         auto* dst=static_cast<BridgeDstLayer*>(layer.get());
         while(!dst->ready()){
           auto msg=layer_dealer_->Receive();
-          CHECK_EQ(msg->src_group_id(), group_id_);
+          CHECK_EQ(msg->src_first(), group_id_);
           string name((char*)msg->frame_data(), msg->frame_size());
           auto tmp=net->name2layer(name);
           CHECK(tmp->is_bridgedstlayer());
@@ -232,7 +247,7 @@ void BPWorker::Forward(shared_ptr<NeuralNet> net, int step, 
 bool training){
         msg->add_frame(dst->name().c_str(), dst->name().length());
         auto const & blob=layer->data(nullptr);
         msg->add_frame(blob.cpu_data(), blob.count()*sizeof(float));
-        layer_dealer_->Send(msg);
+        layer_dealer_->Send(&msg);
       }
       
if(training&&DisplayDebugInfo(step)&&layer->mutable_data(nullptr)!=nullptr){
         LOG(INFO)<<StringPrintf("Forward layer  %10s data norm1 %13.9f",
@@ -280,76 +295,4 @@ void BPWorker::TestOneBatch(shared_ptr<NeuralNet> net,int 
step, Phase phase){
   Forward(net, step, false);
 }
 
-/*********************Implementation for Performance class*******************/
-Performance::Performance(shared_ptr<NeuralNet> net):net_(net), counter_(0){
-  for(auto& layer: net->losslayers()){
-    name_.push_back(layer->name());
-    metric_.push_back(vector<float>{});
-    metric_.back().resize(layer->metric().count(),0.f);
-  }
-}
-
-void Performance::Update(){
-  const auto& losslayers=net_->losslayers();
-  for(size_t i=0;i<losslayers.size();i++){
-    const float * ptr=losslayers[i]->metric().cpu_data();
-    vector<float>& m=metric_.at(i);
-    for(int j=0;j<losslayers[i]->metric().count();j++)
-      m[j]+=ptr[j];
-  }
-  counter_++;
-}
-
-void Performance::Reset(){
-  for(auto& m: metric_)
-    for(auto& x: m)
-      x=0.f;
-  counter_=0;
-}
-
-string Performance::ToString(){
-  string disp="";
-  for(size_t i=0;i<metric_.size();i++){
-    disp+="Output from "+name_[i]+" layer ";
-    vector<float> m=metric_.at(i);
-    for(size_t j=0;j<m.size();j++)
-        disp+=std::to_string(j)+" : "+std::to_string(m[j]/counter_)+"\t";
-    disp+="\n";
-  }
-  return disp;
-}
-/*
-void Executor::Setup(int local_threadid, const ModelProto& model){
-  tForward_=tBackward_=tSyncData_=tSyncParam_=0;
-  modelproto_=model;
-  local_threadid_=local_threadid;
-  if(model.prefetch()){
-    for(auto& layer: train_net_->datalayers()){
-      if(cluster_->group_threadid(local_threadid_)==layer->locationid())
-        localDataLayers_.push_back(layer);
-    }
-    if(localDataLayers_.size())
-      prefetch_thread_=std::thread(Executor::PrefetchData,
-          std::ref(localDataLayers_), true,1);
-  }
-  int gthreadid=cluster_->group_threadid(local_threadid);
-}
-
-void Executor::PrefetchData(const vector<DataLayer*>& datalayers, bool 
training,
-    int steps){
-  if(datalayers.size()==0)
-    return;
-  for(int i=0;i<steps;i++){
-    for(auto& layer: datalayers){
-      layer->Prefetching(training);
-      for(auto& dstlayer: layer->dstlayers()){
-        CHECK(dstlayer->is_parserlayer());
-        auto parserlayer=static_cast<ParserLayer*>(dstlayer.get());
-        parserlayer->Prefetching(training);
-      }
-    }
-  }
-}
-*/
-
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b5b943c7/src/utils/graph.cc
----------------------------------------------------------------------
diff --git a/src/utils/graph.cc b/src/utils/graph.cc
index 076c971..6fff83b 100644
--- a/src/utils/graph.cc
+++ b/src/utils/graph.cc
@@ -1,4 +1,5 @@
 #include <algorithm>
+#include <queue>
 #include "utils/graph.h"
 
 const string Graph::ToString() const {
@@ -78,30 +79,43 @@ void Graph::topology_sort_inner(SNode node,
 
 // sort to make `bottom' nodes be placed in the front positions
 void Graph::Sort() {
-  // adjacent list from upper layers to lower layers
-  std::map<string, bool> visited;
-  // prepare adjacent list; input layers will be processed firstly,
-  // hence no need to sort them (mark them as visited)
-  for (SNode node: nodes_) {
-    visited[node->name()] = false;
-  }
-  // the `top' layer in the net will be placed at the bottom of the stack
-  // and then be processed (i.e., forward) at last
-  std::stack<string > stack;
-  for (SNode node: nodes_) {
-    if (visited[node->name()] == false)
-      topology_sort_inner(node, &visited, &stack);
+  SNode start=nullptr;
+  map<string, bool> visited;
+  for(auto node: nodes_){
+    if(node->srcnodes().size()==0){
+      CHECK(start==nullptr);
+      start=node;
+    }
+    visited[node->name()]=false;
   }
+  int n=nodes_.size();
+  std::queue<SNode> tmp;
+  tmp.push(start);
   nodes_.clear();
-
-  while (!stack.empty()) {
-    nodes_.push_back(name2node_[stack.top()]);
-    stack.pop();
+  while(!tmp.empty()){
+    auto node=tmp.front();
+    tmp.pop();
+    bool visit=true;
+    for(auto src: node->srcnodes())
+      if(visited[src->name()]==false){
+        visit=false;
+        break;
+      }
+    if(visit){
+      nodes_.push_back(node);
+      visited[node->name()]=true;
+      for(auto dst: node->dstnodes()){
+        CHECK(visited.find(dst->name())!=visited.end())<<dst->name();
+        if(visited[dst->name()]==false){
+          tmp.push(dst);
+        }
+      }
+    }
   }
+  CHECK_EQ(nodes_.size(), n);
 }
 
 
-
 SNode Graph::InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes,
     const V& info, bool connect_dst){
   V myinfo=info;

Reply via email to