SINGA-8 Implement distributed Hogwild The program is simply tested using two processes. TODO 1. read process endpoints from the zookeeper instead of hard-coding them. 2. split large parameters to avoid load-balance issue among server groups. currently, server groups are assigned (almost) equal number of param objects. but these objects may be quite different in terms of memory space.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f4370118 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f4370118 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f4370118 Branch: refs/heads/master Commit: f4370118c91f688fdc8c84d0d590096f2e93586c Parents: a019958 Author: wang wei <[email protected]> Authored: Wed Jun 17 16:17:19 2015 +0800 Committer: wang wei <[email protected]> Committed: Thu Jun 25 11:49:32 2015 +0800 ---------------------------------------------------------------------- examples/cifar10/cluster-dist.conf | 8 +++ examples/cifar10/hostfile | 22 +----- include/communication/msg.h | 2 +- include/communication/socket.h | 11 ++- include/trainer/server.h | 14 ++-- include/trainer/trainer.h | 5 +- include/utils/cluster.h | 10 ++- include/utils/param.h | 3 +- src/communication/socket.cc | 8 ++- src/proto/cluster.proto | 6 ++ src/test/test_paramslicer.cc | 47 +++++++++++++ src/trainer/server.cc | 121 ++++++++++++++++++++++++++------ src/trainer/trainer.cc | 62 +++++++++++++++- src/trainer/worker.cc | 8 ++- src/utils/param.cc | 15 ++-- 15 files changed, 272 insertions(+), 70 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/examples/cifar10/cluster-dist.conf ---------------------------------------------------------------------- diff --git a/examples/cifar10/cluster-dist.conf b/examples/cifar10/cluster-dist.conf new file mode 100644 index 0000000..1a4e2c2 --- /dev/null +++ b/examples/cifar10/cluster-dist.conf @@ -0,0 +1,8 @@ +nworker_groups: 2 +nserver_groups: 2 +nservers_per_group: 1 +nworkers_per_group: 1 +nworkers_per_procs: 1 +workspace: "examples/cifar10/" +hostfile: "examples/cifar10/hostfile" +poll_time: 100 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/examples/cifar10/hostfile ---------------------------------------------------------------------- diff --git a/examples/cifar10/hostfile b/examples/cifar10/hostfile index 83e06e5..eda7414 100644 --- a/examples/cifar10/hostfile +++ b/examples/cifar10/hostfile @@ -1,20 +1,2 @@ -awan-2-26-0 -awan-2-27-0 -awan-2-28-0 -awan-2-29-0 -awan-2-30-0 -awan-2-31-0 -awan-2-32-0 -awan-2-33-0 -awan-2-34-0 -awan-2-35-0 -awan-2-36-0 -awan-2-37-0 -awan-2-38-0 -awan-2-39-0 -awan-2-40-0 -awan-2-41-0 -awan-2-42-0 -awan-2-43-0 -awan-2-44-0 -awan-2-45-0 +localhost:9733 +localhost:9734 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/communication/msg.h ---------------------------------------------------------------------- diff --git a/include/communication/msg.h b/include/communication/msg.h index c3ef1c7..60a359a 100644 --- a/include/communication/msg.h +++ b/include/communication/msg.h @@ -23,6 +23,7 @@ class Msg { * @param second worker/server id within the group * @param flag 0 for server, 1 for worker, 2 for stub */ +<<<<<<< HEAD inline void set_src(int first, int second, int flag) { src_ = (first << kOff1) | (second << kOff2) | flag; } @@ -78,7 +79,6 @@ class Msg { void ParseFromZmsg(zmsg_t* msg); zmsg_t* DumpToZmsg(); #endif - protected: static const unsigned int kOff1 = 16; static const unsigned int kOff2 = 4; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/communication/socket.h ---------------------------------------------------------------------- diff --git a/include/communication/socket.h b/include/communication/socket.h index d1cb400..b98656e 100644 --- a/include/communication/socket.h +++ b/include/communication/socket.h @@ -19,10 +19,10 @@ class SocketInterface { public: virtual ~SocketInterface() {} /** - * Send a message to connected socket(s), non-blocking. The message - * will be deallocated after sending, thus should not be used after + * Send a message to connected socket(s), non-blocking. The message + * will be deallocated after sending, thus should not be used after * calling Send(); - * + * * @param msg The message to be sent * @return 1 for success queuing the message for sending, 0 for failure */ @@ -56,6 +56,11 @@ class Poller { */ SocketInterface* Wait(int duration); + /** + * @return true if the poller is terminated due to process interupt + */ + virtual bool Terminated()=0; + protected: #ifdef USE_ZMQ zpoller_t *poller_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h index b07741f..a8995fb 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -27,6 +27,12 @@ class Server{ void Setup(const UpdaterProto& proto, shared_ptr<ServerShard> shard, const vector<int>& slice2group); void Run(); + const int group_id() const { + return group_id_; + } + const int server_id() const { + return server_id_; + } protected: @@ -50,24 +56,20 @@ class Server{ * @return the original message or response message. If we don't want need to * acknowledge the put request, then return nullptr. */ - virtual void HandlePut(shared_ptr<Param> param, Msg **msg); + virtual Msg* HandlePut(Msg **msg); /** * TODO Process SYNC request. */ virtual Msg* HandleSyncRequest(shared_ptr<Param> param, Msg** msg); - /** - * TODO Process SYNC response. - virtual int HandleSyncResponse(shared_ptr<Param> param, Msg** msg); - */ - protected: int thread_id_,group_id_, server_id_; shared_ptr<Dealer> dealer_; shared_ptr<Updater> updater_; shared_ptr<ServerShard> shard_; vector<int> slice2group_; + std::map<int, shared_ptr<Blob<float>>> last_data_; }; } /* Server */ #endif //INCLUDE_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index ed93374..fbbfd0b 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -95,13 +95,14 @@ class Trainer{ // point. protected: - vector<shared_ptr<Server>> CreateServers(int nthread, const ModelProto& mproto, const vector<int> slices, vector<HandleContext*>* ctx); vector<shared_ptr<Worker>> CreateWorkers(int nthread, const ModelProto& mproto, vector<int> *slice_size); - void Run(int nworkers, int nservers); + void Run(const vector<shared_ptr<Worker>>& workers, + const vector<shared_ptr<Server>>& servers, + const std::map<int, shared_ptr<ParamShard>>& shards); /** * Register default implementations for all base classes used in the system, * e.g., the Updater, BaseMsg, etc. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/utils/cluster.h ---------------------------------------------------------------------- diff --git a/include/utils/cluster.h b/include/utils/cluster.h index 9648bfe..0eeb808 100644 --- a/include/utils/cluster.h +++ b/include/utils/cluster.h @@ -112,11 +112,15 @@ class Cluster { } /** - * bandwidth MB/s - float bandwidth() const { + * bandwidth Bytes/s + */ + const int bandwidth() const { return cluster_.bandwidth(); } - */ + + const int poll_time() const { + return cluster_.poll_time(); + } shared_ptr<ClusterRuntime> runtime() const { return cluster_rt_; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/include/utils/param.h ---------------------------------------------------------------------- diff --git a/include/utils/param.h b/include/utils/param.h index 897c97a..d449fba 100644 --- a/include/utils/param.h +++ b/include/utils/param.h @@ -71,6 +71,7 @@ class Param { */ virtual Msg* HandleSyncMsg(Msg** msg); +<<<<<<< HEAD /** * Server parses update request message. * @@ -105,6 +106,7 @@ class Param { * @param shape */ virtual void Setup(const ParamProto& proto, const std::vector<int>& shape); + virtual void Setup(const vector<int>& shape); /* * Fill the values according to initmethod, e.g., gaussian distribution * @@ -238,7 +240,6 @@ class Param { ParamProto proto_; int local_version_; }; - } // namespace singa #endif // INCLUDE_UTILS_PARAM_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/communication/socket.cc ---------------------------------------------------------------------- diff --git a/src/communication/socket.cc b/src/communication/socket.cc index 5321724..38c0d79 100644 --- a/src/communication/socket.cc +++ b/src/communication/socket.cc @@ -19,9 +19,14 @@ SocketInterface* Poller::Wait(int timeout) { zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout)); if (sock != nullptr) return zsock2Socket_[sock]; - return nullptr; + else + return nullptr; +} +bool Poller::Terminated(){ + return zpoller_terminated(poller_); } + Dealer::Dealer() : Dealer(-1) {} Dealer::Dealer(int id) : id_(id) { @@ -31,6 +36,7 @@ Dealer::Dealer(int id) : id_(id) { CHECK_NOTNULL(poller_); } +<<<<<<< HEAD Dealer::~Dealer() { zsock_destroy(&dealer_); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/proto/cluster.proto ---------------------------------------------------------------------- diff --git a/src/proto/cluster.proto b/src/proto/cluster.proto index 4f7e661..3317f2a 100644 --- a/src/proto/cluster.proto +++ b/src/proto/cluster.proto @@ -38,6 +38,12 @@ message ClusterProto { optional bool server_update = 40 [default = true]; // share memory space between worker groups in one procs optional bool share_memory = 41 [default = true]; + + // bandwidth of ethernet, Bytes per second, default is 1 Gbps + optional int32 bandwidth=50 [default=134217728]; + // poll time in milliseconds + optional int32 poll_time=51 [default =100]; +>>>>>>> SINGA-8 Implement distributed Hogwild } message ServerTopology { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/test/test_paramslicer.cc ---------------------------------------------------------------------- diff --git a/src/test/test_paramslicer.cc b/src/test/test_paramslicer.cc new file mode 100644 index 0000000..bbff616 --- /dev/null +++ b/src/test/test_paramslicer.cc @@ -0,0 +1,47 @@ +#include "utils/param.h" +#include "gtest/gtest.h" + + +using namespace singa; + +const int param_size[]={2400,32,25600,32, 51200,64,57600,10}; + +class ParamSlicerTest : public ::testing::Test { + public: + ParamSlicerTest() { + ParamProto proto; + int nparams=sizeof(param_size)/sizeof(int); + for(int i=0;i<nparams;i++){ + vector<int> shape{param_size[i]}; + auto param=std::make_shared<Param>(); + param->Setup(proto, shape); + param->set_id(i); + params.push_back(param); + } + } + protected: + vector<shared_ptr<Param>> params; +}; + +// all params are stored in one box, no need to split +TEST_F(ParamSlicerTest, OneBox){ + int nparams=sizeof(param_size)/sizeof(int); + ParamSlicer slicer; + int num=1; + auto slices=slicer.Slice(num, params); + ASSERT_EQ(slices.size(),nparams); + ASSERT_EQ(slicer.Get(1).size(),1); + ASSERT_EQ(slicer.Get(2).size(),1); + ASSERT_EQ(slicer.Get(nparams-1).back(), slices.size()-1); +} + +// there are multiple boxes +TEST_F(ParamSlicerTest, MultipleBox){ + int nparams=sizeof(param_size)/sizeof(int); + ParamSlicer slicer; + int num=4; + auto slices=slicer.Slice(num, params); + ASSERT_EQ(slicer.Get(1).size(),1); + ASSERT_EQ(slicer.Get(3).size(),1); + ASSERT_EQ(slicer.Get(nparams-1).back(), slices.size()-1); +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index 04f6040..5185c51 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -1,13 +1,14 @@ #include <list> #include <tuple> #include <queue> +#include "mshadow/tensor.h" #include "trainer/server.h" #include "utils/param.h" #include "utils/singleton.h" #include "utils/factory.h" #include "utils/cluster.h" - +using namespace mshadow; namespace singa { Server::Server(int thread_id,int group_id, int server_id): thread_id_(thread_id),group_id_(group_id), server_id_(server_id){} @@ -23,21 +24,23 @@ void Server::Setup(const UpdaterProto& proto, } void Server::Run(){ + LOG(INFO)<<"Server (group_id= "<<group_id_<<", id="<<server_id_<<") starts"; dealer_=std::make_shared<Dealer>(2*thread_id_); dealer_->Connect(kInprocRouterEndpoint); - + auto cluster=Cluster::Get(); Msg* ping=new Msg(); ping->set_src(group_id_, server_id_, kServer); ping->set_dst(-1,-1,kStub); ping->add_frame("PING", 4); ping->set_type(kConnect); dealer_->Send(&ping); + int syncEntry=0; //start recv loop and process requests while (true){ Msg* msg=dealer_->Receive(); if (msg==nullptr) break; - Msg* response=nullptr; + Msg* response=nullptr, *sync=nullptr; int type=msg->type(); if (type== kStop){ msg->set_src(group_id_, server_id_, kServer); @@ -48,26 +51,47 @@ void Server::Run(){ // TODO remove receiving pong msg string pong((char*)msg->frame_data(), msg->frame_size()); CHECK_STREQ("PONG", pong.c_str()); - delete msg; + DeleteMsg(&msg); }else if(type==kPut){ - int pid=msg->trgt_second(); - shared_ptr<Param> param=nullptr; - if(shard_->find(pid)!=shard_->end()){ - LOG(ERROR)<<"Param ("<<pid<<") is put more than once"; - param=shard_->at(pid); - }else{ - param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance() - ->Create("Param")); - param->set_id(pid); - (*shard_)[pid]=param; - } - HandlePut(param, &msg); + response = HandlePut(&msg); }else{ int pid=msg->trgt_second(); if(shard_->find(pid)==shard_->end()){ // delay the processing by re-queue the msg. response=msg; DLOG(ERROR)<<"Requeue msg"; + }else if(type==kSyncReminder){ + DeleteMsg(&msg); + unsigned nchecks=0, nparams=shard_->size(); + while(nchecks<nparams + &&group_locator_->at(shard_->at(syncEntry))!=group_id_){ + syncEntry=(syncEntry+1)%nparams; + nchecks++; + } + if(nchecks==nparams) continue; + auto param=shard_->at(syncEntry); + if(param->local_version()!=param->version()){ + sync=param->GenSyncMsg(true); + for(int i=0;i<cluster->nserver_groups();i++){ + if(i!=group_id_) { + Msg* tmp=sync; + if(i<cluster->nserver_groups()-1) + tmp= new Msg(sync); + tmp->set_dst(i, server_locator_->at(param), kServer); + tmp->set_src(group_id_, server_id_, kServer); + dealer_->Send(&tmp); + param->set_version(param->local_version()); + //DLOG(ERROR)<<"sync"; + } + } + } + }else { + int pid=msg->target_first(); + if(shard_->find(pid)==shard_->end()){ + // delay the processing by re-queue the msg. + response=msg; + LOG(ERROR)<<"Requeue"; +>>>>>>> SINGA-8 Implement distributed Hogwild } else{ auto param=shard_->at(pid); switch (type){ @@ -80,20 +104,42 @@ void Server::Run(){ case kSyncRequest: response = HandleSyncRequest(param, &msg); break; - } - if (response!=nullptr){ - dealer_->Send(&response); + default: + LOG(ERROR)<<"Unknown message type "<<type; + break; } } } + if (response!=nullptr) + dealer_->Send(&response); } + LOG(INFO)<<"Server (group_id= "<<group_id_<<", id="<<server_id_<<") stops"; } -void Server::HandlePut(shared_ptr<Param> param, Msg **msg){ +Msg* Server::HandlePut(Msg **msg){ int version=(*msg)->trgt_third(); - param->HandlePutMsg(msg); + int pid=(*msg)->target_first(); + shared_ptr<Param> param=nullptr; + if(shard_->find(pid)!=shard_->end()){ + LOG(ERROR)<<"Param ("<<pid<<") is put more than once"; + param=shard_->at(pid); + }else{ + auto factory=Singleton<Factory<Param>>::Instance(); + param=shared_ptr<Param>(factory ->Create("Param")); + param->set_id(pid); + (*shard_)[pid]=param; + } + auto response=param->HandlePutMsg(msg); // must set version after HandlePutMsg which allocates the memory param->set_version(version); + if(Cluster::Get()->nserver_groups()>1 && + group_locator_->at(param)!=group_id_){ + last_data_[pid]=std::make_shared<Blob<float>>(); + last_data_[pid]->ReshapeLike(param->data()); + last_data_[pid]->CopyFrom(param->data()); + } + LOG(INFO)<<"Server put param "<<pid<<" size="<<param->size()<<" Bytes"; + return response; } Msg* Server::HandleGet(shared_ptr<Param> param, Msg **msg){ @@ -124,7 +170,36 @@ Msg* Server::HandleUpdate(shared_ptr<Param> param, Msg **msg) { } Msg* Server::HandleSyncRequest(shared_ptr<Param> param, Msg **msg){ - return param->HandleSyncMsg(msg); + Msg* response=nullptr; + auto shape=Shape1(param->size()); + CHECK_EQ((*msg)->frame_size(), param->size()*sizeof(float)); + Tensor<cpu, 1> tmp(static_cast<float*>((*msg)->frame_data()), shape); + Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape); + if(group_locator_->at(param)==group_id_){ + cur+=tmp; + param->set_local_version(param->local_version()+1); + }else{ + TensorContainer<cpu, 1> diff(shape); + Tensor<cpu, 1> prev(last_data_[param->id()]->mutable_cpu_data(), shape); + diff=cur-prev; + (*msg)->next_frame(); + int bandwidth; + sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &bandwidth); + if(bandwidth>0){ + response=new Msg(); + response->set_type(kSyncRequest); + response->set_target(param->id(), param->version()); + response->add_frame(diff.dptr, param->size()*sizeof(float)); + (*msg)->SwapAddr(); + response->SetAddr(*msg); + prev=diff+tmp; + Copy(cur, prev); + }else{ + Copy(prev, tmp); + cur=tmp+diff; + } + } + DeleteMsg(msg); + return response; } - } /* singa */ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 6c08a3a..bdc1416 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -2,12 +2,16 @@ #include <vector> #include <map> #include <queue> +#include <chrono> #include <glog/logging.h> #include "proto/common.pb.h" #include "trainer/trainer.h" #include "mshadow/tensor.h" using std::vector; using std::map; +using namespace std::chrono; + +typedef std::chrono::milliseconds TimeT; namespace singa { @@ -21,14 +25,17 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ } void HandleWorkerFinish(void * ctx){ + /* HandleContext* hctx=static_cast<HandleContext*> (ctx); Msg* msg=new Msg(); msg->set_src(-1,-1, kRuntime); msg->set_dst(hctx->group_id, hctx->id, kServer); msg->set_type(kStop); hctx->dealer->Send(&msg); + */ } +<<<<<<< HEAD const std::unordered_map<int, vector<std::pair<int, int>>> SliceParams(int num, const vector<shared_ptr<Param>>& params){ std::unordered_map<int, vector<std::pair<int, int>>> paramid2slices; @@ -276,20 +283,51 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, threads.push_back(std::thread(&Server::Run,server.get())); for(auto worker: workers) threads.push_back(std::thread(&Worker::Run,worker.get())); - Run(workers.size(), servers.size()); + Run(workers, servers, shards); for(auto& thread: threads) thread.join(); for(auto x: ctx) delete x; } -void Trainer::Run(int nworkers, int nservers){ +inline int bandwidth(int bytes, system_clock::time_point start){ + auto now=system_clock::now(); + auto duration=duration_cast<TimeT> (now - start); + return static_cast<int>(bytes*1000.f/duration.count()); +} +void Trainer::Run(const vector<shared_ptr<Worker>>& workers, + const vector<shared_ptr<Server>>& servers, + const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ auto cluster=Cluster::Get(); procs_id_=cluster->procs_id(); + LOG(INFO)<<"Stub in process "<<procs_id_<<" starts"; map<int, shared_ptr<Dealer>> interprocs_dealers; std::queue<Msg*> msg_queue; bool stop=false; + auto start=std::chrono::system_clock::now(); + float amount=0.f; + Poller poll; + poll.Add(router_.get()); + int sync_server=0, nworkers=workers.size(), nservers=servers.size(); while(!stop){ + Socket *sock=poll.Wait(cluster->poll_time()); + if(poll.Terminated()){ + LOG(ERROR)<<"Connection broken!"; + exit(0); + }else if(sock==nullptr){ + if(cluster->nserver_groups()>1&& + bandwidth(amount, start)<cluster->bandwidth()){ + Msg* msg=new Msg(); + msg->set_src(-1,-1, kStub); + msg->set_dst(servers[sync_server]->group_id(), + servers[sync_server]->server_id(), kServer); + msg->set_type(kSyncReminder); + sync_server=(sync_server+1)%servers.size(); + router_->Send(&msg); + //LOG(ERROR)<<"Reminder"; + } + continue; + } Msg* msg=router_->Receive(); if(msg==nullptr){ LOG(ERROR)<<"Connection broken!"; @@ -360,6 +398,7 @@ void Trainer::Run(int nworkers, int nservers){ msg_queue.push(x); break; default: + LOG(ERROR)<<"Unknow message type:"<<type; break; } }else{ @@ -374,12 +413,30 @@ void Trainer::Run(int nworkers, int nservers){ msg->dst_second(), msg->dst_flag()); } if(dst_procs_id!=procs_id_){ + // forward to other procs + if (interprocs_dealers.find(dst_procs_id)==interprocs_dealers.end()){ + auto dealer=make_shared<Dealer>(); + interprocs_dealers[dst_procs_id]=dealer; + dealer->Connect("tcp://"+cluster->endpoint(dst_procs_id)); + } + if(bandwidth(amount, start) <=cluster->bandwidth()){ + start=std::chrono::system_clock::now(); + amount=0; + } + amount+=msg->size(); + interprocs_dealers[dst_procs_id]->Send(&msg); }else{ + if(type==kSyncRequest){ + char buf[32]; + sprintf(buf, "%d", cluster->bandwidth()-bandwidth(amount, start)); + msg->add_frame(buf, strlen(buf)); + } router_->Send(&msg); } } } } + LOG(INFO)<<"Stub in process "<<procs_id_<<" stops"; } Msg* Trainer::HandleConnect(Msg** msg){ string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); @@ -394,7 +451,6 @@ Msg* Trainer::HandleConnect(Msg** msg){ *msg=NULL; return reply; } - const vector<Msg*> Trainer::HandleGet(shared_ptr<ParamInfo> pi, Msg** msg){ Msg* msgg=*msg; vector<Msg*> replies; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 788e77c..37acb14 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -47,6 +47,7 @@ void Worker::ConnectStub(shared_ptr<Dealer> dealer, EntityType type){ } void Worker::Run(){ + LOG(INFO)<<"Worker (group_id= "<<group_id_<<", id="<<worker_id_<<") starts"; dealer_=make_shared<Dealer>(2*thread_id_); ConnectStub(dealer_, kWorkerParam); for(auto layer: train_net_->layers()) @@ -61,8 +62,10 @@ void Worker::Run(){ for(auto layer: train_net_->layers()){ if(layer->partitionid()==worker_id_) for(auto param: layer->GetParams()){ + // only owners fill the memory of parameter values. + // others share the memory with owners hence do not need to put/get. if(param->owner() == param->id()){ - if(group_id_==0) + if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0) param->InitValues(0); else Get(param, modelproto_.warmup_steps()); @@ -70,7 +73,7 @@ void Worker::Run(){ } } Metric perf; - if(group_id_==0){ + if(group_id_%Cluster::Get()->nworker_groups_per_server_group()==0){ for(step_=0;step_<modelproto_.warmup_steps();step_++) RunOneBatch(step_, &perf); for(auto layer: train_net_->layers()){ @@ -86,6 +89,7 @@ void Worker::Run(){ } Stop(); + LOG(INFO)<<"Worker (group_id= "<<group_id_<<", id="<<worker_id_<<") stops"; } void Worker::Stop(){ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f4370118/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index deff6f4..4ad17ce 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -133,8 +133,12 @@ Msg* Param::GenUpdateMsg(bool copy, int idx){ return msg; } -Msg* Param::GenSyncMsg(){ - return nullptr; +Msg* Param::GenSyncMsg(bool copy, int v){ + Msg* msg=new Msg(); + msg->set_type(kSyncRequest); + msg->set_target(id(), local_version()); + msg->add_frame(mutable_cpu_data(), size()*sizeof(float)); + return msg; } Msg* Param::HandlePutMsg(Msg** msg){ @@ -146,9 +150,9 @@ Msg* Param::HandlePutMsg(Msg** msg){ proto_.set_learning_rate_multiplier(lr); proto_.set_weight_decay_multiplier(wc); vector<int> shape{size}; - grad_.Reshape(shape); - history_.Reshape(shape); - data_=std::make_shared<Blob<float>>(shape); + Setup(shape); + set_local_version((*msg)->target_second()); + set_version((*msg)->target_second()); if(ptr==nullptr){ CHECK((*msg)->next_frame()); CHECK_EQ(size* sizeof(float), (*msg)->frame_size()); @@ -201,6 +205,7 @@ Msg* Param::HandleSyncMsg(Msg** msg){ return nullptr; } +<<<<<<< HEAD int Param::ParseSyncResponseMsg(Msg** msg, int slice_idx){ DeleteMsg(msg); return 1;
