SINGA-21 Code review 5 review server.h, server.cc - format code - remove thread_id field - rename variables nUpdate_ -> n_update_ nPendingSync_ -> n_pending_sync_ - fix a bug in HandleUpdate that using a *msg in a unknown state
TODO: - give each socket an unique id from cluster - buffer the un-processed message, intead of sending it back to stub Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d3e1fca3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d3e1fca3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d3e1fca3 Branch: refs/heads/master Commit: d3e1fca38b97e06ca113369d9a4f583750105a39 Parents: 3161175 Author: wang sheng <[email protected]> Authored: Tue Sep 22 17:22:33 2015 +0800 Committer: wang sheng <[email protected]> Committed: Tue Sep 22 17:28:41 2015 +0800 ---------------------------------------------------------------------- include/trainer/server.h | 74 ++++++++++++++--------------- include/trainer/trainer.h | 2 +- src/trainer/server.cc | 104 +++++++++++++++++++++-------------------- src/trainer/trainer.cc | 6 +-- 4 files changed, 92 insertions(+), 94 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/include/trainer/server.h ---------------------------------------------------------------------- diff --git a/include/trainer/server.h b/include/trainer/server.h index 3f1c12d..3f3539a 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -19,17 +19,20 @@ * *************************************************************/ -#ifndef INCLUDE_TRAINER_SERVER_H_ -#define INCLUDE_TRAINER_SERVER_H_ +#ifndef SINGA_TRAINER_SERVER_H_ +#define SINGA_TRAINER_SERVER_H_ + #include <memory> #include <unordered_map> -#include <utils/param.h> -#include <utils/updater.h> -#include "proto/job.pb.h" +#include <vector> #include "communication/socket.h" +#include "proto/job.pb.h" +#include "utils/param.h" +#include "utils/updater.h" namespace singa { -/* Repsond to worker's get/put/udpate request, and periodically syncing with + + /* Repsond to worker's get/put/udpate request, and periodically syncing with * other servers. * * Normally, the Server creates a response message for each request which @@ -39,33 +42,26 @@ namespace singa { * it just sends it to the router. The router will decide to re-send the * request to the server or send it to the worker. */ -class Server{ +class Server { public: - Server(int thread_id, int group_id, int server_id); + Server(int group_id, int server_id); virtual ~Server(); - void Setup(const UpdaterProto& proto, - const std::vector<int>& slice2group, - const std::vector<int>& slice2server); + void Setup(const UpdaterProto& proto, const std::vector<int>& slice2group, + const std::vector<int>& slice2server); void Run(); - const int grp_id() const { - return grp_id_; - } - const int id() const { - return id_; - } + inline int grp_id() const { return grp_id_; } + inline int id() const { return id_; } protected: - - /** - * Process GET request. + /** + * Process GET request. * * @return the orignal message or a response message which contains the values * of the Param with the request version. */ - virtual Msg* HandleGet(Msg** msg); - - /** - * Process Update request. + virtual Msg* HandleGet(Msg** msg); + /** + * Process Update request. * * It waits until received the gradients from all workers from the same worker * group. After updating, it responses to each sender with the new Param @@ -86,16 +82,14 @@ class Server{ * @return the orignal message or response message */ const std::vector<Msg*> HandleUpdate(Msg **msg); - - /** - * Process PUT request. + /** + * Process PUT request. * * @return the original message or response message. If we don't want to * acknowledge the put request, then return nullptr. - */ - virtual Msg* HandlePut(Msg **msg); - - /** + */ + virtual Msg* HandlePut(Msg **msg); + /** * Handle sync request from other server groups. * * It adds updates of Param (slice) from other server groups directly to @@ -106,8 +100,7 @@ class Server{ * @param msg request msg containing the parameter updates * @return response msg that contains the fresh parameter values. */ - virtual Msg* HandleSyncRequest(Msg** msg); - + virtual Msg* HandleSyncRequest(Msg** msg); /** * Handle sync response. * @@ -121,17 +114,20 @@ class Server{ void HandleSyncResponse(Msg** msg); protected: - int thread_id_,grp_id_, id_; - Updater* updater_; + int grp_id_ = -1; + int id_ = -1; + Updater* updater_ = nullptr; //!< map from slice ID to slice and deleted in the destructor std::unordered_map<int, ParamEntry*> shard_; std::vector<int> slice2group_, slice2server_; //!< num of updates from last sync with master server group for a param/slice - std::vector<int> nUpdates_; + std::vector<int> n_updates_; //!< num of sync requests that have not been responded - std::vector<int> nPendingSync_; + std::vector<int> n_pending_sync_; std::vector<Blob<float>> last_sync_; std::unordered_map<int, std::vector<Msg*>> buffer_requests_; }; -} /* Server */ -#endif //INCLUDE_TRAINER_SERVER_H_ + +} // namespace singa + +#endif // SINGA_TRAINER_SERVER_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/include/trainer/trainer.h ---------------------------------------------------------------------- diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index 6630e51..0b03dea 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -73,7 +73,7 @@ class Trainer{ * @param jobConf * @return server instances */ - vector<Server*> CreateServers(int nthread, const JobProto& jobConf); + vector<Server*> CreateServers(const JobProto& jobConf); /** * Create workers instances. * @param nthread total num of threads in current procs which is used to http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/src/trainer/server.cc ---------------------------------------------------------------------- diff --git a/src/trainer/server.cc b/src/trainer/server.cc index 18fe7d2..29f6a68 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -22,30 +22,30 @@ #include <thread> #include <chrono> #include "mshadow/tensor.h" +#include "proto/common.pb.h" #include "trainer/server.h" #include "utils/param.h" #include "utils/singleton.h" #include "utils/factory.h" #include "utils/cluster.h" -#include "proto/common.pb.h" namespace singa { using namespace mshadow; using std::vector; -Server::Server(int thread_id,int group_id, int server_id): - thread_id_(thread_id),grp_id_(group_id), id_(server_id){ +Server::Server(int group_id, int server_id) { + grp_id_ = group_id; + id_ = server_id; } -void Server::Setup(const UpdaterProto& proto, - const vector<int>& slice2group, - const vector<int>& slice2server) { +void Server::Setup(const UpdaterProto& proto, const vector<int>& slice2group, + const vector<int>& slice2server) { updater_ = Updater::Create(proto); slice2group_ = slice2group; slice2server_ = slice2server; - nUpdates_.resize(slice2group_.size(), 0); - nPendingSync_.resize(slice2group_.size(), 0); + n_updates_.resize(slice2group_.size(), 0); + n_pending_sync_.resize(slice2group_.size(), 0); last_sync_.resize(slice2group_.size()); } @@ -57,14 +57,14 @@ Server::~Server() { delete param; } -void Stop(void * running) { +void Stop(void* running) { *static_cast<bool *>(running) = false; } void Server::Run() { LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_ << ") start"; - - auto dealer = new Dealer(2*thread_id_); + // TODO(wangsh): give each dealer a unique id + auto dealer = new Dealer(0); CHECK(dealer->Connect(kInprocRouterEndpoint)); Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub)); ping->set_type(kConnect); @@ -77,7 +77,7 @@ void Server::Run() { // start recv loop and process requests while (running) { // must use poller here; otherwise Receive() gets stuck after workers stop. - auto *sock = poll.Wait(cluster->poll_time()); + auto* sock = poll.Wait(cluster->poll_time()); if (poll.Terminated()) { LOG(ERROR) << "Connection broken!"; exit(0); @@ -85,35 +85,35 @@ void Server::Run() { continue; } Msg* msg = dealer->Receive(); - if (msg == nullptr) break; // interrupted + if (msg == nullptr) break; // interrupted Msg* response = nullptr; int type = msg->type(); int slice_id = SliceID(msg->trgt_val()); if (type == kPut) { response = HandlePut(&msg); + } else if (shard_.find(slice_id) == shard_.end()) { + // TODO(wangsh): buffer the msg instead, and process it after the + // corresponding put request is done + // delay the processing by re-queue the msg. May sleep for a while? + response = msg; } else { - if (shard_.find(slice_id) == shard_.end()) { - // delay the processing by re-queue the msg. May sleep for a while? - response = msg; - } else { - switch (type) { - case kGet: - response = HandleGet(&msg); - break; - case kUpdate: - for (auto reply : HandleUpdate(&msg)) - dealer->Send(&reply); - break; - case kSyncRequest: - response = HandleSyncRequest(&msg); - break; - case kSyncResponse: - HandleSyncResponse(&msg); - break; - default: - LOG(ERROR)<<"Unknown message type "<<type; - break; - } + switch (type) { + case kGet: + response = HandleGet(&msg); + break; + case kUpdate: + for (auto reply : HandleUpdate(&msg)) + dealer->Send(&reply); + break; + case kSyncRequest: + response = HandleSyncRequest(&msg); + break; + case kSyncResponse: + HandleSyncResponse(&msg); + break; + default: + LOG(ERROR) << "Unknown message type: " << type; + break; } } if (response != nullptr) @@ -125,7 +125,6 @@ void Server::Run() { msg->set_type(kStop); dealer->Send(&msg); std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - LOG(ERROR) << "Server (group = " << grp_id_ << ", id = " << id_ << ") stops"; delete dealer; } @@ -154,8 +153,8 @@ Msg* Server::HandlePut(Msg **msg) { last_sync_[slice_id].ReshapeLike(param->data()); last_sync_[slice_id].CopyFrom(param->data()); } - LOG(INFO)<<"server (group = " << grp_id_ << ", id = " << id_ <<") put slice=" - << slice_id << " size=" << param->size(); + LOG(INFO) << "server (group = " << grp_id_ << ", id = " << id_ + <<") put slice=" << slice_id << " size=" << param->size(); return response; } @@ -163,9 +162,9 @@ Msg* Server::HandleGet(Msg **msg) { int val = (*msg)->trgt_val(); auto param = shard_.at(SliceID(val))->shares.at(0); // re-queue the request if the param is not updated to the required version - if(param->version()<(*msg)->trgt_version()) + if (param->version() < (*msg)->trgt_version()) { return *msg; - else { + } else { // LOG(ERROR) << "get " << slice << " from "<<(*msg)->src_first(); auto reply = param->HandleGetMsg(msg, false); reply->set_trgt(val, param->version()); @@ -183,12 +182,14 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) { (*msg)->ParseFormatFrame("i", &num_update); (*msg)->FirstFrame(); entry->num_update += num_update; - // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src()) << ", " << num_update << " total " << entry->num_total; + // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src()) + // << ", " << num_update << " total " << entry->num_total; // do update until recv gradients from all shares of this param/slice if (entry->num_update >= entry->num_total) { CHECK_EQ(entry->num_update, entry->num_total); auto& request = buffer_requests_.at(sliceid); int step = (*msg)->trgt_version(); + int trgt_val = (*msg)->trgt_val(); auto param = entry->shares.at(0); // extract and aggregate gradients param->ParseUpdateMsgs(request); @@ -196,16 +197,16 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) { param->set_local_version(param->local_version() + 1); // response to all shares of this param for (auto response : param->GenUpdateResponseMsgs(&request, false)) { - response->set_trgt((*msg)->trgt_val(), param->local_version()); + response->set_trgt(trgt_val, param->local_version()); ret.push_back(response); } entry->num_update = 0; - nUpdates_[sliceid]++; + n_updates_[sliceid]++; // sync with master group after at least sync_freq local updates // the last check is to avoid sending msg to stopped servers if (slice2group_[sliceid] != grp_id_ - && nUpdates_[sliceid] >= Cluster::Get()->sync_freq() - && nPendingSync_[sliceid] <= Cluster::Get()->sync_freq()) { + && n_updates_[sliceid] >= Cluster::Get()->sync_freq() + && n_pending_sync_[sliceid] <= Cluster::Get()->sync_freq()) { auto shape = Shape1(param->size()); Tensor<cpu, 1> tmp(last_sync_[sliceid].mutable_cpu_data(), shape); Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape); @@ -213,14 +214,15 @@ const vector<Msg*> Server::HandleUpdate(Msg **msg) { int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer); Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr); sync->set_type(kSyncRequest); - sync->set_trgt((*msg)->trgt_val(), param->local_version()); + sync->set_trgt(trgt_val, param->local_version()); sync->AddFrame(tmp.dptr, param->size() * sizeof(float)); Copy(tmp, cur); ret.push_back(sync); - nUpdates_[sliceid] = 0; - nPendingSync_[sliceid]++; + n_updates_[sliceid] = 0; + n_pending_sync_[sliceid]++; } } + // message already pushed to buffer, just need to reset the pointer *msg = nullptr; return ret; } @@ -247,14 +249,14 @@ void Server::HandleSyncResponse(Msg **msg) { Msg* msgg = *msg; int slice = SliceID(msgg->trgt_val()); auto param = shard_.at(slice)->shares.at(0); - auto shape=Shape1(param->size()); + auto shape = Shape1(param->size()); Tensor<cpu, 1> prev(last_sync_[param->id()].mutable_cpu_data(), shape); Tensor<cpu, 1> cur(param->mutable_cpu_data(), shape); Tensor<cpu, 1> master(static_cast<float*>(msgg->FrameData()), shape); cur += master - prev; // cur = master + (cur - prev); Copy(prev, cur); DeleteMsg(msg); - nPendingSync_[slice]--; + n_pending_sync_[slice]--; } -} /* singa */ +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d3e1fca3/src/trainer/trainer.cc ---------------------------------------------------------------------- diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index 4a4c183..c928d91 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -144,7 +144,7 @@ void Trainer::SetupWorkerServer( server->Setup(job_conf.updater(), slice2group, slice2server_); } -vector<Server*> Trainer::CreateServers(int nthreads, const JobProto& job) { +vector<Server*> Trainer::CreateServers(const JobProto& job) { auto cluster = Cluster::Get(); vector<Server*> servers; if (!cluster->has_server()) @@ -160,7 +160,7 @@ vector<Server*> Trainer::CreateServers(int nthreads, const JobProto& job) { int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3]; for (int gid = gstart; gid < gend; gid++) { for (int sid = start; sid < end; sid++) { - auto server = new Server(nthreads++, gid, sid); + auto server = new Server(gid, sid); servers.push_back(server); } } @@ -244,7 +244,7 @@ void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) { int nthreads = 1; const vector<Worker*> workers = CreateWorkers(nthreads, *job); nthreads += workers.size(); - const vector<Server*> servers = CreateServers(nthreads, *job); + const vector<Server*> servers = CreateServers(*job); SetupWorkerServer(*job, workers, servers); #ifdef USE_MPI
