Repository: incubator-singa
Updated Branches:
  refs/heads/master 7954a87d2 -> 96bedb226


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
index a6a5dbf..3ecaad0 100644
--- a/src/trainer/trainer.cc
+++ b/src/trainer/trainer.cc
@@ -1,9 +1,10 @@
 #include <thread>
 #include <vector>
 #include <map>
-#include <queue>
 #include <chrono>
 #include <glog/logging.h>
+#include "utils/cluster.h"
+#include "utils/common.h"
 #include "proto/common.pb.h"
 #include "trainer/trainer.h"
 #include "mshadow/tensor.h"
@@ -11,587 +12,486 @@
 namespace singa {
 using std::vector;
 using std::map;
+using std::queue;
 using namespace std::chrono;
 using std::make_shared;
 
-typedef std::chrono::milliseconds TimeT;
+/***********************Trainer****************************/
+Trainer::~Trainer() {
+  // free Params (i.e., slices) in server shard
+  for (auto entry : server_shard_)
+    for (auto param : entry.second->shares)
+      delete param;
+  delete router_;
+}
 
-void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
-  // register all layers appearing in the neural net
+void Trainer::RegisterDefaultClasses(const singa::ModelProto& model_conf) {
+  // register all implemented layers
   singa::NeuralNet::RegisterLayers();
-  Singleton<Factory<singa::Param>>::Instance()->Register(
-      "Param", CreateInstance(singa::Param, singa::Param));
-  Singleton<Factory<singa::Updater>>::Instance() ->Register(
-      "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
+  auto param_factory = Singleton<Factory<singa::Param>>::Instance();
+  param_factory->Register("Param", CreateInstance(Param, Param));
+  auto updater_factory = Singleton<Factory<singa::Updater>>::Instance();
+  updater_factory->Register("Updater", CreateInstance(SGDUpdater, Updater));
 }
 
-void HandleWorkerFinish(void * ctx){
-  HandleContext* hctx=static_cast<HandleContext*> (ctx);
-  Msg* msg=new Msg();
-  msg->set_src(-1,-1, kRuntime);
-  msg->set_dst(hctx->group_id, hctx->id, kServer);
-  msg->set_type(kStop);
-  hctx->dealer->Send(&msg);
-}
+const vector<int> SliceParams(const vector<Param*>& params) {
+  // for load-balance among servers in a group and among server groups
+  int nserver_grps = Cluster::Get()->nserver_groups();
+  int nservers_per_grp = Cluster::Get()->nservers_per_group();
+  int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp);
 
-const std::unordered_map<int, vector<std::pair<int, int>>>
-SliceParams(int num, const vector<Param*>& params){
+  // collect sizes of unique Params
+  std::vector<int> paramsize;
+  for (auto param : params)
+    if (param->id() == param->owner())
+      paramsize.push_back(param->size());
+  // slice into lcm pieces to achieve good load-balance for both intra-group
+  // partition (among servers in a group) and inter-group partition (each group
+  // is assgined a sub-set of slices)
+  auto param_slice = Slice(lcm, paramsize);
+  // construct map from Param ID to its slices <slice id, len>
   std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices;
-  if (num==0)
-    return paramid2slices;
-  vector<int> param_size;
-  int avg=0;
-  for(const auto& x:params){
-    if(x->owner()==x->id())
-      avg+=x->size();
-  }
-  avg=avg/num+avg%num;
-  int diff=avg/10;
-  LOG(INFO)<<"Slicer, param avg="<<avg<<", diff= "<<diff;
-
-  int capacity=avg, sliceid=0, nbox=0;
-  for(auto& param: params){
-    if(param->id()!=param->owner())
-      continue;
-    int x=param->size(), paramid=param->id();
-    LOG(INFO)<<"param id="<<paramid<<", total size="<<x;
-    while(x>0){
-      int size=0;
-      if(capacity>=x){
-        capacity-=x;
-        size=x;
-        x=0;
-      }else if(capacity+diff>=x){
-        size=x;
-        x=0;
-        capacity=0;
-      }else if(capacity>=diff){
-        x-=capacity;
-        size=capacity;
-        capacity=avg;
-        nbox++;
-      }else{
-        capacity=avg;
-        nbox++;
-      }
-      if(size){
-        paramid2slices[paramid].push_back(std::make_pair(sliceid++, size));
-        LOG(INFO)<<"param id="<<paramid<<", slice size="<<size;
+  vector<int> slices;
+  auto it = param_slice.begin();
+  int slice_id = 0;
+  for (auto param : params) {
+    if (param->id() == param->owner()) {
+      for (int len : *it) {
+        slices.push_back(len);
+        paramid2slices[param->id()].push_back(std::make_pair(slice_id++, len));
       }
+      it++;
     }
   }
-  CHECK_LE(nbox, num);
-  return paramid2slices;
+  // add slice info for every Param
+  for (auto param : params)
+    for (auto entry : paramid2slices[param->owner()]) {
+      param->AddSlice(entry.first, entry.second);
+      LOG(INFO) << "param id " << param->id() << " owner=" << param->owner()
+        << ": " << entry.first << ", " << entry.second;
+    }
+  return slices;
 }
-const vector<int> PartitionSlice(int num, const vector<int>& slices){
-  int avg=0;
-  for(int x: slices)
-    avg+=x;
-  avg=avg/num+avg%num;
-  int box=avg, boxid=0, diff=avg/10;
-  vector<int> slice2box;
-  for(auto it=slices.begin(); it!=slices.end();){
-    int x=*it;
-    if(box>=x){
-      box-=x;
-      slice2box.push_back(boxid);
-      it++;
-    }else if(box+diff>=x){
-      slice2box.push_back(boxid);
-      it++;
-      box=0;
-    }else{
-      box=avg;
-      boxid++;
+
+void Trainer::SetupWorkerServer(
+    const ModelProto& model_conf,
+    const vector<Worker*>& workers,
+    const vector<Server*>& servers) {
+  auto cluster = Cluster::Get();
+  int grp_size = cluster->nworkers_per_group();
+  const auto& net_conf = model_conf.neuralnet();
+  auto net = NeuralNet::Create(net_conf, kTrain, grp_size);
+  // MUST do SliceParam before share param/net with others
+  auto slices = SliceParams(net->params());
+  shared_ptr<NeuralNet> train_net, test_net, valid_net;
+  int grp = workers.size() ? workers.at(0)->grp_id() : -1;
+  if (grp == 0 && model_conf.test_steps()) {
+    // test are performed only by the first group
+    test_net = NeuralNet::Create(net_conf, kTest, grp_size);
+    test_net->ShareParamsFrom(net);
+  }
+  if (grp == 0 && model_conf.validation_steps()) {
+    // validation are performed only by the first group
+    valid_net = NeuralNet::Create(net_conf, kValidation, grp_size);
+    valid_net->ShareParamsFrom(net);
+  }
+  bool prepare_param = true;
+  for (auto worker : workers) {
+    if (worker->grp_id() != grp) {
+      train_net = NeuralNet::Create(net_conf, kTrain, grp_size);
+      if(cluster->share_memory())
+        train_net->ShareParamsFrom(net);
+      valid_net = test_net = nullptr;
+      grp = worker->grp_id();
+      prepare_param = true;
+    } else {
+      train_net = net;
+    }
+    worker->Setup(model_conf, train_net, valid_net, test_net);
+    // Prepare ParamEntry
+    if (prepare_param) {
+      for (auto layer : train_net->layers()) {
+        bool local = layer->partition_id() >= workers.front()->id()
+          && layer->partition_id() <= workers.back()->id();
+        for (auto param : layer->GetParams()) {
+          int hash = Hash(grp, param->owner());
+          if (worker_shard_.find(hash) == worker_shard_.end())
+            worker_shard_[hash] = new ParamEntry();
+          worker_shard_[hash]->AddParam(local, param);
+        }
+      }
+      prepare_param = false;
     }
   }
-//  CHECK_LT(slice2box.back(), num);
-  CHECK_EQ(slice2box.size(), slices.size());
-  int previd=slice2box[0];
-  std::string disp;
-  for(size_t i=0;i<slice2box.size();i++)
-    if(previd!=slice2box[i]){
-      disp+=", "+std::to_string(slices[i]);
-      previd=slice2box[i];
-    } else
-      disp+=" "+std::to_string(slices[i]);
-  LOG(INFO)<<"partition slice (avg ="<<avg<<", num="<<num<<"):"<<disp;
-  return slice2box;
+  // partition among server groups, each group maintains one sub-set for sync
+  auto slice2group = PartitionSlices(cluster->nserver_groups(), slices);
+  for (auto server : servers)
+    server->Setup(model_conf.updater(), &server_shard_, slice2group);
+  // partition within one server group, each server updates for one sub-set
+  slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices);
 }
-vector<Server*> Trainer::CreateServers(int nthreads,
-    const ModelProto & mproto,
-    const vector<int> slices,
-    vector<HandleContext*>* ctx){
-  auto cluster=Cluster::Get();
+
+vector<Server*> Trainer::CreateServers(int nthreads, const ModelProto& mconf) {
+  auto cluster = Cluster::Get();
   vector<Server*> servers;
-  if(!cluster->has_server())
+  if (!cluster->has_server())
     return servers;
 
-  int pid=cluster->procs_id();
-  if(cluster->server_worker_separate())
-    pid-=cluster->nworker_procs();
-  int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group();
-  int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
-  int end=start+cluster->nservers_per_procs();
-  // the ServerShard for servers consists of a dictionary of Param objects
-  server_shard_=make_shared<ServerShard>();
-  auto slice2group=PartitionSlice(cluster->nserver_groups(), slices);
-  if(start<end){
-    auto dealer=make_shared<Dealer>();
-    dealer->Connect(kInprocRouterEndpoint);
-    for(int sid=start;sid<end;sid++){
-      auto server=new Server(nthreads++, gid, sid);
-      server->Setup(mproto.updater(), server_shard_, slice2group);
-      servers.push_back(server);
-      auto *hc=new HandleContext{dealer, gid, sid};
-      ctx->push_back(hc);
-      CHECK(cluster->runtime()->WatchSGroup(gid, sid, HandleWorkerFinish,
-            ctx->back()));
-    }
+  int pid = cluster->procs_id();
+  // if true, server procs (logical) id starts after worker procs
+  if (cluster->server_worker_separate())
+    pid -= cluster->nworker_procs();
+  int procs_size = cluster->nservers_per_procs();
+  int grp_size = cluster->nservers_per_group();
+  int gid = pid *  procs_size / grp_size;
+  int start = pid * procs_size % grp_size;
+  int end = start + procs_size;
+  for (int sid = start; sid < end; sid++) {
+    auto server = new Server(nthreads++, gid, sid);
+    servers.push_back(server);
   }
   return servers;
 }
 
-vector<Worker*> Trainer::CreateWorkers(int nthreads,
-    const ModelProto& mproto, vector<int> *slice_size){
+vector<Worker*> Trainer::CreateWorkers(int nthreads, const ModelProto& mconf){
   auto cluster=Cluster::Get();
-  auto net=NeuralNet::Create(mproto.neuralnet(), kTrain,
-      cluster->nworkers_per_group());
-  int lcm=LeastCommonMultiple(cluster->nserver_groups(), 
cluster->nservers_per_group());
-  auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size
-  for(auto param: net->params()){
-    if(param->id() == param->owner())
-      for(auto entry: paramid2slices[param->id()])
-        slice_size->push_back(entry.second);
-  }
-
   vector<Worker*> workers;
   if(!cluster->has_worker())
     return workers;
-  //LOG(ERROR)<<net->ToString();
-  int pid=cluster->procs_id();
+  int pid = cluster->procs_id();
+  int grp_size = cluster->nworkers_per_group();
+  int procs_size = cluster->nworkers_per_procs();
   int gstart, gend, wstart, wend;
-  if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
+  if (grp_size >= procs_size) {
     // all workers in this procs are from the same group
-    gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group();
-    gend=gstart+1;
-    wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group();
-    wend=wstart+cluster->nworkers_per_group();
-  }else{
-    // there are multiple groups in this procs
-    CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0);
-    int groups_per_procs=
-      cluster->nworkers_per_procs()/cluster->nworkers_per_group();
-    gstart=pid*groups_per_procs;
-    gend=(pid+1)*groups_per_procs;
-    wstart=0;
-    wend=cluster->nworkers_per_group();
+    gstart = pid * procs_size / grp_size;
+    gend = gstart + 1;
+    wstart = pid * procs_size % grp_size;
+    wend = wstart + procs_size;
+  } else {
+    // there are multiple (complete) groups in this procs.
+    CHECK_EQ(procs_size % grp_size, 0);
+    int groups_per_procs = procs_size / grp_size;
+    gstart = pid * groups_per_procs;
+    gend = (pid+1) * groups_per_procs;
+    wstart = 0;
+    wend = grp_size;
   }
-  for(int gid=gstart;gid<gend;gid++){
-    shared_ptr<NeuralNet> train_net, test_net, validation_net;
-    if(gid==gstart)
-      train_net=net;
-    else{
-      train_net=NeuralNet::Create(mproto.neuralnet(), kTrain,
-          cluster->nworkers_per_group());
-      // the train net for other groups may share parameter values from the
-      // first group
-      if(cluster->share_memory())
-        train_net->ShareParams(net);
-    }
-    if(gid==0){
-      // validation and test are performed only by the first group
-      if(mproto.test_steps()){
-        test_net=NeuralNet::Create(mproto.neuralnet(), kTest,
-            cluster->nworkers_per_group());
-        if(test_net!=nullptr)
-          test_net->ShareParams(train_net);
-      }
-      if(mproto.validation_steps()){
-        validation_net=NeuralNet::Create(mproto.neuralnet(), kValidation,
-            cluster->nworkers_per_group());
-        if(validation_net!=nullptr)
-          validation_net->ShareParams(train_net);
-      }
-    }
-    // create ServerShard for the workers
-    auto shard=make_shared<WorkerShard>();
-    worker_shards_[gid]=shard;
-    for(auto layer: train_net->layers()){
-      int procsid=cluster->ProcsIDOf(gid, layer->partition_id(), kWorkerLayer);
-      bool local=procsid==cluster->procs_id();
-      for(auto param: layer->GetParams()){
-        for(auto entry :paramid2slices[param->owner()]){
-          param->AddSlice(entry.first,  entry.second);
-        }
-        int owner_procs=param->owner()==param->id()?procsid:procs_id_;
-        if(shard->find(param->owner())==shard->end())
-          (*shard)[param->owner()]=
-            make_shared<ParamInfo>(param, local, owner_procs);
-        else
-          shard->at(param->owner())->AddParam(param, local);
-      }
-    }
-    for(int wid=wstart;wid<wend;wid++){
+  for (int gid = gstart; gid < gend; gid++) {
+    for (int wid = wstart; wid < wend; wid++) {
       Worker* worker=nullptr;
-      if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
+      if (mconf.alg() == ModelProto_GradCalcAlg_kBackPropagation)
         worker = new BPWorker(nthreads++,gid, wid);
-      else{
-        // TODO add CDWorker
+      else {
+        // TODO add CDWorker and BPTTWorker
       }
-      worker->Setup(mproto, train_net);
-      worker->set_test_net(test_net);
-      worker->set_validation_net(validation_net);
       workers.push_back(worker);
     }
   }
   return workers;
 }
 
-void Trainer::Start(const ModelProto& mproto, const GlobalProto& gproto, 
-                    const ClusterProto& cproto,
-    int procs_id){
-  // procs_id is only used for resume training
-  CHECK_EQ(procs_id, -1);
-  RegisterDefaultClasses(mproto);
+void Trainer::Start(const ModelProto& mconf, const GlobalProto& gconf,
+                    const ClusterProto& cconf, int job){
+  RegisterDefaultClasses(mconf);
 
-  auto cluster=Cluster::Get(gproto, cproto, procs_id);
-  router_=make_shared<Router>();
+  // register job to zookeeper
+  auto cluster=Cluster::Get(gconf, cconf, job);
+  if (mconf.resume()) {
+    // TODO(wangwei) resume from checkpoint
+    // load param slices to server_shard_ and reset running step of worker
+    // mproto.set_step(step);
+  }
+
+  router_ = new Router();
   router_->Bind(kInprocRouterEndpoint);
-  if(cluster->nprocs()>1){
-    const string hostip=cluster->hostip();
-    int port=router_->Bind("tcp://"+hostip+":*");
-    cluster->Register(hostip+":"+std::to_string(port));
-  }else
+  if (cluster->nprocs() > 1) {
+    const string hostip = cluster->hostip();
+    int port = router_->Bind("tcp://" + hostip + ":*");
+    // register endpoint to zookeeper
+    cluster->Register(hostip + ":" + std::to_string(port));
+  } else {
     cluster->set_procs_id(0);
+  }
 
-  procs_id_ = cluster->procs_id();
-  int nthreads=1;
-  // create workers
-  vector<int> slices;
-  vector<Worker*> workers=CreateWorkers(nthreads, mproto, &slices);
-  if(cluster->nserver_groups()&&cluster->nservers_per_group())
-    slice2server_=PartitionSlice(cluster->nservers_per_group(), slices);
-  nthreads+=workers.size();
-  // create servers
-  vector<HandleContext*> ctx;
-  vector<Server*> servers=CreateServers(nthreads, mproto, slices,
-      &ctx);
+  int nthreads = 1;
+  const vector<Worker*> workers = CreateWorkers(nthreads, mconf);
+  nthreads += workers.size();
+  const vector<Server*> servers = CreateServers(nthreads, mconf);
+  SetupWorkerServer(mconf, workers, servers);
 
 #ifdef USE_MPI
-  for(int i=0;i<nSocket;i++){
+  for (int i = 0; i < nthreads; i++)
     MPIQueues.push_back(make_shared<SafeQueue>());
-  }
 #endif
   vector<std::thread> threads;
-  for(auto server: servers)
-    threads.push_back(std::thread(&Server::Run,server));
-  for(auto worker: workers)
-    threads.push_back(std::thread(&Worker::Run,worker));
+  for(auto server : servers)
+    threads.push_back(std::thread(&Server::Run, server));
+  for(auto worker : workers)
+    threads.push_back(std::thread(&Worker::Run, worker));
   Run(workers, servers);
-  for(auto& thread: threads)
+  for(auto& thread : threads)
     thread.join();
-  for(auto x: ctx)
-    delete x;
-  for(auto x : servers)
-    delete x;
-  for(auto x : workers)
-    delete x;
+  for(auto server : servers)
+    delete server;
+  for(auto worker : workers)
+    delete worker;
 }
 
-inline int bandwidth(int bytes, system_clock::time_point start){
+inline int bandwidth(int bytes, system_clock::time_point start) {
   auto now=system_clock::now();
-  auto duration=duration_cast<TimeT> (now - start);
+  auto duration=duration_cast<std::chrono::milliseconds> (now - start);
   return static_cast<int>(bytes*1000.f/duration.count());
 }
 
-void Trainer::Run(const vector<Worker*>& workers,
-    const vector<Server*>& servers){
-  auto cluster=Cluster::Get();
-  procs_id_=cluster->procs_id();
-  LOG(INFO)<<"Stub in process "<<procs_id_<<" starts";
-  map<int, shared_ptr<Dealer>> interprocs_dealers;
+void Trainer::Run(
+    const vector<Worker*>& workers,
+    const vector<Server*>& servers) {
+  int nworkers = workers.size(), nservers = servers.size();
+  auto cluster = Cluster::Get();
+  procs_id_ = cluster->procs_id();
+  LOG(INFO) << "Stub in process " << procs_id_ << " starts";
+
+  // for sync among server groups
+  auto start = std::chrono::system_clock::now();
+  float trans_size = 0.f;  // total size of msg transferred since start time
+  int sync_server_id = 0;
+  int max_bandwidth = cluster->bandwidth();
+  int nserver_grps = cluster->nserver_groups();
+
+  map<int, Dealer*> inter_dealers;  // for sending msg to other procs
+
   std::queue<Msg*> msg_queue;
+  Poller poll(router_);
   bool stop=false;
-  auto start=std::chrono::system_clock::now();
-  float amount=0.f;
-  Poller poll;
-  poll.Add(router_.get());
-  int sync_server=0, nworkers=workers.size(), nservers=servers.size();
-  while(!stop){
-    // if the poll time is large, then the poller may not expire
-    // if it is small, then many reminder messages will be sent which may
-    // slow done the process of other request. TODO tune it.
-    auto *sock=poll.Wait(cluster->poll_time());
-    if(poll.Terminated()){
-      LOG(ERROR)<<"Connection broken!";
-      exit(0);
-    }else if(sock==nullptr){
-      if(cluster->nserver_groups()>1&&
-          bandwidth(amount, start)<cluster->bandwidth()){
-        Msg* msg=new Msg();
-        msg->set_src(-1,-1, kStub);
-        msg->set_dst(servers[sync_server]->group_id(),
-            servers[sync_server]->server_id(), kServer);
-        msg->set_type(kSyncReminder);
-        sync_server=(sync_server+1)%servers.size();
-        router_->Send(&msg);
+  while (!stop || !msg_queue.empty()) {
+    if (msg_queue.empty()) {
+      // if the poll time is large, then the poller may not expire
+      // if it is small, then many reminder messages will be sent which may
+      // slow done the process of other request. TODO tune it.
+      auto *sock = poll.Wait(cluster->poll_time());
+      if (poll.Terminated()) {
+        LOG(ERROR) << "Connection broken!";
+        exit(0);
+      } else if (sock == nullptr) {
+        if (nserver_grps > 1 && bandwidth(trans_size, start) < max_bandwidth) {
+          Msg* msg = GenSyncReminderMsg(sync_server_id, servers);
+          router_->Send(&msg);
+          sync_server_id = (sync_server_id + 1) % nservers;
+        }
+        continue;
       }
-      continue;
+      Msg* msg = router_->Receive();
+      msg_queue.push(msg);
     }
-    Msg* msg=router_->Receive();
-    if(msg==nullptr){
-      LOG(ERROR)<<"Connection broken!";
-      exit(0);
-    }
-    msg_queue.push(msg);
-    while(!msg_queue.empty()){
-      msg=msg_queue.front();
-      msg_queue.pop();
-      int dst_flag=msg->dst_flag();
-      int type=msg->type();
-      int dst_procs=msg->dst_first();
-      if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){
-        if(type==kConnect){
-          msg_queue.push(HandleConnect(&msg));
-        }else if(type==kStop){
-          if(msg->src_flag()==kServer)
-            nservers--;
-          else if (msg->src_flag()==kWorkerParam)
-            nworkers--;
-          DeleteMsg(&msg);
-          if(nworkers==0&&nservers==0){
-            stop=true;
-            break;
-          }
-        }else if(type==kMetric){
-          if(msg->src_first()==0){
-            int step=msg->trgt_first();
-            string prefix((char*)msg->frame_data(), msg->frame_size());
-            msg->next_frame();
-            Metric cur;
-            cur.ParseFrom(string((char*)msg->frame_data(), msg->frame_size()));
-            LOG(ERROR)<<prefix<<" step-" <<step<<", "<<cur.ToLogString();
-          }
-          DeleteMsg(&msg);
-        }else if(cluster->nserver_groups()>0){
-          int group_id;
-          int paramid=msg->trgt_first();
-          shared_ptr<ParamInfo> entry;
-          switch (type){ // TODO process other requests, e.g. RESTful
-            case kUpdate:
-              group_id=msg->src_first();
-              entry=worker_shards_.at(group_id)->at(paramid);
-              for(auto x:HandleUpdate(entry, &msg))
-                msg_queue.push(x);
-              break;
-            case kRUpdate:
-              group_id=msg->dst_second();
-              entry=worker_shards_.at(group_id)->at(paramid);
-              HandleUpdateResponse(entry, &msg);
-              break;
-            case kGet:
-              group_id=msg->src_first();
-              entry=worker_shards_.at(group_id)->at(paramid);
-              for(auto x:HandleGet(entry, &msg))
-                msg_queue.push(x);
-              break;
-            case kRGet:
-              group_id=msg->dst_second();
-              entry=worker_shards_.at(group_id)->at(paramid);
-              HandleGetResponse(entry, &msg);
-              break;
-            case kPut:
-              group_id=msg->src_first();
-              entry=worker_shards_.at(group_id)->at(paramid);
-              for(auto x:HandlePut(entry, &msg))
-                msg_queue.push(x);
-              break;
-            default:
-              LOG(ERROR)<<"Unknow message type:"<<type;
-              break;
-          }
-        }else{
-          DeleteMsg(&msg);
-        }
-      }else{
-        int dst_procs_id;
-        if(dst_flag==kStub){
-          dst_procs_id=msg->dst_first();
-        }else{
-          dst_procs_id=cluster->ProcsIDOf(msg->dst_first(),
-              msg->dst_second(), msg->dst_flag());
-        }
-        if(dst_procs_id!=procs_id_){
-          // forward to other procs
-          if (interprocs_dealers.find(dst_procs_id)==interprocs_dealers.end()){
-            auto dealer=make_shared<Dealer>();
-            interprocs_dealers[dst_procs_id]=dealer;
-            while(cluster->endpoint(dst_procs_id)==""){
-              std::this_thread::sleep_for(
-                  std::chrono::milliseconds(3000));//kCollectSleepTime));
-              LOG(ERROR)<<"waiting for procs "<< dst_procs_id<<" to register";
-            }
-            dealer->Connect("tcp://"+cluster->endpoint(dst_procs_id));
-          }
-          if(bandwidth(amount, start) <=cluster->bandwidth()){
-            start=std::chrono::system_clock::now();
-            amount=0;
-          }
-          amount+=msg->size();
-          //LOG(ERROR)<<"send inter msg of type "<<msg->type();
-          interprocs_dealers[dst_procs_id]->Send(&msg);
-        }else{
-          if(type==kSyncRequest){
-            char buf[32];
-            sprintf(buf, "%d", cluster->bandwidth()-bandwidth(amount, start));
-            msg->add_frame(buf, strlen(buf));
-          }
-          router_->Send(&msg);
+    Msg* msg = msg_queue.front();
+    msg_queue.pop();
+    int type = msg->type(), dst = msg->dst(), flag = AddrType(dst);
+    if (flag == kStub && (AddrProc(dst) == procs_id_ || AddrGrp(dst) == -1)) {
+      if (type == kConnect) {
+        DeleteMsg(&msg);
+      } else if (type == kMetric) {
+        DisplayMetric(&msg);
+      } else if (type == kStop) {
+        int src_flag = AddrType(msg->src());
+        if (src_flag == kServer) nservers--;
+        else if (src_flag == kWorkerParam) nworkers--;
+        DeleteMsg(&msg);
+        if (nworkers == 0 && nservers == 0) break;
+      } else if (nserver_grps > 0) {
+        HandleLocalMsg(&msg_queue, &msg);
+      } else {
+        DeleteMsg(&msg);
+      }
+    } else {
+      int dst_procs = AddrProc(dst);
+      if (flag != kStub)
+        dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag);
+      if (dst_procs != procs_id_) {
+        if (bandwidth(trans_size, start) <= cluster->bandwidth()) {
+          start = std::chrono::system_clock::now();
+          trans_size = 0;
         }
+        trans_size += msg->size();
+
+        if (inter_dealers.find(dst_procs) == inter_dealers.end())
+          inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs);
+        inter_dealers[dst_procs]->Send(&msg);
+      } else {
+        if (type == kSyncRequest)
+          msg->AddFormatFrame("i", max_bandwidth - bandwidth(trans_size, 
start));
+        router_->Send(&msg);
       }
     }
   }
-  LOG(INFO)<<"Stub in process "<<procs_id_<<" stops";
+  LOG(ERROR) << "Stub in process " << procs_id_ << " stops";
+  for (auto& entry : inter_dealers)
+    delete entry.second;
 }
-Msg* Trainer::HandleConnect(Msg** msg){
-  string ping((char*)(*msg)->frame_data(), (*msg)->frame_size());
-  CHECK_STREQ("PING", ping.c_str());
-  // ping-pong for debug
-  (*msg)->SwapAddr();
-  Msg* reply=new Msg();
-  reply->SetAddr(*msg);
-  reply->add_frame("PONG", 4);
-  reply->set_type(kConnect);
+
+Msg* Trainer::GenSyncReminderMsg(int server, const vector<Server*>& servers ) {
+  Msg* msg = new Msg();
+  msg->set_src(Addr(-1,-1, kStub));
+  msg->set_dst(Addr(servers[server]->grp_id(), servers[server]->id(), 
kServer));
+  msg->set_type(kSyncReminder);
+  return msg;
+}
+
+void Trainer::DisplayMetric(Msg** msg) {
+  Msg* msgg = *msg;
+  // only display metrics from the first group
+  if (AddrGrp(msgg->src()) == 0) {
+    int step = msgg->trgt_version();
+    char prefix[128];
+    msgg->ParseFormatFrame("s", prefix);
+    CHECK(msgg->NextFrame());
+    const string perf(static_cast<char*>(msgg->FrameData()), 
msgg->FrameSize());;
+    Metric cur(perf);
+    LOG(ERROR) << prefix << " step-" << step <<", " << cur.ToLogString();
+  }
   DeleteMsg(msg);
-  return reply;
 }
-const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){
-  Msg* msgg=*msg;
-  vector<Msg*> replies;
-  int version=msgg->trgt_third();
-  if(msgg->src_flag()==kStub){
-    LOG(FATAL)<<"Not implemented";
-    /*
-    if(version<=pi->shares.at(0)->version()){
-      replies.push_back(pi->shares.at(0)->HandleGetMsg(msg));
-    }else if(version>pi->next_version){
-      // reinsert into a msg queue.
-      replies.push_back(mmsg);
-    }
-    */
-  }else if(version>pi->next_version){
-    pi->next_version=version;
-    int gid=msgg->src_first();
-    int group=gid/Cluster::Get()->nworker_groups_per_server_group();
-    auto param=pi->shares.at(0);
-    for(int idx=0, id=param->slice_start();idx<param->num_slices();idx++){
-      int server=slice2server_[id+idx];
-      int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
-      auto x=param->GenGetMsg(procs!=procs_id_, idx);
-      x->set_trgt(param->owner(), id+idx, param->local_version()+1);
-      x->set_src(procs_id_, gid, kStub);
-      x->set_dst(group, server, kServer);
-      //LOG(ERROR)<<"stub handle get for "<<idx+id<<","<<group<<","<<server;
-      replies.push_back(x);
+
+Dealer* Trainer::CreateInterProcsDealer(int dst_procs) {
+  // forward to other procs
+  auto cluster = Cluster::Get();
+  auto dealer = new Dealer();
+  while(cluster->endpoint(dst_procs)=="") {
+    //kCollectSleepTime));
+    std::this_thread::sleep_for(std::chrono::milliseconds(3000));
+    LOG(ERROR)<<"waiting for procs "<< dst_procs<<" to register";
+  }
+  dealer->Connect("tcp://"+cluster->endpoint(dst_procs));
+  return dealer;
+}
+
+void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) {
+  Msg* msgg = *msg;
+  int paramid = ParamID(msgg->trgt_val());
+  int type = msgg->type();
+  int grp;
+  ParamEntry *entry = nullptr;
+  switch (type) {  // TODO process other requests, e.g. RESTful
+    case kUpdate:
+      grp = AddrGrp(msgg->src());
+      entry = worker_shard_.at(Hash(grp, paramid));
+      for(auto update_msg : HandleUpdate(entry, msg))
+        msg_queue->push(update_msg);
+      break;
+    case kRUpdate:
+      grp = AddrGrp(msgg->dst());
+      entry = worker_shard_.at(Hash(grp, paramid));
+      HandleUpdateResponse(entry, msg);
+      break;
+    case kGet:
+      grp = AddrGrp(msgg->src());
+      entry = worker_shard_.at(Hash(grp, paramid));
+      for(auto get_msg : HandleGet(entry, msg))
+        msg_queue->push(get_msg);
+      break;
+    case kRGet:
+      grp = AddrGrp(msgg->dst());
+      entry = worker_shard_.at(Hash(grp, paramid));
+      HandleGetResponse(entry, msg);
+      break;
+    case kPut:
+      grp = AddrGrp(msgg->src());
+      entry = worker_shard_.at(Hash(grp, paramid));
+      for(auto put_msg : HandlePut(entry, msg))
+        msg_queue->push(put_msg);
+      break;
+    default:
+      LOG(ERROR)<<"Unknow message type:"<<type;
+      break;
+  }
+}
+
+void Trainer::GenMsgs(int type, int version, ParamEntry* entry,
+    Msg* msg, vector<Msg*> *ret) {
+  int src_grp = AddrGrp(msg->src());
+  int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group();
+  auto param=entry->shares.at(0);
+  for (int idx = 0 ; idx < param->num_slices(); idx++) {
+    int slice_id =param->slice_start() + idx;
+    int server = slice2server_[slice_id];
+    int procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer);
+    Msg* new_msg = nullptr;
+    if (type == kPut) {
+      CHECK_GT(entry->num_total, 0);
+      new_msg = param->GenPutMsg(procs != procs_id_, idx);
+      new_msg->AddFormatFrame("i", entry->num_total);
+    } else if (type == kGet) {
+      new_msg = param->GenGetMsg(procs != procs_id_, idx);
+    } else if (type == kUpdate) {
+      new_msg = param->GenUpdateMsg(procs != procs_id_, idx);
+      new_msg->AddFormatFrame("i", entry->num_local);
+    } else {
+      LOG(FATAL) << "Wrong type";
     }
+    new_msg->set_trgt(ParamTrgt(param->owner(), slice_id), version);
+    new_msg->set_src(Addr(src_grp, procs_id_, kStub));
+    new_msg->set_dst(Addr(dst_grp, server, kServer));
+    ret->push_back(new_msg);
   }
-  return replies;
 }
 
-const vector<Msg*> Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){
-  Msg* msgg=*msg ;
+const vector<Msg*> Trainer::HandleGet(ParamEntry* entry, Msg** msg) {
   vector<Msg*> ret;
-  int step= msgg->trgt_third();
-  if(msgg->src_flag()==kStub){
-    if(pi->num_update<pi->num_local){
-      ret.push_back(*msg);
-      return ret; //wait unitl local updates are ready
-    }
-    int n; sscanf((char*)(*msg)->frame_data(), "%d", &n);
-    pi->num_update+=n;
-    auto it=pi->shares.begin();
-    auto shape=mshadow::Shape1((*it)->size());
-    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
-    mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
-    agg+=grad;
-  }else if(++pi->num_update>=pi->num_local){
-    auto it=pi->shares.begin();
-    auto shape=mshadow::Shape1((*it)->size());
-    mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape);
-    for(++it;it!=pi->shares.end();it++){
-      mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
-      agg+=grad;
-    }
-    agg/=pi->num_total;
-    if(pi->num_local<pi->num_total){
-      /*
-      int gid=msgg->src_first();
-      for(auto update: pi->shares.at(0)->GenUpdateMsg(step)){
-        update->set_src(procs_id_, gid,kStub);
-        update->set_dst(pi->owner_procs, gid, kStub);
-        ret.push_back(update);
-      }
-      pi->num_update=0;
-      */
-    }
+  int version = (*msg)->trgt_version();
+  if (version > entry->next_version) {
+    entry->next_version = version;
+    GenMsgs(kGet, version, entry, *msg, &ret);
   }
-  if(pi->num_update==pi->num_total){
-    auto param=pi->shares.at(0);
-    int 
group=msgg->src_first()/Cluster::Get()->nworker_groups_per_server_group();
-    int srcgid=msgg->src_first();
-    for(int idx=0, id=param->slice_start(); idx<param->num_slices();idx++){
-      int server=slice2server_[idx+id];
-      int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
-      auto x=param->GenUpdateMsg(procs!=procs_id_, idx);
-      x->set_trgt(param->owner(), id+idx, step);
-      x->set_src(procs_id_, srcgid, kStub);
-      x->set_dst(group, server, kServer);
-      ret.push_back(x);
+  DeleteMsg(msg);
+  return ret;
+}
+
+const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg** msg) {
+  vector<Msg*> ret;
+  entry->num_update++;
+  if (entry->num_update >= entry->num_local) {
+    // average local gradient
+    if (entry->num_local > 1) {
+      auto it = entry->shares.begin();
+      auto shape=mshadow::Shape1((*it)->size());
+      mshadow::Tensor<mshadow::cpu,1> sum((*it)->mutable_cpu_grad(), shape);
+      for (++it; it != entry->shares.end(); it++) {
+        mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape);
+        sum += grad;
+      }
+      sum /= entry->num_total;
     }
-    pi->num_update=0;
+    int step = (*msg)->trgt_version();
+    GenMsgs(kUpdate, step, entry, *msg, &ret);
+    entry->num_update = 0;
   }
   DeleteMsg(msg);
   return ret;
 }
 
-const vector<Msg*> Trainer::HandlePut(shared_ptr<ParamInfo>pi, Msg** msg){
+const vector<Msg*> Trainer::HandlePut(ParamEntry* entry, Msg** msg) {
   vector<Msg*> ret;
-  CHECK_NE((*msg)->src_flag(), kStub);
-  int gid=(*msg)->src_first();
-  int version=(*msg)->trgt_third();
-  auto param=pi->shares.at(0);
-  int group=gid/Cluster::Get()->nworker_groups_per_server_group();
-  for(int idx=0, start=param->slice_start();idx<param->num_slices(); idx++){
-    int server=slice2server_[start+idx];
-    int procs=Cluster::Get()->ProcsIDOf(group, server, kServer);
-    auto x=param->GenPutMsg(procs!=procs_id_, idx);
-    x->set_trgt(param->owner(), start+idx, version);
-    x->set_src(procs_id_, gid, kStub);
-    x->set_dst(group, server, kServer);
-    ret.push_back(x);
-    //LOG(ERROR)<<"stub handle put "<<start+idx<<"to "<<group<<","<<server;
-  }
+  int version = (*msg)->trgt_version();
+  GenMsgs(kPut, version, entry, *msg, &ret);
   DeleteMsg(msg);
   return ret;
 }
 
-void Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){
-  int version=(*msg)->trgt_third();
-  int sliceid=(*msg)->trgt_second();
-  auto param=pi->shares.at(0);
-  if(param->ParseGetResponseMsg(msg,sliceid-param->slice_start()))
+void Trainer::HandleGetResponse(ParamEntry* entry, Msg** msg) {
+  int version = (*msg)->trgt_version();
+  int sliceid = SliceID((*msg)->trgt_val());
+  auto param = entry->shares.at(0);
+  if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start()))
     param->set_version(version);
-  // process get requests in waiting queue
+  DeleteMsg(msg);
 }
 
-
-void Trainer::HandleUpdateResponse(shared_ptr<ParamInfo> pi, Msg** msg){
-  int sliceid=(*msg)->trgt_second();
-  int version=(*msg)->trgt_third();
-  auto param=pi->shares.at(0);
-  if(param->ParseUpdateResponseMsg(msg,sliceid-param->slice_start())){
+void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) {
+  int version = (*msg)->trgt_version();
+  int sliceid = SliceID((*msg)->trgt_val());
+  auto param = entry->shares.at(0);
+  if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start()))
     param->set_version(version);
-  }
+  DeleteMsg(msg);
 }
 } /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
index 80a6283..bf98f0b 100644
--- a/src/trainer/worker.cc
+++ b/src/trainer/worker.cc
@@ -1,300 +1,328 @@
 #include <glog/logging.h>
 #include <thread>
-#include <memory>
-#include <iostream>
 #include <chrono>
 #include <thread>
 #include "utils/singleton.h"
+#include "utils/cluster.h"
 #include "utils/factory.h"
 #include "trainer/worker.h"
 #include "proto/model.pb.h"
+
 namespace singa {
 using std::thread;
-using std::make_shared;
 
-Worker::Worker(int thread_id, int group_id, int worker_id):
-  thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){
+Worker::Worker(int thread_id, int grp_id, int id):
+  thread_id_(thread_id), grp_id_(grp_id), id_(id),
+  layer_dealer_(nullptr), dealer_(nullptr), updater_(nullptr) {
 }
 
-void Worker::Setup(const ModelProto& model,
-    shared_ptr<NeuralNet> train_net){
-  train_net_=train_net;
-  modelproto_=model;
-  auto cluster=Cluster::Get();
-  if(!(cluster->nserver_groups()&&cluster->server_update())){
-    updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
-        ->Create("Updater"));
+void Worker::Setup(
+    const ModelProto& model, shared_ptr<NeuralNet> train_net,
+    shared_ptr<NeuralNet> valid_net, shared_ptr<NeuralNet> test_net) {
+  modelproto_.CopyFrom(model);
+  train_net_ = train_net;
+  validation_net_ = valid_net;
+  test_net_ = test_net;
+  auto cluster = Cluster::Get();
+  // if no server or user requires worker to do param update
+  if (!(cluster->nserver_groups() && cluster->server_update())) {
+    updater_ = Singleton<Factory<Updater>>::Instance()->Create("Updater");
     updater_->Init(model.updater());
   }
 }
 
-void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){
-  if(updater_==nullptr){
-    auto cluster=Cluster::Get();
-    int sgid=group_id_/cluster->nworker_groups_per_server_group();
-    CHECK(cluster->runtime()->JoinSGroup(group_id_, worker_id_, sgid));
+Worker::~Worker() {
+  if (updater_ != nullptr)
+    delete updater_;
+  if (layer_dealer_)
+    delete layer_dealer_;
+  if (dealer_)
+    delete dealer_;
+}
+
+void Worker::InitLocalParams() {
+  // for each server grp, its first subscriber worker grp does the param init
+  if (grp_id_ % Cluster::Get()->nworker_groups_per_server_group() == 0) {
+    for (auto layer: train_net_->layers()){
+      if (layer->partition_id() == id_) {
+        for (auto param : layer->GetParams()) {
+          // only owners fill the memory of parameter values.
+          if(param->owner() == param->id())
+            param->InitValues(0);
+        }
+      }
+    }
+    Metric perf;
+    // warmup training before put params to servers
+    for (; step_ < modelproto_.warmup_steps(); step_++)
+      TrainOneBatch(step_, &perf);
+    for (auto layer : train_net_->layers()) {
+      if (layer->partition_id() == id_)
+        for (auto param : layer->GetParams())
+          if (param->owner() == param->id())
+            Put(param, step_);
+    }
+  }
+  // wait owners in the same procs init params, then no get requests sent
+  std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+  for (auto layer : train_net_->layers()) {
+    if (layer->partition_id() == id_)
+      for (auto param : layer->GetParams())
+        if (param->owner() != param->id())
+          Get(param, modelproto_.warmup_steps());
   }
+}
 
+void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) {
   dealer->Connect(kInprocRouterEndpoint);
-  Msg* ping=new Msg();
-  ping->set_src(group_id_, worker_id_, type);
-  ping->set_dst(-1,-1,kStub);
+  Msg* ping = new Msg(Addr(grp, id, entity), Addr(-1, -1, kStub));
   ping->set_type(kConnect);
-  ping->add_frame("PING", 4);
   dealer->Send(&ping);
-  ping=dealer->Receive();
-  string pong((char*)ping->frame_data(), ping->frame_size());
-  CHECK_STREQ("PONG", pong.c_str());
-  delete ping;
 }
 
-void Worker::Run(){
-  LOG(ERROR)<<"Worker (group_id = "<<group_id_
-    <<", id = "<<worker_id_<<") starts";
-  dealer_=make_shared<Dealer>(2*thread_id_);
-  ConnectStub(dealer_, kWorkerParam);
-  for(auto layer: train_net_->layers())
-    if(layer->partition_id()==worker_id_)
-      if(layer->is_bridgedstlayer()||layer->is_bridgesrclayer()){
-        layer_dealer_=make_shared<Dealer>(2*thread_id_+1);
-        ConnectStub(layer_dealer_, kWorkerLayer);
+void Worker::Run() {
+  LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") start";
+  auto cluster = Cluster::Get();
+  if (updater_==nullptr) {
+    int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
+    CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp));
+  }
+  dealer_ = new Dealer(2*thread_id_);
+  ConnectStub(grp_id_, id_, dealer_, kWorkerParam);
+  for (auto layer : train_net_->layers()) {
+    if (layer->partition_id() == id_) {
+      if (layer->is_bridgelayer()) {
+        layer_dealer_ = new Dealer(2*thread_id_+1);
+        ConnectStub(grp_id_, id_, layer_dealer_, kWorkerLayer);
         break;
       }
-  step_=modelproto_.step();
-  // init params
-  for(auto layer: train_net_->layers()){
-    if(layer->partition_id()==worker_id_)
-      for(auto param: layer->GetParams()){
-        // only owners fill the memory of parameter values.
-        // others share the memory with owners hence do not need to put/get.
-        if(param->owner() == param->id()){
-          if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0)
-            param->InitValues(0);
-          else{
-            Get(param, modelproto_.warmup_steps());
-          }
-        }
-      }
+    }
   }
+
+  step_ = modelproto_.step();
+  InitLocalParams();
   Metric perf;
-  if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0){
-    for(step_=0;step_<modelproto_.warmup_steps();step_++)
-      RunOneBatch(step_, &perf);
-    for(auto layer: train_net_->layers()){
-      if(layer->partition_id()==worker_id_)
-        for(auto param: layer->GetParams())
-          if(param->owner()==param->id())
-            Put(param, step_);
+  while (!StopNow(step_)) {
+    if (ValidateNow(step_)) {
+      //LOG(ERROR)<<"Validation at step "<<step;
+      CollectAll(validation_net_, step_);
+      Test(modelproto_.validation_steps(), kValidation, validation_net_);
+    }
+    if (TestNow(step_)) {
+      //LOG(ERROR)<<"Test at step "<<step;
+      CollectAll(test_net_, step_);
+      Test(modelproto_.test_steps(), kTest, test_net_);
+    }
+    TrainOneBatch(step_, &perf);
+    //LOG(ERROR)<<"Train "<<step;
+    if (DisplayNow(step_)) {
+      Report("Train", perf);
+      perf.Reset();
     }
-  }
-  while(!StopNow(step_)){
-    RunOneBatch(step_, &perf);
     step_++;
   }
 
-  Stop();
-  LOG(ERROR)<<"Worker (group_id = "<<group_id_
-    <<", id = "<<worker_id_<<") stops";
-}
-
-void Worker::Stop(){
-  auto cluster=Cluster::Get();
-  if(updater_ == nullptr){
-    int sgid=group_id_/cluster->nworker_groups_per_server_group();
-    cluster->runtime()->LeaveSGroup(group_id_, worker_id_, sgid);
+  // clean up
+  if(updater_ == nullptr) {
+    int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group();
+    cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp);
   }
-  Msg* msg=new Msg();
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(-1,-1, kStub);
+  // notify the stub on worker stop
+  Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1,-1, kStub));
   msg->set_type(kStop);
-  dealer_->Send(&msg); // use param dealer to send the stop msg
+  dealer_->Send(&msg);  // use param dealer to send the stop msg
+
+  LOG(ERROR) << "Worker (group = " <<grp_id_ << ", id = " << id_ << ") stop";
+}
+
+void Worker::Resume() {
+  // TODO(wangwei)
 }
-int Worker::Put(Param* param, int step){
-  Msg* msg=new Msg();
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(-1, -1, kStub);
+
+int Worker::Put(Param* param, int step) {
+  Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
   msg->set_type(kPut);
-  msg->set_trgt(param->owner(), 0, step);
   dealer_->Send(&msg);
   return 1;
 }
-int Worker::Get(Param* param, int step){
-  Msg* msg=new Msg();
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(-1, -1, kStub);
+
+int Worker::Get(Param* param, int step) {
+  if (param->version() >= step)
+    return 1;
+  Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(ParamTrgt(param->owner(), 0), step);
   msg->set_type(kGet);
-  msg->set_trgt(param->owner(), 0, step);
   dealer_->Send(&msg);
   return 1;
 }
-int Worker::Update(Param* param, int step){
+
+int Worker::Update(Param* param, int step) {
   param->set_local_version(param->version());
-  if(updater_){
+  if (updater_) {
     updater_->Update(step, param);
-    param->set_version(param->version()+1);
-  }else{
-    Msg* msg=new Msg();
-    msg->set_src(group_id_, worker_id_, kWorkerParam);
-    msg->set_dst(-1, -1, kStub);
+    param->set_version(param->version() + 1);
+  } else {
+    Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+    msg->set_trgt(ParamTrgt(param->owner(), 0), step);
     msg->set_type(kUpdate);
-    msg->set_trgt(param->owner(), 0, step);
     dealer_->Send(&msg);
   }
   return 1;
 }
 
-int Worker::CollectAll(shared_ptr<NeuralNet> net, int step){
-  auto& layers=net->layers();
-  for(auto& layer: layers){
-    if(layer->partition_id()==worker_id_)
-      for(Param* p: layer->GetParams()){
+int Worker::CollectAll(shared_ptr<NeuralNet> net, int step) {
+  auto& layers = net->layers();
+  for (auto& layer : layers){
+    if (layer->partition_id() == id_)
+      for (Param* p: layer->GetParams()) {
         Collect(p, step);
       }
   }
   return 1;
 }
-int Worker::Collect(Param* param, int step){
-  while(param->version()<=param->local_version()){
+int Worker::Collect(Param* param, int step) {
+  while (param->version() <= param->local_version())
     std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime));
-  }
   return 1;
 }
-void Worker::DisplayPerformance(const string& prefix, const Metric & perf) {
-  Msg* msg=new Msg();
-  msg->set_src(group_id_, worker_id_, kWorkerParam);
-  msg->set_dst(-1,-1, kStub);
+void Worker::Report(const string& prefix, const Metric & perf) {
+  Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub));
+  msg->set_trgt(0, step_);
   msg->set_type(kMetric);
-  msg->set_trgt(step_,0,0);
-  msg->add_frame(prefix.c_str(), prefix.length());
   const string disp = perf.ToString();
-  msg->add_frame(disp.c_str(), disp.length());
+  msg->AddFormatFrame("s", prefix.c_str());
+  msg->AddFrame(disp.c_str(), disp.length());
   dealer_->Send(&msg);
 }
 
-void Worker::RunOneBatch(int step, Metric* perf){
-  if(ValidateNow(step)){
-    //LOG(ERROR)<<"Validation at step "<<step;
-    CollectAll(validation_net_, step);
-    Test(modelproto_.validation_steps(),kValidation, validation_net_);
-  }
-  if(TestNow(step)){
-    //LOG(ERROR)<<"Test at step "<<step;
-    CollectAll(test_net_, step);
-    Test(modelproto_.test_steps(), kTest, test_net_);
+void Worker::ReceiveBlobs(
+    bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net) {
+  while (!layer->ready()) {
+    auto msg = layer_dealer_->Receive();
+    CHECK_EQ(AddrGrp(msg->src()), grp_id_);
+    string name(static_cast<char*>(msg->FrameData()), msg->FrameSize());
+    auto receive_layer = net->name2layer(name);
+    CHECK(receive_layer->is_bridgelayer());
+    auto data = receive_layer->mutable_data(nullptr);
+    msg->NextFrame();
+    memcpy(data->mutable_cpu_data(), msg->FrameData(), msg->FrameSize());
+    static_cast<BridgeLayer*>(receive_layer)->set_ready(true);
+    delete msg;
   }
-  TrainOneBatch(step, perf);
-  //LOG(ERROR)<<"Train "<<step;
-  if(perf!=nullptr){
-    if(DisplayNow(step)){
-      DisplayPerformance("Train", *perf);
-      perf->Reset();
-    }
-  }
-  /*
-  if(CheckpointNow(step)){
-    pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step));
-  }
-  */
-}
-
-void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){
 }
 
-void Worker::SendBlob(){
+void Worker::SendBlobs(
+    bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net) {
+  auto dst=layer->dstlayers().at(0);
+  Msg *msg=new Msg();
+  msg->set_src(Addr(grp_id_, id_, kWorkerLayer));
+  msg->set_dst(Addr(grp_id_, dst->partition_id(), kWorkerLayer));
+  msg->AddFrame(dst->name().c_str(), dst->name().length());
+  auto const & blob=layer->data(nullptr);
+  msg->AddFrame(blob.cpu_data(), blob.count()*sizeof(float));
+  layer_dealer_->Send(&msg);
 }
 
-void Worker::Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net){
+void Worker::Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net) {
   Metric perf;
-  for(int step=0;step<nsteps;step++){
+  for (int step = 0; step < nsteps; step++)
     TestOneBatch(step, phase, net, &perf);
-  }
-  //perf.Avg();
-  if(phase==kValidation)
-    DisplayPerformance("Validation", perf);
-  else if (phase==kTest)
-    DisplayPerformance("Test", perf);
+  if (phase == kValidation)
+    Report("Validation", perf);
+  else if (phase == kTest)
+    Report("Test", perf);
+}
+bool Worker::DisplayNow(int step) const {
+  return (modelproto_.display_frequency() > 0
+      && step >= modelproto_.display_after_steps()
+      && ((step - modelproto_.display_after_steps())
+        % modelproto_.display_frequency() == 0));
 }
 
-/****************************BPWorker**********************************/
+bool Worker::DisplayDebugInfo(int step) const {
+  return DisplayNow(step) && modelproto_.debug() && grp_id_ == 0;
+}
+bool Worker::StopNow(int step) const {
+  return step >= modelproto_.train_steps();
+}
+bool Worker::CheckpointNow(int step) const {
+  return (grp_id_ == 0
+      && modelproto_.checkpoint_frequency() > 0
+      && step >= modelproto_.checkpoint_after_steps()
+      && ((step - modelproto_.checkpoint_after_steps())
+        % modelproto_.checkpoint_frequency() == 0));
+}
+bool Worker::TestNow(const int step) const {
+  return (grp_id_ == 0
+      && modelproto_.test_frequency() > 0
+      && modelproto_.test_steps() > 0
+      && step >= modelproto_.test_after_steps()
+      && ((step - modelproto_.test_after_steps())
+        % modelproto_.test_frequency() == 0));
+}
+bool Worker::ValidateNow(const int step) const {
+  return (grp_id_ == 0
+      && modelproto_.validation_frequency() > 0
+      && modelproto_.validation_steps() > 0
+      && step >= modelproto_.validation_after_steps()
+      && ((step - modelproto_.validation_after_steps())
+        % modelproto_.validation_frequency() == 0));
+}
 
+
+/****************************BPWorker**********************************/
 BPWorker::BPWorker(int thread_id, int group_id, int worker_id):
-  Worker(thread_id, group_id, worker_id){
+  Worker(thread_id, group_id, worker_id) {
 }
 
-void BPWorker::Forward(int step, Phase phase, shared_ptr<NeuralNet> net,
-    Metric* perf){
-  auto& layers=net->layers();
-  for(auto& layer: layers){
-    if(layer->partition_id()==worker_id_){
-      if(layer->is_bridgedstlayer()){
-        auto* dst=static_cast<BridgeDstLayer*>(layer);
-        while(!dst->ready()){
-          auto msg=layer_dealer_->Receive();
-          CHECK_EQ(msg->src_first(), group_id_);
-          string name((char*)msg->frame_data(), msg->frame_size());
-          auto tmp=net->name2layer(name);
-          CHECK(tmp->is_bridgedstlayer());
-          auto* dstlayer=static_cast<BridgeDstLayer*>(tmp);
-          auto data=dstlayer->mutable_data(nullptr);
-          msg->next_frame();
-          memcpy(data->mutable_cpu_data(), msg->frame_data(), 
msg->frame_size());
-          dstlayer->set_ready(true);
-          delete msg;
-        }
-      }
-      if(phase==kTrain){
-        for(Param* p: layer->GetParams()){
+void BPWorker::Forward(
+    int step, Phase phase, shared_ptr<NeuralNet> net, Metric* perf) {
+  for (auto& layer : net->layers()) {
+    if (layer->partition_id() == id_) {
+      if (layer->is_bridgedstlayer())  // recv data from other workers
+        ReceiveBlobs(true, false, static_cast<BridgeLayer*>(layer), net);
+      if (phase == kTrain) {
+        for (Param* p : layer->GetParams()) {  // wait until param is updated
           Collect(p, step);
         }
       }
-      //clock_t s=clock();
       layer->ComputeFeature(phase, perf);
-      //LOG(ERROR)<<layer->name()<<":"<<(clock()-s)*1.0/CLOCKS_PER_SEC;
-      if(layer->is_bridgesrclayer()){
-        auto dst=layer->dstlayers().at(0);
-        Msg *msg=new Msg();
-        msg->set_src(group_id_, worker_id_, kWorkerLayer);
-        msg->set_dst(group_id_, dst->partition_id(), kWorkerLayer);
-        msg->add_frame(dst->name().c_str(), dst->name().length());
-        auto const & blob=layer->data(nullptr);
-        msg->add_frame(blob.cpu_data(), blob.count()*sizeof(float));
-        layer_dealer_->Send(&msg);
-      }
-      if(phase == kTrain && DisplayDebugInfo(step))
+      if (layer->is_bridgesrclayer())  // send data to other workers
+        SendBlobs(true, false, static_cast<BridgeLayer*>(layer), net);
+      if (DisplayDebugInfo(step))
         LOG(INFO) << layer->DebugString(step, kForward);
     }
   }
 }
 
-void BPWorker::Backward(int step, shared_ptr<NeuralNet> net){
+void BPWorker::Backward(int step, shared_ptr<NeuralNet> net) {
   auto& layers=net->layers();
   for (auto it = layers.rbegin(); it != layers.rend(); it++){
-    Layer* layer=*it;
-    if(layer->partition_id()==worker_id_){
-      if(layer->is_bridgesrclayer()){
-        //auto* src=static_cast<BridgeSrcLayer*>(layer.get());
-        // receive grad blobs
+    Layer* layer = *it;
+    if (layer->partition_id() == id_) {
+      if(layer->is_bridgesrclayer()) {
+        // ReceiveBlobs(false, true, layer, net);
       }
       layer->ComputeGradient(kTrain);
-      if(DisplayDebugInfo(step))
+      if (DisplayDebugInfo(step))
         LOG(INFO) << layer->DebugString(step, kBackward);
-      for(Param* p: layer->GetParams())
+      for (Param* p : layer->GetParams())
         Update(p, step);
-      if(layer->is_bridgedstlayer()){
-        // send grad blobs
+      if (layer->is_bridgedstlayer()) {
+        // SendBlobs(false, true, layer);
       }
     }
   }
 }
 
-void BPWorker::TrainOneBatch(int step, Metric* perf){
+void BPWorker::TrainOneBatch(int step, Metric* perf) {
   Forward(step, kTrain, train_net_, perf);
   Backward(step, train_net_);
-  auto losslayers=train_net_->losslayers();
 }
 
 void BPWorker::TestOneBatch(int step, Phase phase,
-    shared_ptr<NeuralNet> net, Metric* perf){
+    shared_ptr<NeuralNet> net, Metric* perf) {
   Forward(step, phase, net, perf);
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
index 0c4eefa..9c57c42 100644
--- a/src/utils/cluster.cc
+++ b/src/utils/cluster.cc
@@ -56,13 +56,14 @@ Cluster::Cluster(const GlobalProto & global, const 
ClusterProto &cluster,
   hostip_=GetHostIP();
 }
 
-void Cluster::Register(const string& endpoint){
+void Cluster::Register(const string& endpoint) {
   procs_id_=cluster_rt_->RegistProc(endpoint);
   CHECK_GE(procs_id_,0);
   CHECK_LT(procs_id_,nprocs());
   LOG(ERROR) << "proc #" << procs_id_ << " -> " << endpoint;
 }
-const string Cluster::endpoint(int procsid) const{
+
+const string Cluster::endpoint(int procsid) const {
   CHECK_LT(procsid, nprocs());
   CHECK_GE(procsid, 0);
   if(endpoints_.size())
@@ -70,6 +71,7 @@ const string Cluster::endpoint(int procsid) const{
   else
     return cluster_rt_->GetProcHost(procsid);
 }
+
 void Cluster::SetupFolders(const ClusterProto &cluster){
   // create visulization folder
   mkdir(vis_folder().c_str(),  S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
index 11a19f8..f733497 100644
--- a/src/utils/common.cc
+++ b/src/utils/common.cc
@@ -160,6 +160,10 @@ void SetupLog(const std::string& log_dir, const 
std::string& model) {
   google::SetLogDestination(google::FATAL, fatal.c_str());
 }
 
+Metric::Metric(const std::string& str) {
+  ParseFrom(str);
+}
+
 void Metric::Add(const string& name, float value) {
   if(entry_.find(name) == entry_.end())
     entry_[name] = std::make_pair(1, value);
@@ -176,7 +180,7 @@ void Metric::Reset() {
     e.second.second = 0;
   }
 }
-const string Metric::ToLogString() const{
+const string Metric::ToLogString() const {
   string ret;
   size_t k = 0;
   for(auto e : entry_) {
@@ -188,7 +192,7 @@ const string Metric::ToLogString() const{
   return ret;
 }
 
-const string Metric::ToString() const{
+const string Metric::ToString() const {
   MetricProto proto;
   for(auto e : entry_) {
     proto.add_name(e.first);
@@ -208,4 +212,89 @@ void Metric::ParseFrom(const string& msg) {
     entry_[proto.name(i)] = std::make_pair(proto.count(i), proto.val(i));
   }
 }
+
+
+const vector<vector<int>> Slice(int num, const vector<int>& sizes) {
+  vector<vector<int>> slices;
+  if (num == 0)
+    return slices;
+  int avg = 0;
+  for(int x : sizes)
+      avg += x;
+  avg = avg / num + avg % num;
+  int diff = avg / 10;
+  LOG(INFO) << "Slicer, param avg=" << avg << ", diff= " << diff;
+
+  int capacity = avg, nbox = 0;
+  for (int x : sizes) {
+    vector<int> slice;
+    string slicestr = "";
+    while (x > 0) {
+      int size=0;
+      if (capacity >= x) {
+        capacity -= x;
+        size = x;
+        x = 0;
+      }else if(capacity + diff >= x) {
+        size = x;
+        x = 0;
+        capacity = 0;
+      }else if (capacity >= diff) {
+        x -= capacity;
+        size = capacity;
+        capacity = avg;
+        nbox++;
+      } else {
+        capacity = avg;
+        nbox++;
+      }
+      if (size) {
+        slice.push_back(size);
+        slicestr += ", " + std::to_string(size);
+      }
+    }
+    LOG(INFO) << slicestr;
+    slices.push_back(slice);
+  }
+  CHECK_LE(nbox, num);
+  return slices;
+}
+
+const vector<int> PartitionSlices(int num, const vector<int>& slices) {
+  vector<int> slice2box;
+  if (num == 0)
+    return slice2box;
+  int avg = 0;
+  for(int x : slices)
+    avg += x;
+  avg = avg / num + avg % num;
+  int box = avg, boxid = 0, diff = avg / 10;
+  for (auto it = slices.begin(); it != slices.end();) {
+    int x = *it;
+    if (box >= x) {
+      box -= x;
+      slice2box.push_back(boxid);
+      it++;
+    } else if (box + diff >= x) {
+      slice2box.push_back(boxid);
+      it++;
+      box = 0;
+    } else {
+      box = avg;
+      boxid++;
+    }
+  }
+  CHECK_EQ(slice2box.size(), slices.size());
+  int previd = -1;
+  std::string disp;
+  for (size_t i = 0; i < slice2box.size(); i++) {
+    if (previd != slice2box[i]) {
+      previd = slice2box[i];
+      disp += " box = " +std::to_string(previd) + ":";
+    }
+    disp += " " + std::to_string(slices[i]);
+  }
+  LOG(INFO) << "partition slice (avg =" << avg << ", num="<<num<<"):" << disp;
+  return slice2box;
+}
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 24a0541..8b1f113 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -11,8 +11,8 @@ using std::vector;
 using std::string;
 namespace singa {
 
-Param::Param():data_(nullptr), slice_start_(0), num_slices_(0),
-  num_pending_requests_(0),local_version_(-1){
+Param::Param():local_version_(-1), slice_start_(0), num_slices_(0),
+  num_pending_requests_(0), data_(nullptr) {
 }
 void Param::Setup(const ParamProto& proto, const vector<int>& shape){
   data_=std::make_shared<Blob<float>>(shape);
@@ -82,96 +82,87 @@ void Param::InitValues(int version){
 }
 
 /**************Message related functions********/
-Msg* Param::GenPutMsg(bool copy, int idx){
+Msg* Param::GenPutMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kPut);
-  char buf[128];
-  sprintf(buf, "%d %f %f", slice_size_[idx],
-      learning_rate_multiplier(), weight_decay_multiplier());
   void *ptr=mutable_cpu_data()+slice_offset_[idx];
-  if(copy){
-    sprintf(buf+strlen(buf), " %p ", nullptr);
-    msg->add_frame(buf, strlen(buf));
-    msg->add_frame(ptr, slice_size_[idx]*sizeof(float));
-  }else{
-    sprintf(buf+strlen(buf), " %p ", ptr);
-    msg->add_frame(buf, strlen(buf));
+  void *p = ptr;
+  if (copy) p = nullptr;
+  msg->AddFormatFrame("iffp", slice_size_[idx],
+      learning_rate_multiplier(), weight_decay_multiplier(), p);
+  if (copy) {
+    msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
   }
   //pending_put_[idx]=true;
   //num_pending_requests_++;
        return msg;
 }
 
-Msg* Param::GenGetMsg(bool copy, int idx){
+Msg* Param::GenGetMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kGet);
-  char buf[32]; sprintf(buf, " %d %p ", copy,
-      data_->cpu_data()+slice_offset_[idx]);
-  msg->add_frame(buf, sizeof(buf));
+  msg->AddFormatFrame("ip",  copy, data_->cpu_data()+slice_offset_[idx]);
   pending_get_[idx]=true;
   num_pending_requests_++;
   return msg;
 }
 
-Msg* Param::GenUpdateMsg(bool copy, int idx){
+Msg* Param::GenUpdateMsg(bool copy, int idx) {
   CHECK_LT(idx, num_slices_);
   Msg* msg=new Msg();
   msg->set_type(kUpdate);
-  char buf[8]; sprintf(buf, " %d ", copy);
-  msg->add_frame(buf, sizeof(buf));
+  msg->AddFormatFrame("i", copy);
   void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx];
   if(copy){
     //LOG(ERROR)<<"Copy in gen update";
-    msg->add_frame(ptr, slice_size_[idx]*sizeof(float));
-  }
-  else{ // to share values of grad blob
-    char buf[32]; sprintf(buf, " %p ", ptr);
-    msg->add_frame(buf, strlen(buf));
+    msg->AddFrame(ptr, slice_size_[idx]*sizeof(float));
+  } else { // to share values of grad blob
+    msg->AddFormatFrame("p", ptr);
   }
   pending_update_[idx]=true;
   num_pending_requests_++;
   return msg;
 }
 
-Msg* Param::GenSyncMsg(int offset, int size){
+Msg* Param::GenSyncMsg(int offset, int size) {
   Msg* msg=new Msg();
   msg->set_type(kSyncRequest);
-  msg->set_trgt(-1, id(), local_version());
-  msg->add_frame(mutable_cpu_data(), data_->count()*sizeof(float));
+  msg->set_trgt(ParamTrgt(-1, id()), local_version());
+  // always copy data because syn is between server groups in diff procs
+  msg->AddFrame(mutable_cpu_data(), data_->count()*sizeof(float));
   return msg;
 }
 
-Msg* Param::HandlePutMsg(Msg** msg){
+Msg* Param::HandlePutMsg(Msg** msg, bool reserve) {
   int size;
   float lr, wc;
   float* ptr;
-  sscanf(static_cast<char*>((*msg)->frame_data()),
-      "%d %f %f %p ", &size, &lr, &wc, &ptr);
+  (*msg)->ParseFormatFrame("iffp", &size, &lr, &wc, &ptr);
   proto_.set_learning_rate_multiplier(lr);
   proto_.set_weight_decay_multiplier(wc);
   vector<int> shape{size};
   ParamProto proto;
   Setup(proto, shape);
-  if(ptr==nullptr){
-    CHECK((*msg)->next_frame());
-    CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
-    memcpy(mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float));
+  if (ptr == nullptr) {
+    CHECK((*msg)->NextFrame());
+    CHECK_EQ(size* sizeof(float), (*msg)->FrameSize());
+    memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float));
   }else{
     data_->set_cpu_data(ptr);
   }
-  DeleteMsg(msg);
+  if (!reserve)
+    DeleteMsg(msg);
   return nullptr;
 }
 
-Msg* Param::HandleGetMsg(Msg** msg){
+Msg* Param::HandleGetMsg(Msg** msg, bool reserve) {
   int copy;
   float* ptr;
-  sscanf(static_cast<char*>((*msg)->frame_data()), " %d %p ", &copy, &ptr);
-  (*msg)->next_frame();
+  (*msg)->ParseFormatFrame("ip", &copy, &ptr);
   if(copy)
-    (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size());
+    (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size());
   else if(ptr!=data_->cpu_data()){
     memcpy(ptr, data_->cpu_data(), sizeof(float)*size());
     data_->set_cpu_data(ptr);
@@ -182,73 +173,127 @@ Msg* Param::HandleGetMsg(Msg** msg){
   return *msg;
 }
 
-int Param::ParseUpdateMsg(Msg** msg){
-  int copy;
-  sscanf(static_cast<char*>((*msg)->frame_data()), " %d ", &copy);
-  (*msg)->next_frame();
-  if(copy){
-    //LOG(ERROR)<<"Copy in parse update";
-    CHECK((*msg)->frame_size());
-    memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size());
-  }else {// use the same data field of the grad blob
-    float* ptr=nullptr;
-    sscanf(static_cast<char*>((*msg)->frame_data()), " %p ", &ptr);
-    grad_.set_cpu_data(ptr);
+void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) {
+  bool reset = true;
+  vector<int> copies;
+  for (auto *msg : msgs) {
+    int copy;
+    msg->ParseFormatFrame("i", &copy);
+    reset = reset && copy;
+    copies.push_back(copy);
+  }
+  int idx = 0;
+  for (auto *msg : msgs) {
+    CHECK(msg->NextFrame());
+    if (copies.at(idx++)) {
+      float* server_grad = mutable_cpu_grad();
+      float* worker_grad = static_cast<float*> (msg->FrameData());
+      if (reset) {
+        memcpy(server_grad, worker_grad, sizeof(float) * size());
+        reset = false;
+      } else {
+        for (int i =0; i < size(); i++)
+          server_grad[i] += worker_grad[i];
+      }
+    } else {
+      float* ptr = nullptr;
+      msg->ParseFormatFrame("p", &ptr);
+      if (grad_.cpu_data() != ptr) {
+        memcpy(ptr, grad_.cpu_data(), msg->FrameSize());
+        grad_.set_cpu_data(ptr);
+      }
+    }
   }
-  DeleteMsg(msg);
-  return copy;
-}
 
-Msg* Param::GenUpdateResponseMsg(bool copy){
-  Msg* msg=new Msg();
-  msg->set_type(kRUpdate);
-  char buf[8]; sprintf(buf, " %d ", copy);
-  msg->add_frame(buf, sizeof(buf));
-  if(copy){
-    //LOG(ERROR)<<"Copy in gen";
-  //  LOG(ERROR)<<"gen copy resonse for "<<id()<<", "<<size();
-    msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+  if (msgs.size() > 1) {
+    float* server_grad = mutable_cpu_grad();
+    for (int i = 0; i < size(); i++)
+      server_grad[i] /= msgs.size();
   }
-  //  LOG(ERROR)<<"gen share resonse for "<<id()<<", "<<size();
+}
 
-  return msg;
+const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) {
+  vector<Msg*> ret;
+  for (auto msg : msgs) {
+    msg->FirstFrame();
+    msg->SwapAddr();
+    msg->set_type(kRUpdate);
+    int copy;
+    msg->ParseFormatFrame("i", &copy);
+    if (copy) {
+      msg->NextFrame();
+      CHECK_EQ(msg->FrameSize(), sizeof(float) * size());
+      memcpy(msg->FrameData(), mutable_cpu_data(), msg->FrameSize());
+    }
+    ret.push_back(msg);
+  }
+  return ret;
 }
 
-Msg* Param::HandleSyncMsg(Msg** msg){
-  DeleteMsg(msg);
+Msg* Param::HandleSyncMsg(Msg** msg, bool reserve) {
+  if (!reserve)
+    DeleteMsg(msg);
   return nullptr;
 }
 
-int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){
-  DeleteMsg(msg);
+int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) {
   return 1;
 }
 
-int Param::ParseGetResponseMsg(Msg **msg, int slice_idx){
+int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) {
   CHECK_EQ(pending_get_[slice_idx], true);
   pending_get_[slice_idx]=false;
   ParseResponseMsg(msg, slice_idx);
   return (--num_pending_requests_)%num_slices_==0;
 }
 
-int Param::ParseUpdateResponseMsg(Msg **msg, int slice_idx){
+int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) {
   CHECK_EQ(pending_update_[slice_idx], true);
   pending_update_[slice_idx]=false;
   ParseResponseMsg(msg, slice_idx);
-  return (--num_pending_requests_)%num_slices_==0;
+  return (--num_pending_requests_) % num_slices_==0;
 }
 
-void Param::ParseResponseMsg(Msg** msg, int slice_idx){
+void Param::ParseResponseMsg(Msg* msg, int slice_idx) {
   int copy;
-  sscanf(static_cast<char*>((*msg)->frame_data()), " %d ", &copy);
-  (*msg)->next_frame();
-  if(copy){
-        CHECK_EQ((*msg)->frame_size(), slice_size_[slice_idx]*sizeof(float));
+  msg->ParseFormatFrame("i", &copy);
+  msg->NextFrame();
+  if(copy) {
+    CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float));
     memcpy(mutable_cpu_data()+slice_offset_[slice_idx],
-        (*msg)->frame_data(), (*msg)->frame_size());
+        msg->FrameData(), msg->FrameSize());
   }
   //LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id();
-  DeleteMsg(msg);
+}
+
+void Param::ShareFrom(const Param& other) {
+  proto_.set_owner(other.owner());
+  if(data_!=nullptr)
+    CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
+          other.data_->shape().begin()));
+  data_ = other.data_;
+  slice_offset_ = other.slice_offset_;
+  slice_size_ = other.slice_size_;
+  slice_start_ = other.slice_start_;
+  num_slices_ = other.num_slices_;
+  pending_get_ = other.pending_get_;
+  pending_put_ = other.pending_put_;
+  pending_update_ = other.pending_update_;
+}
+
+/************************ParamEntry***************************/
+ParamEntry::ParamEntry():
+  num_update(0), next_version(-1), num_local(0), num_total(0) {
+}
+
+ParamEntry::ParamEntry(int total, Param* p) : num_update(0), num_total(total) {
+  shares.push_back(p);
+}
+void ParamEntry::AddParam(bool local, Param* p) {
+  num_local += local;
+  num_total += 1;
+  if(local)
+    shares.push_back(p);
 }
 }
 

Reply via email to