SINGA-156 Remove the dependency on ZMQ for single process training Move msg queue init into dealer and router.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/42f5253e Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/42f5253e Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/42f5253e Branch: refs/heads/master Commit: 42f5253eacde9a0ab87d3b4ed2382a137d9652d6 Parents: 65b8c8d Author: Wei Wang <[email protected]> Authored: Mon Apr 4 16:52:51 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Mon Apr 4 16:52:51 2016 +0800 ---------------------------------------------------------------------- include/singa/comm/socket.h | 3 ++- include/singa/server.h | 2 ++ src/comm/socket.cc | 8 ++++++++ src/driver.cc | 12 ------------ src/server.cc | 12 ++++++------ src/worker.cc | 2 +- 6 files changed, 19 insertions(+), 20 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/include/singa/comm/socket.h ---------------------------------------------------------------------- diff --git a/include/singa/comm/socket.h b/include/singa/comm/socket.h index de8cbde..3194d8c 100644 --- a/include/singa/comm/socket.h +++ b/include/singa/comm/socket.h @@ -43,7 +43,7 @@ class Dealer { /** * @param id used for identifying the msg queue of this dealer. */ - Dealer(int id) : id_(id) {} + Dealer(int id); ~Dealer(); /** * Setup the connection with the remote router. @@ -83,6 +83,7 @@ class Dealer { class Router { public: ~Router(); + Router(); /** * Bind the router to an endpoint for recv msg from remote dealer. * If the router is used for intra-communication only, then no need to call http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/include/singa/server.h ---------------------------------------------------------------------- diff --git a/include/singa/server.h b/include/singa/server.h index 4bffeae..d95862d 100644 --- a/include/singa/server.h +++ b/include/singa/server.h @@ -126,6 +126,8 @@ class Server { std::vector<int> n_pending_sync_; std::vector<Blob<float>> last_sync_; std::unordered_map<int, std::vector<Msg*>> buffer_requests_; + + Dealer* dealer_; }; } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/src/comm/socket.cc ---------------------------------------------------------------------- diff --git a/src/comm/socket.cc b/src/comm/socket.cc index 8245398..aa1ee85 100644 --- a/src/comm/socket.cc +++ b/src/comm/socket.cc @@ -31,6 +31,10 @@ Dealer::~Dealer() { #endif } +Dealer::Dealer(int id) : id_ (id) { + msgQueues[id]; +} + int Dealer::Connect(const std::string& endpoint) { if (endpoint.length() > 0) { #ifdef USE_ZMQ @@ -79,6 +83,10 @@ Router::~Router() { #endif } +Router::Router() { + msgQueues[-1]; +} + int Router::Bind(const std::string& endpoint) { int port = -1; if (endpoint.length() > 0) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 2952c62..b8f6735 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -232,18 +232,6 @@ void Driver::Train(const JobProto& job_conf) { net->ToGraph(true).ToJson()); const vector<Worker*> workers = CreateWorkers(job_conf, net); const vector<Server*> servers = CreateServers(job_conf, net); - // Add msg queues for each socket - for (auto worker : workers) { - msgQueues[Addr(worker->grp_id(), worker->id(), kWorkerParam)]; - msgQueues[Addr(worker->grp_id(), worker->id(), kWorkerLayer)]; -// LOG(ERROR) << "worker addr " << Addr(worker->grp_id(), worker->id(), kWorkerParam); -// LOG(ERROR) << "worker addr " << Addr(worker->grp_id(), worker->id(), kWorkerLayer); - } - for (auto server : servers) { - msgQueues[Addr(server->grp_id(), server->id(), kServer)]; -// LOG(ERROR) << "server addr " << Addr(server->grp_id(), server->id(), kServer); - } - msgQueues[-1]; vector<std::thread> threads; for (auto server : servers) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/src/server.cc ---------------------------------------------------------------------- diff --git a/src/server.cc b/src/server.cc index d5ef028..3b72243 100644 --- a/src/server.cc +++ b/src/server.cc @@ -44,6 +44,7 @@ Server::Server(int group_id, int server_id, updater_ = Updater::Create(job_conf.updater()); slice2group_ = slice2group; slice2server_ = slice2server; + dealer_ = new Dealer(Addr(grp_id_, id_, kServer)); } Server::~Server() { @@ -52,6 +53,7 @@ Server::~Server() { for (auto entry : shard_) for (auto param : entry.second->shares) delete param; + delete dealer_; } void Stop(void* running) { @@ -73,11 +75,10 @@ void Server::Run() { bool running = true; CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running)); - auto dealer = new Dealer(Addr(grp_id_, id_, kServer)); // start recv loop and process requests while (running) { // cannot use blocking Receive() here, it will get stuck after workers stop. - Msg* msg = dealer->Receive(cluster->poll_time()); + Msg* msg = dealer_->Receive(cluster->poll_time()); if (msg == nullptr) continue; Msg* response = nullptr; @@ -97,7 +98,7 @@ void Server::Run() { break; case kUpdate: for (auto reply : HandleUpdate(&msg)) - dealer->Send(&reply); + dealer_->Send(&reply); break; case kSyncRequest: response = HandleSyncRequest(&msg); @@ -111,16 +112,15 @@ void Server::Run() { } } if (response != nullptr) - dealer->Send(&response); + dealer_->Send(&response); } // send stop msg to stub Msg* msg = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub)); msg->set_type(kStop); - dealer->Send(&msg); + dealer_->Send(&msg); std::this_thread::sleep_for(std::chrono::milliseconds(1000)); LOG(ERROR) << "Server (group = " << grp_id_ << ", id = " << id_ << ") stops"; - delete dealer; } Msg* Server::HandlePut(Msg **msg) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/42f5253e/src/worker.cc ---------------------------------------------------------------------- diff --git a/src/worker.cc b/src/worker.cc index 6c461ce..5206513 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -53,7 +53,7 @@ void Worker::Setup(int grp_id, int id, const JobProto& conf, train_net_ = train_net; val_net_ = val_net; test_net_ = test_net; - bridge_dealer_ = dealer_ = nullptr; + InitSockets(train_net); } Worker::~Worker() {
