Repository: incubator-singa Updated Branches: refs/heads/master 7954a87d2 -> 96bedb226
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index a6a5dbf..3ecaad0 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -1,9 +1,10 @@ #include <thread> #include <vector> #include <map> -#include <queue> #include <chrono> #include <glog/logging.h> +#include "utils/cluster.h" +#include "utils/common.h" #include "proto/common.pb.h" #include "trainer/trainer.h" #include "mshadow/tensor.h" @@ -11,587 +12,486 @@ namespace singa { using std::vector; using std::map; +using std::queue; using namespace std::chrono; using std::make_shared; -typedef std::chrono::milliseconds TimeT; +/***********************Trainer****************************/ +Trainer::~Trainer() { + // free Params (i.e., slices) in server shard + for (auto entry : server_shard_) + for (auto param : entry.second->shares) + delete param; + delete router_; +} -void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ - // register all layers appearing in the neural net +void Trainer::RegisterDefaultClasses(const singa::ModelProto& model_conf) { + // register all implemented layers singa::NeuralNet::RegisterLayers(); - Singleton<Factory<singa::Param>>::Instance()->Register( - "Param", CreateInstance(singa::Param, singa::Param)); - Singleton<Factory<singa::Updater>>::Instance() ->Register( - "Updater", CreateInstance(singa::SGDUpdater, singa::Updater)); + auto param_factory = Singleton<Factory<singa::Param>>::Instance(); + param_factory->Register("Param", CreateInstance(Param, Param)); + auto updater_factory = Singleton<Factory<singa::Updater>>::Instance(); + updater_factory->Register("Updater", CreateInstance(SGDUpdater, Updater)); } -void HandleWorkerFinish(void * ctx){ - HandleContext* hctx=static_cast<HandleContext*> (ctx); - Msg* msg=new Msg(); - msg->set_src(-1,-1, kRuntime); - msg->set_dst(hctx->group_id, hctx->id, kServer); - msg->set_type(kStop); - hctx->dealer->Send(&msg); -} +const vector<int> SliceParams(const vector<Param*>& params) { + // for load-balance among servers in a group and among server groups + int nserver_grps = Cluster::Get()->nserver_groups(); + int nservers_per_grp = Cluster::Get()->nservers_per_group(); + int lcm = LeastCommonMultiple(nserver_grps, nservers_per_grp); -const std::unordered_map<int, vector<std::pair<int, int>>> -SliceParams(int num, const vector<Param*>& params){ + // collect sizes of unique Params + std::vector<int> paramsize; + for (auto param : params) + if (param->id() == param->owner()) + paramsize.push_back(param->size()); + // slice into lcm pieces to achieve good load-balance for both intra-group + // partition (among servers in a group) and inter-group partition (each group + // is assgined a sub-set of slices) + auto param_slice = Slice(lcm, paramsize); + // construct map from Param ID to its slices <slice id, len> std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices; - if (num==0) - return paramid2slices; - vector<int> param_size; - int avg=0; - for(const auto& x:params){ - if(x->owner()==x->id()) - avg+=x->size(); - } - avg=avg/num+avg%num; - int diff=avg/10; - LOG(INFO)<<"Slicer, param avg="<<avg<<", diff= "<<diff; - - int capacity=avg, sliceid=0, nbox=0; - for(auto& param: params){ - if(param->id()!=param->owner()) - continue; - int x=param->size(), paramid=param->id(); - LOG(INFO)<<"param id="<<paramid<<", total size="<<x; - while(x>0){ - int size=0; - if(capacity>=x){ - capacity-=x; - size=x; - x=0; - }else if(capacity+diff>=x){ - size=x; - x=0; - capacity=0; - }else if(capacity>=diff){ - x-=capacity; - size=capacity; - capacity=avg; - nbox++; - }else{ - capacity=avg; - nbox++; - } - if(size){ - paramid2slices[paramid].push_back(std::make_pair(sliceid++, size)); - LOG(INFO)<<"param id="<<paramid<<", slice size="<<size; + vector<int> slices; + auto it = param_slice.begin(); + int slice_id = 0; + for (auto param : params) { + if (param->id() == param->owner()) { + for (int len : *it) { + slices.push_back(len); + paramid2slices[param->id()].push_back(std::make_pair(slice_id++, len)); } + it++; } } - CHECK_LE(nbox, num); - return paramid2slices; + // add slice info for every Param + for (auto param : params) + for (auto entry : paramid2slices[param->owner()]) { + param->AddSlice(entry.first, entry.second); + LOG(INFO) << "param id " << param->id() << " owner=" << param->owner() + << ": " << entry.first << ", " << entry.second; + } + return slices; } -const vector<int> PartitionSlice(int num, const vector<int>& slices){ - int avg=0; - for(int x: slices) - avg+=x; - avg=avg/num+avg%num; - int box=avg, boxid=0, diff=avg/10; - vector<int> slice2box; - for(auto it=slices.begin(); it!=slices.end();){ - int x=*it; - if(box>=x){ - box-=x; - slice2box.push_back(boxid); - it++; - }else if(box+diff>=x){ - slice2box.push_back(boxid); - it++; - box=0; - }else{ - box=avg; - boxid++; + +void Trainer::SetupWorkerServer( + const ModelProto& model_conf, + const vector<Worker*>& workers, + const vector<Server*>& servers) { + auto cluster = Cluster::Get(); + int grp_size = cluster->nworkers_per_group(); + const auto& net_conf = model_conf.neuralnet(); + auto net = NeuralNet::Create(net_conf, kTrain, grp_size); + // MUST do SliceParam before share param/net with others + auto slices = SliceParams(net->params()); + shared_ptr<NeuralNet> train_net, test_net, valid_net; + int grp = workers.size() ? workers.at(0)->grp_id() : -1; + if (grp == 0 && model_conf.test_steps()) { + // test are performed only by the first group + test_net = NeuralNet::Create(net_conf, kTest, grp_size); + test_net->ShareParamsFrom(net); + } + if (grp == 0 && model_conf.validation_steps()) { + // validation are performed only by the first group + valid_net = NeuralNet::Create(net_conf, kValidation, grp_size); + valid_net->ShareParamsFrom(net); + } + bool prepare_param = true; + for (auto worker : workers) { + if (worker->grp_id() != grp) { + train_net = NeuralNet::Create(net_conf, kTrain, grp_size); + if(cluster->share_memory()) + train_net->ShareParamsFrom(net); + valid_net = test_net = nullptr; + grp = worker->grp_id(); + prepare_param = true; + } else { + train_net = net; + } + worker->Setup(model_conf, train_net, valid_net, test_net); + // Prepare ParamEntry + if (prepare_param) { + for (auto layer : train_net->layers()) { + bool local = layer->partition_id() >= workers.front()->id() + && layer->partition_id() <= workers.back()->id(); + for (auto param : layer->GetParams()) { + int hash = Hash(grp, param->owner()); + if (worker_shard_.find(hash) == worker_shard_.end()) + worker_shard_[hash] = new ParamEntry(); + worker_shard_[hash]->AddParam(local, param); + } + } + prepare_param = false; } } -// CHECK_LT(slice2box.back(), num); - CHECK_EQ(slice2box.size(), slices.size()); - int previd=slice2box[0]; - std::string disp; - for(size_t i=0;i<slice2box.size();i++) - if(previd!=slice2box[i]){ - disp+=", "+std::to_string(slices[i]); - previd=slice2box[i]; - } else - disp+=" "+std::to_string(slices[i]); - LOG(INFO)<<"partition slice (avg ="<<avg<<", num="<<num<<"):"<<disp; - return slice2box; + // partition among server groups, each group maintains one sub-set for sync + auto slice2group = PartitionSlices(cluster->nserver_groups(), slices); + for (auto server : servers) + server->Setup(model_conf.updater(), &server_shard_, slice2group); + // partition within one server group, each server updates for one sub-set + slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices); } -vector<Server*> Trainer::CreateServers(int nthreads, - const ModelProto & mproto, - const vector<int> slices, - vector<HandleContext*>* ctx){ - auto cluster=Cluster::Get(); + +vector<Server*> Trainer::CreateServers(int nthreads, const ModelProto& mconf) { + auto cluster = Cluster::Get(); vector<Server*> servers; - if(!cluster->has_server()) + if (!cluster->has_server()) return servers; - int pid=cluster->procs_id(); - if(cluster->server_worker_separate()) - pid-=cluster->nworker_procs(); - int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group(); - int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group(); - int end=start+cluster->nservers_per_procs(); - // the ServerShard for servers consists of a dictionary of Param objects - server_shard_=make_shared<ServerShard>(); - auto slice2group=PartitionSlice(cluster->nserver_groups(), slices); - if(start<end){ - auto dealer=make_shared<Dealer>(); - dealer->Connect(kInprocRouterEndpoint); - for(int sid=start;sid<end;sid++){ - auto server=new Server(nthreads++, gid, sid); - server->Setup(mproto.updater(), server_shard_, slice2group); - servers.push_back(server); - auto *hc=new HandleContext{dealer, gid, sid}; - ctx->push_back(hc); - CHECK(cluster->runtime()->WatchSGroup(gid, sid, HandleWorkerFinish, - ctx->back())); - } + int pid = cluster->procs_id(); + // if true, server procs (logical) id starts after worker procs + if (cluster->server_worker_separate()) + pid -= cluster->nworker_procs(); + int procs_size = cluster->nservers_per_procs(); + int grp_size = cluster->nservers_per_group(); + int gid = pid * procs_size / grp_size; + int start = pid * procs_size % grp_size; + int end = start + procs_size; + for (int sid = start; sid < end; sid++) { + auto server = new Server(nthreads++, gid, sid); + servers.push_back(server); } return servers; } -vector<Worker*> Trainer::CreateWorkers(int nthreads, - const ModelProto& mproto, vector<int> *slice_size){ +vector<Worker*> Trainer::CreateWorkers(int nthreads, const ModelProto& mconf){ auto cluster=Cluster::Get(); - auto net=NeuralNet::Create(mproto.neuralnet(), kTrain, - cluster->nworkers_per_group()); - int lcm=LeastCommonMultiple(cluster->nserver_groups(), cluster->nservers_per_group()); - auto paramid2slices=SliceParams(lcm, net->params()); // sliceid, size - for(auto param: net->params()){ - if(param->id() == param->owner()) - for(auto entry: paramid2slices[param->id()]) - slice_size->push_back(entry.second); - } - vector<Worker*> workers; if(!cluster->has_worker()) return workers; - //LOG(ERROR)<<net->ToString(); - int pid=cluster->procs_id(); + int pid = cluster->procs_id(); + int grp_size = cluster->nworkers_per_group(); + int procs_size = cluster->nworkers_per_procs(); int gstart, gend, wstart, wend; - if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){ + if (grp_size >= procs_size) { // all workers in this procs are from the same group - gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group(); - gend=gstart+1; - wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group(); - wend=wstart+cluster->nworkers_per_group(); - }else{ - // there are multiple groups in this procs - CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0); - int groups_per_procs= - cluster->nworkers_per_procs()/cluster->nworkers_per_group(); - gstart=pid*groups_per_procs; - gend=(pid+1)*groups_per_procs; - wstart=0; - wend=cluster->nworkers_per_group(); + gstart = pid * procs_size / grp_size; + gend = gstart + 1; + wstart = pid * procs_size % grp_size; + wend = wstart + procs_size; + } else { + // there are multiple (complete) groups in this procs. + CHECK_EQ(procs_size % grp_size, 0); + int groups_per_procs = procs_size / grp_size; + gstart = pid * groups_per_procs; + gend = (pid+1) * groups_per_procs; + wstart = 0; + wend = grp_size; } - for(int gid=gstart;gid<gend;gid++){ - shared_ptr<NeuralNet> train_net, test_net, validation_net; - if(gid==gstart) - train_net=net; - else{ - train_net=NeuralNet::Create(mproto.neuralnet(), kTrain, - cluster->nworkers_per_group()); - // the train net for other groups may share parameter values from the - // first group - if(cluster->share_memory()) - train_net->ShareParams(net); - } - if(gid==0){ - // validation and test are performed only by the first group - if(mproto.test_steps()){ - test_net=NeuralNet::Create(mproto.neuralnet(), kTest, - cluster->nworkers_per_group()); - if(test_net!=nullptr) - test_net->ShareParams(train_net); - } - if(mproto.validation_steps()){ - validation_net=NeuralNet::Create(mproto.neuralnet(), kValidation, - cluster->nworkers_per_group()); - if(validation_net!=nullptr) - validation_net->ShareParams(train_net); - } - } - // create ServerShard for the workers - auto shard=make_shared<WorkerShard>(); - worker_shards_[gid]=shard; - for(auto layer: train_net->layers()){ - int procsid=cluster->ProcsIDOf(gid, layer->partition_id(), kWorkerLayer); - bool local=procsid==cluster->procs_id(); - for(auto param: layer->GetParams()){ - for(auto entry :paramid2slices[param->owner()]){ - param->AddSlice(entry.first, entry.second); - } - int owner_procs=param->owner()==param->id()?procsid:procs_id_; - if(shard->find(param->owner())==shard->end()) - (*shard)[param->owner()]= - make_shared<ParamInfo>(param, local, owner_procs); - else - shard->at(param->owner())->AddParam(param, local); - } - } - for(int wid=wstart;wid<wend;wid++){ + for (int gid = gstart; gid < gend; gid++) { + for (int wid = wstart; wid < wend; wid++) { Worker* worker=nullptr; - if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation) + if (mconf.alg() == ModelProto_GradCalcAlg_kBackPropagation) worker = new BPWorker(nthreads++,gid, wid); - else{ - // TODO add CDWorker + else { + // TODO add CDWorker and BPTTWorker } - worker->Setup(mproto, train_net); - worker->set_test_net(test_net); - worker->set_validation_net(validation_net); workers.push_back(worker); } } return workers; } -void Trainer::Start(const ModelProto& mproto, const GlobalProto& gproto, - const ClusterProto& cproto, - int procs_id){ - // procs_id is only used for resume training - CHECK_EQ(procs_id, -1); - RegisterDefaultClasses(mproto); +void Trainer::Start(const ModelProto& mconf, const GlobalProto& gconf, + const ClusterProto& cconf, int job){ + RegisterDefaultClasses(mconf); - auto cluster=Cluster::Get(gproto, cproto, procs_id); - router_=make_shared<Router>(); + // register job to zookeeper + auto cluster=Cluster::Get(gconf, cconf, job); + if (mconf.resume()) { + // TODO(wangwei) resume from checkpoint + // load param slices to server_shard_ and reset running step of worker + // mproto.set_step(step); + } + + router_ = new Router(); router_->Bind(kInprocRouterEndpoint); - if(cluster->nprocs()>1){ - const string hostip=cluster->hostip(); - int port=router_->Bind("tcp://"+hostip+":*"); - cluster->Register(hostip+":"+std::to_string(port)); - }else + if (cluster->nprocs() > 1) { + const string hostip = cluster->hostip(); + int port = router_->Bind("tcp://" + hostip + ":*"); + // register endpoint to zookeeper + cluster->Register(hostip + ":" + std::to_string(port)); + } else { cluster->set_procs_id(0); + } - procs_id_ = cluster->procs_id(); - int nthreads=1; - // create workers - vector<int> slices; - vector<Worker*> workers=CreateWorkers(nthreads, mproto, &slices); - if(cluster->nserver_groups()&&cluster->nservers_per_group()) - slice2server_=PartitionSlice(cluster->nservers_per_group(), slices); - nthreads+=workers.size(); - // create servers - vector<HandleContext*> ctx; - vector<Server*> servers=CreateServers(nthreads, mproto, slices, - &ctx); + int nthreads = 1; + const vector<Worker*> workers = CreateWorkers(nthreads, mconf); + nthreads += workers.size(); + const vector<Server*> servers = CreateServers(nthreads, mconf); + SetupWorkerServer(mconf, workers, servers); #ifdef USE_MPI - for(int i=0;i<nSocket;i++){ + for (int i = 0; i < nthreads; i++) MPIQueues.push_back(make_shared<SafeQueue>()); - } #endif vector<std::thread> threads; - for(auto server: servers) - threads.push_back(std::thread(&Server::Run,server)); - for(auto worker: workers) - threads.push_back(std::thread(&Worker::Run,worker)); + for(auto server : servers) + threads.push_back(std::thread(&Server::Run, server)); + for(auto worker : workers) + threads.push_back(std::thread(&Worker::Run, worker)); Run(workers, servers); - for(auto& thread: threads) + for(auto& thread : threads) thread.join(); - for(auto x: ctx) - delete x; - for(auto x : servers) - delete x; - for(auto x : workers) - delete x; + for(auto server : servers) + delete server; + for(auto worker : workers) + delete worker; } -inline int bandwidth(int bytes, system_clock::time_point start){ +inline int bandwidth(int bytes, system_clock::time_point start) { auto now=system_clock::now(); - auto duration=duration_cast<TimeT> (now - start); + auto duration=duration_cast<std::chrono::milliseconds> (now - start); return static_cast<int>(bytes*1000.f/duration.count()); } -void Trainer::Run(const vector<Worker*>& workers, - const vector<Server*>& servers){ - auto cluster=Cluster::Get(); - procs_id_=cluster->procs_id(); - LOG(INFO)<<"Stub in process "<<procs_id_<<" starts"; - map<int, shared_ptr<Dealer>> interprocs_dealers; +void Trainer::Run( + const vector<Worker*>& workers, + const vector<Server*>& servers) { + int nworkers = workers.size(), nservers = servers.size(); + auto cluster = Cluster::Get(); + procs_id_ = cluster->procs_id(); + LOG(INFO) << "Stub in process " << procs_id_ << " starts"; + + // for sync among server groups + auto start = std::chrono::system_clock::now(); + float trans_size = 0.f; // total size of msg transferred since start time + int sync_server_id = 0; + int max_bandwidth = cluster->bandwidth(); + int nserver_grps = cluster->nserver_groups(); + + map<int, Dealer*> inter_dealers; // for sending msg to other procs + std::queue<Msg*> msg_queue; + Poller poll(router_); bool stop=false; - auto start=std::chrono::system_clock::now(); - float amount=0.f; - Poller poll; - poll.Add(router_.get()); - int sync_server=0, nworkers=workers.size(), nservers=servers.size(); - while(!stop){ - // if the poll time is large, then the poller may not expire - // if it is small, then many reminder messages will be sent which may - // slow done the process of other request. TODO tune it. - auto *sock=poll.Wait(cluster->poll_time()); - if(poll.Terminated()){ - LOG(ERROR)<<"Connection broken!"; - exit(0); - }else if(sock==nullptr){ - if(cluster->nserver_groups()>1&& - bandwidth(amount, start)<cluster->bandwidth()){ - Msg* msg=new Msg(); - msg->set_src(-1,-1, kStub); - msg->set_dst(servers[sync_server]->group_id(), - servers[sync_server]->server_id(), kServer); - msg->set_type(kSyncReminder); - sync_server=(sync_server+1)%servers.size(); - router_->Send(&msg); + while (!stop || !msg_queue.empty()) { + if (msg_queue.empty()) { + // if the poll time is large, then the poller may not expire + // if it is small, then many reminder messages will be sent which may + // slow done the process of other request. TODO tune it. + auto *sock = poll.Wait(cluster->poll_time()); + if (poll.Terminated()) { + LOG(ERROR) << "Connection broken!"; + exit(0); + } else if (sock == nullptr) { + if (nserver_grps > 1 && bandwidth(trans_size, start) < max_bandwidth) { + Msg* msg = GenSyncReminderMsg(sync_server_id, servers); + router_->Send(&msg); + sync_server_id = (sync_server_id + 1) % nservers; + } + continue; } - continue; + Msg* msg = router_->Receive(); + msg_queue.push(msg); } - Msg* msg=router_->Receive(); - if(msg==nullptr){ - LOG(ERROR)<<"Connection broken!"; - exit(0); - } - msg_queue.push(msg); - while(!msg_queue.empty()){ - msg=msg_queue.front(); - msg_queue.pop(); - int dst_flag=msg->dst_flag(); - int type=msg->type(); - int dst_procs=msg->dst_first(); - if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){ - if(type==kConnect){ - msg_queue.push(HandleConnect(&msg)); - }else if(type==kStop){ - if(msg->src_flag()==kServer) - nservers--; - else if (msg->src_flag()==kWorkerParam) - nworkers--; - DeleteMsg(&msg); - if(nworkers==0&&nservers==0){ - stop=true; - break; - } - }else if(type==kMetric){ - if(msg->src_first()==0){ - int step=msg->trgt_first(); - string prefix((char*)msg->frame_data(), msg->frame_size()); - msg->next_frame(); - Metric cur; - cur.ParseFrom(string((char*)msg->frame_data(), msg->frame_size())); - LOG(ERROR)<<prefix<<" step-" <<step<<", "<<cur.ToLogString(); - } - DeleteMsg(&msg); - }else if(cluster->nserver_groups()>0){ - int group_id; - int paramid=msg->trgt_first(); - shared_ptr<ParamInfo> entry; - switch (type){ // TODO process other requests, e.g. RESTful - case kUpdate: - group_id=msg->src_first(); - entry=worker_shards_.at(group_id)->at(paramid); - for(auto x:HandleUpdate(entry, &msg)) - msg_queue.push(x); - break; - case kRUpdate: - group_id=msg->dst_second(); - entry=worker_shards_.at(group_id)->at(paramid); - HandleUpdateResponse(entry, &msg); - break; - case kGet: - group_id=msg->src_first(); - entry=worker_shards_.at(group_id)->at(paramid); - for(auto x:HandleGet(entry, &msg)) - msg_queue.push(x); - break; - case kRGet: - group_id=msg->dst_second(); - entry=worker_shards_.at(group_id)->at(paramid); - HandleGetResponse(entry, &msg); - break; - case kPut: - group_id=msg->src_first(); - entry=worker_shards_.at(group_id)->at(paramid); - for(auto x:HandlePut(entry, &msg)) - msg_queue.push(x); - break; - default: - LOG(ERROR)<<"Unknow message type:"<<type; - break; - } - }else{ - DeleteMsg(&msg); - } - }else{ - int dst_procs_id; - if(dst_flag==kStub){ - dst_procs_id=msg->dst_first(); - }else{ - dst_procs_id=cluster->ProcsIDOf(msg->dst_first(), - msg->dst_second(), msg->dst_flag()); - } - if(dst_procs_id!=procs_id_){ - // forward to other procs - if (interprocs_dealers.find(dst_procs_id)==interprocs_dealers.end()){ - auto dealer=make_shared<Dealer>(); - interprocs_dealers[dst_procs_id]=dealer; - while(cluster->endpoint(dst_procs_id)==""){ - std::this_thread::sleep_for( - std::chrono::milliseconds(3000));//kCollectSleepTime)); - LOG(ERROR)<<"waiting for procs "<< dst_procs_id<<" to register"; - } - dealer->Connect("tcp://"+cluster->endpoint(dst_procs_id)); - } - if(bandwidth(amount, start) <=cluster->bandwidth()){ - start=std::chrono::system_clock::now(); - amount=0; - } - amount+=msg->size(); - //LOG(ERROR)<<"send inter msg of type "<<msg->type(); - interprocs_dealers[dst_procs_id]->Send(&msg); - }else{ - if(type==kSyncRequest){ - char buf[32]; - sprintf(buf, "%d", cluster->bandwidth()-bandwidth(amount, start)); - msg->add_frame(buf, strlen(buf)); - } - router_->Send(&msg); + Msg* msg = msg_queue.front(); + msg_queue.pop(); + int type = msg->type(), dst = msg->dst(), flag = AddrType(dst); + if (flag == kStub && (AddrProc(dst) == procs_id_ || AddrGrp(dst) == -1)) { + if (type == kConnect) { + DeleteMsg(&msg); + } else if (type == kMetric) { + DisplayMetric(&msg); + } else if (type == kStop) { + int src_flag = AddrType(msg->src()); + if (src_flag == kServer) nservers--; + else if (src_flag == kWorkerParam) nworkers--; + DeleteMsg(&msg); + if (nworkers == 0 && nservers == 0) break; + } else if (nserver_grps > 0) { + HandleLocalMsg(&msg_queue, &msg); + } else { + DeleteMsg(&msg); + } + } else { + int dst_procs = AddrProc(dst); + if (flag != kStub) + dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag); + if (dst_procs != procs_id_) { + if (bandwidth(trans_size, start) <= cluster->bandwidth()) { + start = std::chrono::system_clock::now(); + trans_size = 0; } + trans_size += msg->size(); + + if (inter_dealers.find(dst_procs) == inter_dealers.end()) + inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs); + inter_dealers[dst_procs]->Send(&msg); + } else { + if (type == kSyncRequest) + msg->AddFormatFrame("i", max_bandwidth - bandwidth(trans_size, start)); + router_->Send(&msg); } } } - LOG(INFO)<<"Stub in process "<<procs_id_<<" stops"; + LOG(ERROR) << "Stub in process " << procs_id_ << " stops"; + for (auto& entry : inter_dealers) + delete entry.second; } -Msg* Trainer::HandleConnect(Msg** msg){ - string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); - CHECK_STREQ("PING", ping.c_str()); - // ping-pong for debug - (*msg)->SwapAddr(); - Msg* reply=new Msg(); - reply->SetAddr(*msg); - reply->add_frame("PONG", 4); - reply->set_type(kConnect); + +Msg* Trainer::GenSyncReminderMsg(int server, const vector<Server*>& servers ) { + Msg* msg = new Msg(); + msg->set_src(Addr(-1,-1, kStub)); + msg->set_dst(Addr(servers[server]->grp_id(), servers[server]->id(), kServer)); + msg->set_type(kSyncReminder); + return msg; +} + +void Trainer::DisplayMetric(Msg** msg) { + Msg* msgg = *msg; + // only display metrics from the first group + if (AddrGrp(msgg->src()) == 0) { + int step = msgg->trgt_version(); + char prefix[128]; + msgg->ParseFormatFrame("s", prefix); + CHECK(msgg->NextFrame()); + const string perf(static_cast<char*>(msgg->FrameData()), msgg->FrameSize());; + Metric cur(perf); + LOG(ERROR) << prefix << " step-" << step <<", " << cur.ToLogString(); + } DeleteMsg(msg); - return reply; } -const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){ - Msg* msgg=*msg; - vector<Msg*> replies; - int version=msgg->trgt_third(); - if(msgg->src_flag()==kStub){ - LOG(FATAL)<<"Not implemented"; - /* - if(version<=pi->shares.at(0)->version()){ - replies.push_back(pi->shares.at(0)->HandleGetMsg(msg)); - }else if(version>pi->next_version){ - // reinsert into a msg queue. - replies.push_back(mmsg); - } - */ - }else if(version>pi->next_version){ - pi->next_version=version; - int gid=msgg->src_first(); - int group=gid/Cluster::Get()->nworker_groups_per_server_group(); - auto param=pi->shares.at(0); - for(int idx=0, id=param->slice_start();idx<param->num_slices();idx++){ - int server=slice2server_[id+idx]; - int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); - auto x=param->GenGetMsg(procs!=procs_id_, idx); - x->set_trgt(param->owner(), id+idx, param->local_version()+1); - x->set_src(procs_id_, gid, kStub); - x->set_dst(group, server, kServer); - //LOG(ERROR)<<"stub handle get for "<<idx+id<<","<<group<<","<<server; - replies.push_back(x); + +Dealer* Trainer::CreateInterProcsDealer(int dst_procs) { + // forward to other procs + auto cluster = Cluster::Get(); + auto dealer = new Dealer(); + while(cluster->endpoint(dst_procs)=="") { + //kCollectSleepTime)); + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + LOG(ERROR)<<"waiting for procs "<< dst_procs<<" to register"; + } + dealer->Connect("tcp://"+cluster->endpoint(dst_procs)); + return dealer; +} + +void Trainer::HandleLocalMsg(queue<Msg*>* msg_queue, Msg** msg) { + Msg* msgg = *msg; + int paramid = ParamID(msgg->trgt_val()); + int type = msgg->type(); + int grp; + ParamEntry *entry = nullptr; + switch (type) { // TODO process other requests, e.g. RESTful + case kUpdate: + grp = AddrGrp(msgg->src()); + entry = worker_shard_.at(Hash(grp, paramid)); + for(auto update_msg : HandleUpdate(entry, msg)) + msg_queue->push(update_msg); + break; + case kRUpdate: + grp = AddrGrp(msgg->dst()); + entry = worker_shard_.at(Hash(grp, paramid)); + HandleUpdateResponse(entry, msg); + break; + case kGet: + grp = AddrGrp(msgg->src()); + entry = worker_shard_.at(Hash(grp, paramid)); + for(auto get_msg : HandleGet(entry, msg)) + msg_queue->push(get_msg); + break; + case kRGet: + grp = AddrGrp(msgg->dst()); + entry = worker_shard_.at(Hash(grp, paramid)); + HandleGetResponse(entry, msg); + break; + case kPut: + grp = AddrGrp(msgg->src()); + entry = worker_shard_.at(Hash(grp, paramid)); + for(auto put_msg : HandlePut(entry, msg)) + msg_queue->push(put_msg); + break; + default: + LOG(ERROR)<<"Unknow message type:"<<type; + break; + } +} + +void Trainer::GenMsgs(int type, int version, ParamEntry* entry, + Msg* msg, vector<Msg*> *ret) { + int src_grp = AddrGrp(msg->src()); + int dst_grp = src_grp / Cluster::Get()->nworker_groups_per_server_group(); + auto param=entry->shares.at(0); + for (int idx = 0 ; idx < param->num_slices(); idx++) { + int slice_id =param->slice_start() + idx; + int server = slice2server_[slice_id]; + int procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer); + Msg* new_msg = nullptr; + if (type == kPut) { + CHECK_GT(entry->num_total, 0); + new_msg = param->GenPutMsg(procs != procs_id_, idx); + new_msg->AddFormatFrame("i", entry->num_total); + } else if (type == kGet) { + new_msg = param->GenGetMsg(procs != procs_id_, idx); + } else if (type == kUpdate) { + new_msg = param->GenUpdateMsg(procs != procs_id_, idx); + new_msg->AddFormatFrame("i", entry->num_local); + } else { + LOG(FATAL) << "Wrong type"; } + new_msg->set_trgt(ParamTrgt(param->owner(), slice_id), version); + new_msg->set_src(Addr(src_grp, procs_id_, kStub)); + new_msg->set_dst(Addr(dst_grp, server, kServer)); + ret->push_back(new_msg); } - return replies; } -const vector<Msg*> Trainer::HandleUpdate(shared_ptr<ParamInfo>pi, Msg** msg){ - Msg* msgg=*msg ; +const vector<Msg*> Trainer::HandleGet(ParamEntry* entry, Msg** msg) { vector<Msg*> ret; - int step= msgg->trgt_third(); - if(msgg->src_flag()==kStub){ - if(pi->num_update<pi->num_local){ - ret.push_back(*msg); - return ret; //wait unitl local updates are ready - } - int n; sscanf((char*)(*msg)->frame_data(), "%d", &n); - pi->num_update+=n; - auto it=pi->shares.begin(); - auto shape=mshadow::Shape1((*it)->size()); - mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape); - mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); - agg+=grad; - }else if(++pi->num_update>=pi->num_local){ - auto it=pi->shares.begin(); - auto shape=mshadow::Shape1((*it)->size()); - mshadow::Tensor<mshadow::cpu,1> agg((*it)->mutable_cpu_grad(), shape); - for(++it;it!=pi->shares.end();it++){ - mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); - agg+=grad; - } - agg/=pi->num_total; - if(pi->num_local<pi->num_total){ - /* - int gid=msgg->src_first(); - for(auto update: pi->shares.at(0)->GenUpdateMsg(step)){ - update->set_src(procs_id_, gid,kStub); - update->set_dst(pi->owner_procs, gid, kStub); - ret.push_back(update); - } - pi->num_update=0; - */ - } + int version = (*msg)->trgt_version(); + if (version > entry->next_version) { + entry->next_version = version; + GenMsgs(kGet, version, entry, *msg, &ret); } - if(pi->num_update==pi->num_total){ - auto param=pi->shares.at(0); - int group=msgg->src_first()/Cluster::Get()->nworker_groups_per_server_group(); - int srcgid=msgg->src_first(); - for(int idx=0, id=param->slice_start(); idx<param->num_slices();idx++){ - int server=slice2server_[idx+id]; - int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); - auto x=param->GenUpdateMsg(procs!=procs_id_, idx); - x->set_trgt(param->owner(), id+idx, step); - x->set_src(procs_id_, srcgid, kStub); - x->set_dst(group, server, kServer); - ret.push_back(x); + DeleteMsg(msg); + return ret; +} + +const vector<Msg*> Trainer::HandleUpdate(ParamEntry *entry, Msg** msg) { + vector<Msg*> ret; + entry->num_update++; + if (entry->num_update >= entry->num_local) { + // average local gradient + if (entry->num_local > 1) { + auto it = entry->shares.begin(); + auto shape=mshadow::Shape1((*it)->size()); + mshadow::Tensor<mshadow::cpu,1> sum((*it)->mutable_cpu_grad(), shape); + for (++it; it != entry->shares.end(); it++) { + mshadow::Tensor<mshadow::cpu,1> grad((*it)->mutable_cpu_grad(), shape); + sum += grad; + } + sum /= entry->num_total; } - pi->num_update=0; + int step = (*msg)->trgt_version(); + GenMsgs(kUpdate, step, entry, *msg, &ret); + entry->num_update = 0; } DeleteMsg(msg); return ret; } -const vector<Msg*> Trainer::HandlePut(shared_ptr<ParamInfo>pi, Msg** msg){ +const vector<Msg*> Trainer::HandlePut(ParamEntry* entry, Msg** msg) { vector<Msg*> ret; - CHECK_NE((*msg)->src_flag(), kStub); - int gid=(*msg)->src_first(); - int version=(*msg)->trgt_third(); - auto param=pi->shares.at(0); - int group=gid/Cluster::Get()->nworker_groups_per_server_group(); - for(int idx=0, start=param->slice_start();idx<param->num_slices(); idx++){ - int server=slice2server_[start+idx]; - int procs=Cluster::Get()->ProcsIDOf(group, server, kServer); - auto x=param->GenPutMsg(procs!=procs_id_, idx); - x->set_trgt(param->owner(), start+idx, version); - x->set_src(procs_id_, gid, kStub); - x->set_dst(group, server, kServer); - ret.push_back(x); - //LOG(ERROR)<<"stub handle put "<<start+idx<<"to "<<group<<","<<server; - } + int version = (*msg)->trgt_version(); + GenMsgs(kPut, version, entry, *msg, &ret); DeleteMsg(msg); return ret; } -void Trainer::HandleGetResponse(shared_ptr<ParamInfo>pi, Msg** msg){ - int version=(*msg)->trgt_third(); - int sliceid=(*msg)->trgt_second(); - auto param=pi->shares.at(0); - if(param->ParseGetResponseMsg(msg,sliceid-param->slice_start())) +void Trainer::HandleGetResponse(ParamEntry* entry, Msg** msg) { + int version = (*msg)->trgt_version(); + int sliceid = SliceID((*msg)->trgt_val()); + auto param = entry->shares.at(0); + if (param->ParseGetResponseMsg(*msg, sliceid-param->slice_start())) param->set_version(version); - // process get requests in waiting queue + DeleteMsg(msg); } - -void Trainer::HandleUpdateResponse(shared_ptr<ParamInfo> pi, Msg** msg){ - int sliceid=(*msg)->trgt_second(); - int version=(*msg)->trgt_third(); - auto param=pi->shares.at(0); - if(param->ParseUpdateResponseMsg(msg,sliceid-param->slice_start())){ +void Trainer::HandleUpdateResponse(ParamEntry* entry, Msg** msg) { + int version = (*msg)->trgt_version(); + int sliceid = SliceID((*msg)->trgt_val()); + auto param = entry->shares.at(0); + if (param->ParseUpdateResponseMsg(*msg, sliceid-param->slice_start())) param->set_version(version); - } + DeleteMsg(msg); } } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 80a6283..bf98f0b 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -1,300 +1,328 @@ #include <glog/logging.h> #include <thread> -#include <memory> -#include <iostream> #include <chrono> #include <thread> #include "utils/singleton.h" +#include "utils/cluster.h" #include "utils/factory.h" #include "trainer/worker.h" #include "proto/model.pb.h" + namespace singa { using std::thread; -using std::make_shared; -Worker::Worker(int thread_id, int group_id, int worker_id): - thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){ +Worker::Worker(int thread_id, int grp_id, int id): + thread_id_(thread_id), grp_id_(grp_id), id_(id), + layer_dealer_(nullptr), dealer_(nullptr), updater_(nullptr) { } -void Worker::Setup(const ModelProto& model, - shared_ptr<NeuralNet> train_net){ - train_net_=train_net; - modelproto_=model; - auto cluster=Cluster::Get(); - if(!(cluster->nserver_groups()&&cluster->server_update())){ - updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance() - ->Create("Updater")); +void Worker::Setup( + const ModelProto& model, shared_ptr<NeuralNet> train_net, + shared_ptr<NeuralNet> valid_net, shared_ptr<NeuralNet> test_net) { + modelproto_.CopyFrom(model); + train_net_ = train_net; + validation_net_ = valid_net; + test_net_ = test_net; + auto cluster = Cluster::Get(); + // if no server or user requires worker to do param update + if (!(cluster->nserver_groups() && cluster->server_update())) { + updater_ = Singleton<Factory<Updater>>::Instance()->Create("Updater"); updater_->Init(model.updater()); } } -void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){ - if(updater_==nullptr){ - auto cluster=Cluster::Get(); - int sgid=group_id_/cluster->nworker_groups_per_server_group(); - CHECK(cluster->runtime()->JoinSGroup(group_id_, worker_id_, sgid)); +Worker::~Worker() { + if (updater_ != nullptr) + delete updater_; + if (layer_dealer_) + delete layer_dealer_; + if (dealer_) + delete dealer_; +} + +void Worker::InitLocalParams() { + // for each server grp, its first subscriber worker grp does the param init + if (grp_id_ % Cluster::Get()->nworker_groups_per_server_group() == 0) { + for (auto layer: train_net_->layers()){ + if (layer->partition_id() == id_) { + for (auto param : layer->GetParams()) { + // only owners fill the memory of parameter values. + if(param->owner() == param->id()) + param->InitValues(0); + } + } + } + Metric perf; + // warmup training before put params to servers + for (; step_ < modelproto_.warmup_steps(); step_++) + TrainOneBatch(step_, &perf); + for (auto layer : train_net_->layers()) { + if (layer->partition_id() == id_) + for (auto param : layer->GetParams()) + if (param->owner() == param->id()) + Put(param, step_); + } + } + // wait owners in the same procs init params, then no get requests sent + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + for (auto layer : train_net_->layers()) { + if (layer->partition_id() == id_) + for (auto param : layer->GetParams()) + if (param->owner() != param->id()) + Get(param, modelproto_.warmup_steps()); } +} +void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) { dealer->Connect(kInprocRouterEndpoint); - Msg* ping=new Msg(); - ping->set_src(group_id_, worker_id_, type); - ping->set_dst(-1,-1,kStub); + Msg* ping = new Msg(Addr(grp, id, entity), Addr(-1, -1, kStub)); ping->set_type(kConnect); - ping->add_frame("PING", 4); dealer->Send(&ping); - ping=dealer->Receive(); - string pong((char*)ping->frame_data(), ping->frame_size()); - CHECK_STREQ("PONG", pong.c_str()); - delete ping; } -void Worker::Run(){ - LOG(ERROR)<<"Worker (group_id = "<<group_id_ - <<", id = "<<worker_id_<<") starts"; - dealer_=make_shared<Dealer>(2*thread_id_); - ConnectStub(dealer_, kWorkerParam); - for(auto layer: train_net_->layers()) - if(layer->partition_id()==worker_id_) - if(layer->is_bridgedstlayer()||layer->is_bridgesrclayer()){ - layer_dealer_=make_shared<Dealer>(2*thread_id_+1); - ConnectStub(layer_dealer_, kWorkerLayer); +void Worker::Run() { + LOG(ERROR) << "Worker (group = " << grp_id_ <<", id = " << id_ << ") start"; + auto cluster = Cluster::Get(); + if (updater_==nullptr) { + int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group(); + CHECK(cluster->runtime()->JoinSGroup(grp_id_, id_, svr_grp)); + } + dealer_ = new Dealer(2*thread_id_); + ConnectStub(grp_id_, id_, dealer_, kWorkerParam); + for (auto layer : train_net_->layers()) { + if (layer->partition_id() == id_) { + if (layer->is_bridgelayer()) { + layer_dealer_ = new Dealer(2*thread_id_+1); + ConnectStub(grp_id_, id_, layer_dealer_, kWorkerLayer); break; } - step_=modelproto_.step(); - // init params - for(auto layer: train_net_->layers()){ - if(layer->partition_id()==worker_id_) - for(auto param: layer->GetParams()){ - // only owners fill the memory of parameter values. - // others share the memory with owners hence do not need to put/get. - if(param->owner() == param->id()){ - if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0) - param->InitValues(0); - else{ - Get(param, modelproto_.warmup_steps()); - } - } - } + } } + + step_ = modelproto_.step(); + InitLocalParams(); Metric perf; - if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0){ - for(step_=0;step_<modelproto_.warmup_steps();step_++) - RunOneBatch(step_, &perf); - for(auto layer: train_net_->layers()){ - if(layer->partition_id()==worker_id_) - for(auto param: layer->GetParams()) - if(param->owner()==param->id()) - Put(param, step_); + while (!StopNow(step_)) { + if (ValidateNow(step_)) { + //LOG(ERROR)<<"Validation at step "<<step; + CollectAll(validation_net_, step_); + Test(modelproto_.validation_steps(), kValidation, validation_net_); + } + if (TestNow(step_)) { + //LOG(ERROR)<<"Test at step "<<step; + CollectAll(test_net_, step_); + Test(modelproto_.test_steps(), kTest, test_net_); + } + TrainOneBatch(step_, &perf); + //LOG(ERROR)<<"Train "<<step; + if (DisplayNow(step_)) { + Report("Train", perf); + perf.Reset(); } - } - while(!StopNow(step_)){ - RunOneBatch(step_, &perf); step_++; } - Stop(); - LOG(ERROR)<<"Worker (group_id = "<<group_id_ - <<", id = "<<worker_id_<<") stops"; -} - -void Worker::Stop(){ - auto cluster=Cluster::Get(); - if(updater_ == nullptr){ - int sgid=group_id_/cluster->nworker_groups_per_server_group(); - cluster->runtime()->LeaveSGroup(group_id_, worker_id_, sgid); + // clean up + if(updater_ == nullptr) { + int svr_grp = grp_id_ / cluster->nworker_groups_per_server_group(); + cluster->runtime()->LeaveSGroup(grp_id_, id_, svr_grp); } - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1,-1, kStub); + // notify the stub on worker stop + Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1,-1, kStub)); msg->set_type(kStop); - dealer_->Send(&msg); // use param dealer to send the stop msg + dealer_->Send(&msg); // use param dealer to send the stop msg + + LOG(ERROR) << "Worker (group = " <<grp_id_ << ", id = " << id_ << ") stop"; +} + +void Worker::Resume() { + // TODO(wangwei) } -int Worker::Put(Param* param, int step){ - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1, -1, kStub); + +int Worker::Put(Param* param, int step) { + Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); + msg->set_trgt(ParamTrgt(param->owner(), 0), step); msg->set_type(kPut); - msg->set_trgt(param->owner(), 0, step); dealer_->Send(&msg); return 1; } -int Worker::Get(Param* param, int step){ - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1, -1, kStub); + +int Worker::Get(Param* param, int step) { + if (param->version() >= step) + return 1; + Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); + msg->set_trgt(ParamTrgt(param->owner(), 0), step); msg->set_type(kGet); - msg->set_trgt(param->owner(), 0, step); dealer_->Send(&msg); return 1; } -int Worker::Update(Param* param, int step){ + +int Worker::Update(Param* param, int step) { param->set_local_version(param->version()); - if(updater_){ + if (updater_) { updater_->Update(step, param); - param->set_version(param->version()+1); - }else{ - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1, -1, kStub); + param->set_version(param->version() + 1); + } else { + Msg* msg=new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); + msg->set_trgt(ParamTrgt(param->owner(), 0), step); msg->set_type(kUpdate); - msg->set_trgt(param->owner(), 0, step); dealer_->Send(&msg); } return 1; } -int Worker::CollectAll(shared_ptr<NeuralNet> net, int step){ - auto& layers=net->layers(); - for(auto& layer: layers){ - if(layer->partition_id()==worker_id_) - for(Param* p: layer->GetParams()){ +int Worker::CollectAll(shared_ptr<NeuralNet> net, int step) { + auto& layers = net->layers(); + for (auto& layer : layers){ + if (layer->partition_id() == id_) + for (Param* p: layer->GetParams()) { Collect(p, step); } } return 1; } -int Worker::Collect(Param* param, int step){ - while(param->version()<=param->local_version()){ +int Worker::Collect(Param* param, int step) { + while (param->version() <= param->local_version()) std::this_thread::sleep_for(std::chrono::milliseconds(kCollectSleepTime)); - } return 1; } -void Worker::DisplayPerformance(const string& prefix, const Metric & perf) { - Msg* msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerParam); - msg->set_dst(-1,-1, kStub); +void Worker::Report(const string& prefix, const Metric & perf) { + Msg* msg = new Msg(Addr(grp_id_, id_, kWorkerParam), Addr(-1, -1, kStub)); + msg->set_trgt(0, step_); msg->set_type(kMetric); - msg->set_trgt(step_,0,0); - msg->add_frame(prefix.c_str(), prefix.length()); const string disp = perf.ToString(); - msg->add_frame(disp.c_str(), disp.length()); + msg->AddFormatFrame("s", prefix.c_str()); + msg->AddFrame(disp.c_str(), disp.length()); dealer_->Send(&msg); } -void Worker::RunOneBatch(int step, Metric* perf){ - if(ValidateNow(step)){ - //LOG(ERROR)<<"Validation at step "<<step; - CollectAll(validation_net_, step); - Test(modelproto_.validation_steps(),kValidation, validation_net_); - } - if(TestNow(step)){ - //LOG(ERROR)<<"Test at step "<<step; - CollectAll(test_net_, step); - Test(modelproto_.test_steps(), kTest, test_net_); +void Worker::ReceiveBlobs( + bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net) { + while (!layer->ready()) { + auto msg = layer_dealer_->Receive(); + CHECK_EQ(AddrGrp(msg->src()), grp_id_); + string name(static_cast<char*>(msg->FrameData()), msg->FrameSize()); + auto receive_layer = net->name2layer(name); + CHECK(receive_layer->is_bridgelayer()); + auto data = receive_layer->mutable_data(nullptr); + msg->NextFrame(); + memcpy(data->mutable_cpu_data(), msg->FrameData(), msg->FrameSize()); + static_cast<BridgeLayer*>(receive_layer)->set_ready(true); + delete msg; } - TrainOneBatch(step, perf); - //LOG(ERROR)<<"Train "<<step; - if(perf!=nullptr){ - if(DisplayNow(step)){ - DisplayPerformance("Train", *perf); - perf->Reset(); - } - } - /* - if(CheckpointNow(step)){ - pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step)); - } - */ -} - -void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){ } -void Worker::SendBlob(){ +void Worker::SendBlobs( + bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net) { + auto dst=layer->dstlayers().at(0); + Msg *msg=new Msg(); + msg->set_src(Addr(grp_id_, id_, kWorkerLayer)); + msg->set_dst(Addr(grp_id_, dst->partition_id(), kWorkerLayer)); + msg->AddFrame(dst->name().c_str(), dst->name().length()); + auto const & blob=layer->data(nullptr); + msg->AddFrame(blob.cpu_data(), blob.count()*sizeof(float)); + layer_dealer_->Send(&msg); } -void Worker::Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net){ +void Worker::Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net) { Metric perf; - for(int step=0;step<nsteps;step++){ + for (int step = 0; step < nsteps; step++) TestOneBatch(step, phase, net, &perf); - } - //perf.Avg(); - if(phase==kValidation) - DisplayPerformance("Validation", perf); - else if (phase==kTest) - DisplayPerformance("Test", perf); + if (phase == kValidation) + Report("Validation", perf); + else if (phase == kTest) + Report("Test", perf); +} +bool Worker::DisplayNow(int step) const { + return (modelproto_.display_frequency() > 0 + && step >= modelproto_.display_after_steps() + && ((step - modelproto_.display_after_steps()) + % modelproto_.display_frequency() == 0)); } -/****************************BPWorker**********************************/ +bool Worker::DisplayDebugInfo(int step) const { + return DisplayNow(step) && modelproto_.debug() && grp_id_ == 0; +} +bool Worker::StopNow(int step) const { + return step >= modelproto_.train_steps(); +} +bool Worker::CheckpointNow(int step) const { + return (grp_id_ == 0 + && modelproto_.checkpoint_frequency() > 0 + && step >= modelproto_.checkpoint_after_steps() + && ((step - modelproto_.checkpoint_after_steps()) + % modelproto_.checkpoint_frequency() == 0)); +} +bool Worker::TestNow(const int step) const { + return (grp_id_ == 0 + && modelproto_.test_frequency() > 0 + && modelproto_.test_steps() > 0 + && step >= modelproto_.test_after_steps() + && ((step - modelproto_.test_after_steps()) + % modelproto_.test_frequency() == 0)); +} +bool Worker::ValidateNow(const int step) const { + return (grp_id_ == 0 + && modelproto_.validation_frequency() > 0 + && modelproto_.validation_steps() > 0 + && step >= modelproto_.validation_after_steps() + && ((step - modelproto_.validation_after_steps()) + % modelproto_.validation_frequency() == 0)); +} + +/****************************BPWorker**********************************/ BPWorker::BPWorker(int thread_id, int group_id, int worker_id): - Worker(thread_id, group_id, worker_id){ + Worker(thread_id, group_id, worker_id) { } -void BPWorker::Forward(int step, Phase phase, shared_ptr<NeuralNet> net, - Metric* perf){ - auto& layers=net->layers(); - for(auto& layer: layers){ - if(layer->partition_id()==worker_id_){ - if(layer->is_bridgedstlayer()){ - auto* dst=static_cast<BridgeDstLayer*>(layer); - while(!dst->ready()){ - auto msg=layer_dealer_->Receive(); - CHECK_EQ(msg->src_first(), group_id_); - string name((char*)msg->frame_data(), msg->frame_size()); - auto tmp=net->name2layer(name); - CHECK(tmp->is_bridgedstlayer()); - auto* dstlayer=static_cast<BridgeDstLayer*>(tmp); - auto data=dstlayer->mutable_data(nullptr); - msg->next_frame(); - memcpy(data->mutable_cpu_data(), msg->frame_data(), msg->frame_size()); - dstlayer->set_ready(true); - delete msg; - } - } - if(phase==kTrain){ - for(Param* p: layer->GetParams()){ +void BPWorker::Forward( + int step, Phase phase, shared_ptr<NeuralNet> net, Metric* perf) { + for (auto& layer : net->layers()) { + if (layer->partition_id() == id_) { + if (layer->is_bridgedstlayer()) // recv data from other workers + ReceiveBlobs(true, false, static_cast<BridgeLayer*>(layer), net); + if (phase == kTrain) { + for (Param* p : layer->GetParams()) { // wait until param is updated Collect(p, step); } } - //clock_t s=clock(); layer->ComputeFeature(phase, perf); - //LOG(ERROR)<<layer->name()<<":"<<(clock()-s)*1.0/CLOCKS_PER_SEC; - if(layer->is_bridgesrclayer()){ - auto dst=layer->dstlayers().at(0); - Msg *msg=new Msg(); - msg->set_src(group_id_, worker_id_, kWorkerLayer); - msg->set_dst(group_id_, dst->partition_id(), kWorkerLayer); - msg->add_frame(dst->name().c_str(), dst->name().length()); - auto const & blob=layer->data(nullptr); - msg->add_frame(blob.cpu_data(), blob.count()*sizeof(float)); - layer_dealer_->Send(&msg); - } - if(phase == kTrain && DisplayDebugInfo(step)) + if (layer->is_bridgesrclayer()) // send data to other workers + SendBlobs(true, false, static_cast<BridgeLayer*>(layer), net); + if (DisplayDebugInfo(step)) LOG(INFO) << layer->DebugString(step, kForward); } } } -void BPWorker::Backward(int step, shared_ptr<NeuralNet> net){ +void BPWorker::Backward(int step, shared_ptr<NeuralNet> net) { auto& layers=net->layers(); for (auto it = layers.rbegin(); it != layers.rend(); it++){ - Layer* layer=*it; - if(layer->partition_id()==worker_id_){ - if(layer->is_bridgesrclayer()){ - //auto* src=static_cast<BridgeSrcLayer*>(layer.get()); - // receive grad blobs + Layer* layer = *it; + if (layer->partition_id() == id_) { + if(layer->is_bridgesrclayer()) { + // ReceiveBlobs(false, true, layer, net); } layer->ComputeGradient(kTrain); - if(DisplayDebugInfo(step)) + if (DisplayDebugInfo(step)) LOG(INFO) << layer->DebugString(step, kBackward); - for(Param* p: layer->GetParams()) + for (Param* p : layer->GetParams()) Update(p, step); - if(layer->is_bridgedstlayer()){ - // send grad blobs + if (layer->is_bridgedstlayer()) { + // SendBlobs(false, true, layer); } } } } -void BPWorker::TrainOneBatch(int step, Metric* perf){ +void BPWorker::TrainOneBatch(int step, Metric* perf) { Forward(step, kTrain, train_net_, perf); Backward(step, train_net_); - auto losslayers=train_net_->losslayers(); } void BPWorker::TestOneBatch(int step, Phase phase, - shared_ptr<NeuralNet> net, Metric* perf){ + shared_ptr<NeuralNet> net, Metric* perf) { Forward(step, phase, net, perf); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/cluster.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc index 0c4eefa..9c57c42 100644 --- a/src/utils/cluster.cc +++ b/src/utils/cluster.cc @@ -56,13 +56,14 @@ Cluster::Cluster(const GlobalProto & global, const ClusterProto &cluster, hostip_=GetHostIP(); } -void Cluster::Register(const string& endpoint){ +void Cluster::Register(const string& endpoint) { procs_id_=cluster_rt_->RegistProc(endpoint); CHECK_GE(procs_id_,0); CHECK_LT(procs_id_,nprocs()); LOG(ERROR) << "proc #" << procs_id_ << " -> " << endpoint; } -const string Cluster::endpoint(int procsid) const{ + +const string Cluster::endpoint(int procsid) const { CHECK_LT(procsid, nprocs()); CHECK_GE(procsid, 0); if(endpoints_.size()) @@ -70,6 +71,7 @@ const string Cluster::endpoint(int procsid) const{ else return cluster_rt_->GetProcHost(procsid); } + void Cluster::SetupFolders(const ClusterProto &cluster){ // create visulization folder mkdir(vis_folder().c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/common.cc ---------------------------------------------------------------------- diff --git a/src/utils/common.cc b/src/utils/common.cc index 11a19f8..f733497 100644 --- a/src/utils/common.cc +++ b/src/utils/common.cc @@ -160,6 +160,10 @@ void SetupLog(const std::string& log_dir, const std::string& model) { google::SetLogDestination(google::FATAL, fatal.c_str()); } +Metric::Metric(const std::string& str) { + ParseFrom(str); +} + void Metric::Add(const string& name, float value) { if(entry_.find(name) == entry_.end()) entry_[name] = std::make_pair(1, value); @@ -176,7 +180,7 @@ void Metric::Reset() { e.second.second = 0; } } -const string Metric::ToLogString() const{ +const string Metric::ToLogString() const { string ret; size_t k = 0; for(auto e : entry_) { @@ -188,7 +192,7 @@ const string Metric::ToLogString() const{ return ret; } -const string Metric::ToString() const{ +const string Metric::ToString() const { MetricProto proto; for(auto e : entry_) { proto.add_name(e.first); @@ -208,4 +212,89 @@ void Metric::ParseFrom(const string& msg) { entry_[proto.name(i)] = std::make_pair(proto.count(i), proto.val(i)); } } + + +const vector<vector<int>> Slice(int num, const vector<int>& sizes) { + vector<vector<int>> slices; + if (num == 0) + return slices; + int avg = 0; + for(int x : sizes) + avg += x; + avg = avg / num + avg % num; + int diff = avg / 10; + LOG(INFO) << "Slicer, param avg=" << avg << ", diff= " << diff; + + int capacity = avg, nbox = 0; + for (int x : sizes) { + vector<int> slice; + string slicestr = ""; + while (x > 0) { + int size=0; + if (capacity >= x) { + capacity -= x; + size = x; + x = 0; + }else if(capacity + diff >= x) { + size = x; + x = 0; + capacity = 0; + }else if (capacity >= diff) { + x -= capacity; + size = capacity; + capacity = avg; + nbox++; + } else { + capacity = avg; + nbox++; + } + if (size) { + slice.push_back(size); + slicestr += ", " + std::to_string(size); + } + } + LOG(INFO) << slicestr; + slices.push_back(slice); + } + CHECK_LE(nbox, num); + return slices; +} + +const vector<int> PartitionSlices(int num, const vector<int>& slices) { + vector<int> slice2box; + if (num == 0) + return slice2box; + int avg = 0; + for(int x : slices) + avg += x; + avg = avg / num + avg % num; + int box = avg, boxid = 0, diff = avg / 10; + for (auto it = slices.begin(); it != slices.end();) { + int x = *it; + if (box >= x) { + box -= x; + slice2box.push_back(boxid); + it++; + } else if (box + diff >= x) { + slice2box.push_back(boxid); + it++; + box = 0; + } else { + box = avg; + boxid++; + } + } + CHECK_EQ(slice2box.size(), slices.size()); + int previd = -1; + std::string disp; + for (size_t i = 0; i < slice2box.size(); i++) { + if (previd != slice2box[i]) { + previd = slice2box[i]; + disp += " box = " +std::to_string(previd) + ":"; + } + disp += " " + std::to_string(slices[i]); + } + LOG(INFO) << "partition slice (avg =" << avg << ", num="<<num<<"):" << disp; + return slice2box; +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/585e275f/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index 24a0541..8b1f113 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -11,8 +11,8 @@ using std::vector; using std::string; namespace singa { -Param::Param():data_(nullptr), slice_start_(0), num_slices_(0), - num_pending_requests_(0),local_version_(-1){ +Param::Param():local_version_(-1), slice_start_(0), num_slices_(0), + num_pending_requests_(0), data_(nullptr) { } void Param::Setup(const ParamProto& proto, const vector<int>& shape){ data_=std::make_shared<Blob<float>>(shape); @@ -82,96 +82,87 @@ void Param::InitValues(int version){ } /**************Message related functions********/ -Msg* Param::GenPutMsg(bool copy, int idx){ +Msg* Param::GenPutMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kPut); - char buf[128]; - sprintf(buf, "%d %f %f", slice_size_[idx], - learning_rate_multiplier(), weight_decay_multiplier()); void *ptr=mutable_cpu_data()+slice_offset_[idx]; - if(copy){ - sprintf(buf+strlen(buf), " %p ", nullptr); - msg->add_frame(buf, strlen(buf)); - msg->add_frame(ptr, slice_size_[idx]*sizeof(float)); - }else{ - sprintf(buf+strlen(buf), " %p ", ptr); - msg->add_frame(buf, strlen(buf)); + void *p = ptr; + if (copy) p = nullptr; + msg->AddFormatFrame("iffp", slice_size_[idx], + learning_rate_multiplier(), weight_decay_multiplier(), p); + if (copy) { + msg->AddFrame(ptr, slice_size_[idx]*sizeof(float)); } //pending_put_[idx]=true; //num_pending_requests_++; return msg; } -Msg* Param::GenGetMsg(bool copy, int idx){ +Msg* Param::GenGetMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kGet); - char buf[32]; sprintf(buf, " %d %p ", copy, - data_->cpu_data()+slice_offset_[idx]); - msg->add_frame(buf, sizeof(buf)); + msg->AddFormatFrame("ip", copy, data_->cpu_data()+slice_offset_[idx]); pending_get_[idx]=true; num_pending_requests_++; return msg; } -Msg* Param::GenUpdateMsg(bool copy, int idx){ +Msg* Param::GenUpdateMsg(bool copy, int idx) { CHECK_LT(idx, num_slices_); Msg* msg=new Msg(); msg->set_type(kUpdate); - char buf[8]; sprintf(buf, " %d ", copy); - msg->add_frame(buf, sizeof(buf)); + msg->AddFormatFrame("i", copy); void* ptr=grad_.mutable_cpu_data()+slice_offset_[idx]; if(copy){ //LOG(ERROR)<<"Copy in gen update"; - msg->add_frame(ptr, slice_size_[idx]*sizeof(float)); - } - else{ // to share values of grad blob - char buf[32]; sprintf(buf, " %p ", ptr); - msg->add_frame(buf, strlen(buf)); + msg->AddFrame(ptr, slice_size_[idx]*sizeof(float)); + } else { // to share values of grad blob + msg->AddFormatFrame("p", ptr); } pending_update_[idx]=true; num_pending_requests_++; return msg; } -Msg* Param::GenSyncMsg(int offset, int size){ +Msg* Param::GenSyncMsg(int offset, int size) { Msg* msg=new Msg(); msg->set_type(kSyncRequest); - msg->set_trgt(-1, id(), local_version()); - msg->add_frame(mutable_cpu_data(), data_->count()*sizeof(float)); + msg->set_trgt(ParamTrgt(-1, id()), local_version()); + // always copy data because syn is between server groups in diff procs + msg->AddFrame(mutable_cpu_data(), data_->count()*sizeof(float)); return msg; } -Msg* Param::HandlePutMsg(Msg** msg){ +Msg* Param::HandlePutMsg(Msg** msg, bool reserve) { int size; float lr, wc; float* ptr; - sscanf(static_cast<char*>((*msg)->frame_data()), - "%d %f %f %p ", &size, &lr, &wc, &ptr); + (*msg)->ParseFormatFrame("iffp", &size, &lr, &wc, &ptr); proto_.set_learning_rate_multiplier(lr); proto_.set_weight_decay_multiplier(wc); vector<int> shape{size}; ParamProto proto; Setup(proto, shape); - if(ptr==nullptr){ - CHECK((*msg)->next_frame()); - CHECK_EQ(size* sizeof(float), (*msg)->frame_size()); - memcpy(mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float)); + if (ptr == nullptr) { + CHECK((*msg)->NextFrame()); + CHECK_EQ(size* sizeof(float), (*msg)->FrameSize()); + memcpy(mutable_cpu_data(), (*msg)->FrameData(), size*sizeof(float)); }else{ data_->set_cpu_data(ptr); } - DeleteMsg(msg); + if (!reserve) + DeleteMsg(msg); return nullptr; } -Msg* Param::HandleGetMsg(Msg** msg){ +Msg* Param::HandleGetMsg(Msg** msg, bool reserve) { int copy; float* ptr; - sscanf(static_cast<char*>((*msg)->frame_data()), " %d %p ", ©, &ptr); - (*msg)->next_frame(); + (*msg)->ParseFormatFrame("ip", ©, &ptr); if(copy) - (*msg)->add_frame(mutable_cpu_data(), sizeof(float)*size()); + (*msg)->AddFrame(mutable_cpu_data(), sizeof(float)*size()); else if(ptr!=data_->cpu_data()){ memcpy(ptr, data_->cpu_data(), sizeof(float)*size()); data_->set_cpu_data(ptr); @@ -182,73 +173,127 @@ Msg* Param::HandleGetMsg(Msg** msg){ return *msg; } -int Param::ParseUpdateMsg(Msg** msg){ - int copy; - sscanf(static_cast<char*>((*msg)->frame_data()), " %d ", ©); - (*msg)->next_frame(); - if(copy){ - //LOG(ERROR)<<"Copy in parse update"; - CHECK((*msg)->frame_size()); - memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size()); - }else {// use the same data field of the grad blob - float* ptr=nullptr; - sscanf(static_cast<char*>((*msg)->frame_data()), " %p ", &ptr); - grad_.set_cpu_data(ptr); +void Param::ParseUpdateMsgs(const vector<Msg*>& msgs) { + bool reset = true; + vector<int> copies; + for (auto *msg : msgs) { + int copy; + msg->ParseFormatFrame("i", ©); + reset = reset && copy; + copies.push_back(copy); + } + int idx = 0; + for (auto *msg : msgs) { + CHECK(msg->NextFrame()); + if (copies.at(idx++)) { + float* server_grad = mutable_cpu_grad(); + float* worker_grad = static_cast<float*> (msg->FrameData()); + if (reset) { + memcpy(server_grad, worker_grad, sizeof(float) * size()); + reset = false; + } else { + for (int i =0; i < size(); i++) + server_grad[i] += worker_grad[i]; + } + } else { + float* ptr = nullptr; + msg->ParseFormatFrame("p", &ptr); + if (grad_.cpu_data() != ptr) { + memcpy(ptr, grad_.cpu_data(), msg->FrameSize()); + grad_.set_cpu_data(ptr); + } + } } - DeleteMsg(msg); - return copy; -} -Msg* Param::GenUpdateResponseMsg(bool copy){ - Msg* msg=new Msg(); - msg->set_type(kRUpdate); - char buf[8]; sprintf(buf, " %d ", copy); - msg->add_frame(buf, sizeof(buf)); - if(copy){ - //LOG(ERROR)<<"Copy in gen"; - // LOG(ERROR)<<"gen copy resonse for "<<id()<<", "<<size(); - msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); + if (msgs.size() > 1) { + float* server_grad = mutable_cpu_grad(); + for (int i = 0; i < size(); i++) + server_grad[i] /= msgs.size(); } - // LOG(ERROR)<<"gen share resonse for "<<id()<<", "<<size(); +} - return msg; +const vector<Msg*> Param::GenUpdateResponseMsgs(const vector<Msg*>& msgs) { + vector<Msg*> ret; + for (auto msg : msgs) { + msg->FirstFrame(); + msg->SwapAddr(); + msg->set_type(kRUpdate); + int copy; + msg->ParseFormatFrame("i", ©); + if (copy) { + msg->NextFrame(); + CHECK_EQ(msg->FrameSize(), sizeof(float) * size()); + memcpy(msg->FrameData(), mutable_cpu_data(), msg->FrameSize()); + } + ret.push_back(msg); + } + return ret; } -Msg* Param::HandleSyncMsg(Msg** msg){ - DeleteMsg(msg); +Msg* Param::HandleSyncMsg(Msg** msg, bool reserve) { + if (!reserve) + DeleteMsg(msg); return nullptr; } -int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){ - DeleteMsg(msg); +int Param::ParseSyncResponseMsg(Msg* msg, int slice_idx) { return 1; } -int Param::ParseGetResponseMsg(Msg **msg, int slice_idx){ +int Param::ParseGetResponseMsg(Msg *msg, int slice_idx) { CHECK_EQ(pending_get_[slice_idx], true); pending_get_[slice_idx]=false; ParseResponseMsg(msg, slice_idx); return (--num_pending_requests_)%num_slices_==0; } -int Param::ParseUpdateResponseMsg(Msg **msg, int slice_idx){ +int Param::ParseUpdateResponseMsg(Msg *msg, int slice_idx) { CHECK_EQ(pending_update_[slice_idx], true); pending_update_[slice_idx]=false; ParseResponseMsg(msg, slice_idx); - return (--num_pending_requests_)%num_slices_==0; + return (--num_pending_requests_) % num_slices_==0; } -void Param::ParseResponseMsg(Msg** msg, int slice_idx){ +void Param::ParseResponseMsg(Msg* msg, int slice_idx) { int copy; - sscanf(static_cast<char*>((*msg)->frame_data()), " %d ", ©); - (*msg)->next_frame(); - if(copy){ - CHECK_EQ((*msg)->frame_size(), slice_size_[slice_idx]*sizeof(float)); + msg->ParseFormatFrame("i", ©); + msg->NextFrame(); + if(copy) { + CHECK_EQ(msg->FrameSize(), slice_size_[slice_idx]*sizeof(float)); memcpy(mutable_cpu_data()+slice_offset_[slice_idx], - (*msg)->frame_data(), (*msg)->frame_size()); + msg->FrameData(), msg->FrameSize()); } //LOG(ERROR)<<"parse response norm "<<data_->asum_data()<<" of "<<id(); - DeleteMsg(msg); +} + +void Param::ShareFrom(const Param& other) { + proto_.set_owner(other.owner()); + if(data_!=nullptr) + CHECK(std::equal(data_->shape().begin(), data_->shape().end(), + other.data_->shape().begin())); + data_ = other.data_; + slice_offset_ = other.slice_offset_; + slice_size_ = other.slice_size_; + slice_start_ = other.slice_start_; + num_slices_ = other.num_slices_; + pending_get_ = other.pending_get_; + pending_put_ = other.pending_put_; + pending_update_ = other.pending_update_; +} + +/************************ParamEntry***************************/ +ParamEntry::ParamEntry(): + num_update(0), next_version(-1), num_local(0), num_total(0) { +} + +ParamEntry::ParamEntry(int total, Param* p) : num_update(0), num_total(total) { + shares.push_back(p); +} +void ParamEntry::AddParam(bool local, Param* p) { + num_local += local; + num_total += 1; + if(local) + shares.push_back(p); } }
