http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index cd2bc02..5d530da 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -21,6 +21,7 @@ void Server::Setup(const UpdaterProto& proto, shard_=shard; } + void Server::Run(){ dealer_=std::make_shared<Dealer>(2*thread_id_); dealer_->Connect(kInprocRouterEndpoint); @@ -38,7 +39,12 @@ void Server::Run(){ break; Msg* response=nullptr; int type=msg->type(); - if (type==kConnect){ + if (type== kStop){ + msg->set_src(group_id_, server_id_, kServer); + msg->set_dst(-1,-1, kStub); + dealer_->Send(&msg); + break; + }else if (type==kConnect){ // TODO remove receiving pong msg string pong((char*)msg->frame_data(), msg->frame_size()); CHECK_STREQ("PONG", pong.c_str());
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index bc6867d..37e9883 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -36,6 +36,20 @@ void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){ "Updater", CreateInstance(singa::SGDUpdater, singa::Updater)); } +typedef struct HandleContext_{ + shared_ptr<Dealer> dealer; + int group_id, id; +} HandleContext; + +void HandleWorkerFinish(void * ctx){ + HandleContext* hctx=static_cast<HandleContext*> (ctx); + Msg* msg=new Msg(); + msg->set_src(-1,-1, kRuntime); + msg->set_dst(hctx->group_id, hctx->id, kServer); + msg->set_type(kStop); + hctx->dealer->Send(&msg); +} + void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, int procs_id){ procs_id_=procs_id; @@ -44,6 +58,7 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, auto cluster=Cluster::Get(cproto, procs_id); // create servers vector<shared_ptr<Server>> servers; + vector<HandleContext> ctx; int nthreads=1; // the first socket is the router if(cluster->has_server()){ int pid=cluster->procs_id(); @@ -54,10 +69,21 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, int end=start+cluster->nservers_per_group(); // the ParamShard for servers consists of a dictionary of Param objects auto shard=make_shared<Server::ParamShard>(); - for(int sid=start;sid<end;sid++){ - auto server=make_shared<Server>(nthreads++, gid, sid); - server->Setup(mproto.updater(), shard); - servers.push_back(server); + if(start<end){ + auto dealer=make_shared<Dealer>(); + dealer->Connect(kInprocRouterEndpoint); + for(int sid=start;sid<end;sid++){ + auto server=make_shared<Server>(nthreads++, gid, sid); + server->Setup(mproto.updater(), shard); + servers.push_back(server); + HandleContext hc; + hc.dealer=dealer; + hc.group_id=gid; + hc.id=sid; + ctx.push_back(hc); + cluster->runtime()->sWatchSGroup(gid, sid, HandleWorkerFinish, + &ctx.back()); + } } } @@ -152,12 +178,13 @@ void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto, threads.push_back(std::thread(&Server::Run,server.get())); for(auto worker: workers) threads.push_back(std::thread(&Worker::Run,worker.get())); - Run(shards); + Run(servers.size(), workers.size(), shards); for(auto& thread: threads) thread.join(); } -void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ +void Trainer::Run(int nworkers, int nservers, + const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ auto cluster=Cluster::Get(); procs_id_=cluster->procs_id(); auto router=make_shared<Router>(); @@ -166,7 +193,8 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ router->Bind(cluster->endpoint()); map<int, shared_ptr<Dealer>> interprocs_dealers; - while(true){ + bool stop=false; + while(!stop){ Msg* msg=router->Receive(); if(msg==nullptr){ LOG(ERROR)<<"Connection broken!"; @@ -179,6 +207,18 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ if(dst_flag == kStub&&(dst_procs==procs_id_||dst_procs==-1)){ if(type==kConnect){ msg =HandleConnect(&msg); + }else if(type==kStop){ + if(msg->src_flag()==kServer) + nworkers--; + else if (msg->src_flag()==kWorkerParam) + nservers--; + delete msg; + msg=nullptr; + if(nworkers==0&&nservers==0){ + stop=true; + break; + } + LOG(ERROR)<<"Stub recv Stop"; }else{ int group_id=msg->src_first(); int paramid=msg->target_first(); @@ -223,6 +263,7 @@ void Trainer::Run(const std::map<int, shared_ptr<Trainer::ParamShard>>& shards){ } } } + LOG(ERROR)<<"Stub finishes"; } Msg* Trainer::HandleConnect(Msg** msg){ string ping((char*)(*msg)->frame_data(), (*msg)->frame_size()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/trainer/worker.cc ---------------------------------------------------------------------- diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc index 9ef47a6..f0b54ea 100644 --- a/src/trainer/worker.cc +++ b/src/trainer/worker.cc @@ -12,16 +12,20 @@ using std::thread; namespace singa { Worker::Worker(int thread_id, int group_id, int worker_id): thread_id_(thread_id), group_id_(group_id), worker_id_(worker_id){ - } + +} void Worker::Setup(const ModelProto& model, shared_ptr<NeuralNet> train_net){ train_net_=train_net; modelproto_=model; + auto cluster=Cluster::Get(); + int sgid=group_id_/cluster->nworker_groups_per_server_group(); + cluster->runtime()->wJoinSGroup(group_id_, worker_id_, sgid); } void Worker::Run(){ - param_dealer_=make_shared<Dealer>(2*thread_id_); + param_dealer_=make_shared<Dealer>(2*thread_id_); param_dealer_->Connect(kInprocRouterEndpoint); param_poller_.Add(param_dealer_.get()); layer_dealer_=make_shared<Dealer>(2*thread_id_+1); @@ -87,6 +91,19 @@ void Worker::Run(){ RunOneBatch(step_, &perf); step_++; } + + Stop(); +} + +void Worker::Stop(){ + auto cluster=Cluster::Get(); + int sgid=group_id_/cluster->nworker_groups_per_server_group(); + cluster->runtime()->wLeaveSGroup(group_id_, worker_id_, sgid); + Msg* msg=new Msg(); + msg->set_src(group_id_, worker_id_, kWorkerParam); + msg->set_dst(-1,-1, kStub); + msg->set_type(kStop); + param_dealer_->Send(&msg); } int Worker::Put(shared_ptr<Param> param, int step){ Msg* msg=new Msg(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc index 66c4ac8..b00a3cd 100644 --- a/src/utils/cluster.cc +++ b/src/utils/cluster.cc @@ -30,6 +30,9 @@ Cluster::Cluster(const ClusterProto &cluster, int procs_id) { } CHECK_EQ(endpoints_.size(), nprocs); } + auto rt=new ZKClusterRT(cluster_.zookeeper_host()); + rt->Init(); + cluster_rt_=shared_ptr<ClusterRuntime>(static_cast<ClusterRuntime*>(rt)); } void Cluster::SetupFolders(const ClusterProto &cluster){ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c318a98c/src/utils/cluster_rt.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster_rt.cc b/src/utils/cluster_rt.cc index d88ab46..433623d 100644 --- a/src/utils/cluster_rt.cc +++ b/src/utils/cluster_rt.cc @@ -39,7 +39,7 @@ bool ZKClusterRT::Init(){ } bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){ - + string path = getSGroupPath(gid); struct Stat stat; @@ -49,7 +49,7 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){ if (ret == ZOK) ; //need to create zk node first else if (ret == ZNONODE){ - char buf[MAX_BUF_LEN]; + char buf[MAX_BUF_LEN]; ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, 0, buf, MAX_BUF_LEN); if (ret == ZOK){ LOG(INFO) << "zookeeper node " << buf << " created"; @@ -77,13 +77,13 @@ bool ZKClusterRT::sWatchSGroup(int gid, int sid, rt_callback fn, void *ctx){ } bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){ - + string path = getSGroupPath(s_group) + getWorkerPath(gid, wid); - char buf[MAX_BUF_LEN]; - + char buf[MAX_BUF_LEN]; + int ret = zoo_create(zkhandle_, path.c_str(), NULL, -1, &ZOO_OPEN_ACL_UNSAFE, ZOO_EPHEMERAL, buf, MAX_BUF_LEN); if (ret == ZOK){ - LOG(INFO) << "zookeeper node " << buf << " created"; + LOG(ERROR) << "zookeeper node " << buf << " created"; return true; } else if (ret == ZNODEEXISTS){ @@ -94,18 +94,18 @@ bool ZKClusterRT::wJoinSGroup(int gid, int wid, int s_group){ LOG(ERROR) << "zookeeper parent node " << getSGroupPath(s_group) << " not exist"; return false; } - + LOG(ERROR) << "Unhandled ZK error code: " << ret << " (zoo_create)"; return false; } bool ZKClusterRT::wLeaveSGroup(int gid, int wid, int s_group){ - + string path = getSGroupPath(s_group) + getWorkerPath(gid, wid); - + int ret = zoo_delete(zkhandle_, path.c_str(), -1); if (ret == ZOK){ - LOG(INFO) << "zookeeper node " << path << " deleted"; + LOG(ERROR) << "zookeeper node " << path << " deleted"; return true; } else if (ret == ZNONODE){
