Repository: incubator-singa Updated Branches: refs/heads/master 0233049ce -> 914c1e722
SINGA-156 Remove the dependency on ZMQ for single process training Update driver, Server, Worker for using new Dealer/Router; Implement Msg class without ZMQ. Dealer must check its own msg queue to recv msgs. The Router may recv msgs from msg queue or zmq (for inter-process comm). There are more msgs from inter-comm, which may block the recving of inter-comm msgs if not handled properly (e.g., stop recv zmq msg once getting a msg from inter-comm). Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/d8dffdf0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/d8dffdf0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/d8dffdf0 Branch: refs/heads/master Commit: d8dffdf02f90f338388d02e5032d5ca8d0b561e9 Parents: 0233049 Author: Wei Wang <[email protected]> Authored: Tue Mar 29 20:35:46 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Sat Apr 2 12:12:52 2016 +0800 ---------------------------------------------------------------------- include/singa/comm/msg.h | 9 +- include/singa/comm/socket.h | 160 +++++++-------------- include/singa/utils/safe_queue.h | 263 ++++++++++++++++++++++++++++++++++ src/comm/msg.cc | 77 ++++++++-- src/comm/socket.cc | 188 ++++++++++-------------- src/driver.cc | 19 ++- src/server.cc | 21 +-- src/stub.cc | 15 +- src/utils/cluster_rt.cc | 4 +- src/utils/param.cc | 2 + src/worker.cc | 16 +-- 11 files changed, 493 insertions(+), 281 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/include/singa/comm/msg.h ---------------------------------------------------------------------- diff --git a/include/singa/comm/msg.h b/include/singa/comm/msg.h index 50a9b81..ade7fc8 100644 --- a/include/singa/comm/msg.h +++ b/include/singa/comm/msg.h @@ -22,10 +22,12 @@ #ifndef SINGA_COMM_MSG_H_ #define SINGA_COMM_MSG_H_ +#include <utility> + // TODO(wangwei): make it a compiler argument #define USE_ZMQ -#include <utility> +#include <vector> #ifdef USE_ZMQ #include <czmq.h> #endif @@ -79,7 +81,7 @@ inline int AddrType(int addr) { } /** - * Msg used to transfer Param info (gradient or value), feature blob, etc + * Msg used to transfer Param info (gradient or value), feature blob, etc. * between workers, stubs and servers. * * Each msg has a source addr and dest addr identified by a unique integer. @@ -225,6 +227,9 @@ class Msg { #ifdef USE_ZMQ zmsg_t* msg_ = nullptr; zframe_t *frame_ = nullptr; +#else + std::vector<std::pair<void*, int>> frames_; + unsigned idx_ = 0; #endif }; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/include/singa/comm/socket.h ---------------------------------------------------------------------- diff --git a/include/singa/comm/socket.h b/include/singa/comm/socket.h index fae9ccb..de8cbde 100644 --- a/include/singa/comm/socket.h +++ b/include/singa/comm/socket.h @@ -25,150 +25,98 @@ #ifdef USE_ZMQ #include <czmq.h> #endif + #include <map> #include <string> #include <vector> +#include <unordered_map> +#include "singa/utils/safe_queue.h" #include "singa/comm/msg.h" namespace singa { - -const std::string kInprocRouterEndpoint = "inproc://router"; - -class SocketInterface { +/** + * Worker and Server use Dealer to communicate with Stub. + * Stub uses Dealer to communicate with remote Stub. + */ +class Dealer { public: - virtual ~SocketInterface() {} - /** - * Send a message to connected socket(s), non-blocking. The message - * will be deallocated after sending, thus should not be used after - * calling Send(); - * - * @param msg The message to be sent - * @return 1 for success queuing the message for sending, 0 for failure + /** + * @param id used for identifying the msg queue of this dealer. */ - virtual int Send(Msg** msg) = 0; + Dealer(int id) : id_(id) {} + ~Dealer(); /** - * Receive a message from any connected socket. + * Setup the connection with the remote router. * - * @return a message pointer if success; nullptr if failure + * For local router, there is no need to connect it. + * + * @param endpoint Identifier of the remote router to connect. It follows + * ZeroMQ's format, i.e., IP:port, where IP is the connected process. + * @return 1 connection sets up successfully; 0 otherwise */ - virtual Msg* Receive() = 0; + int Connect(const std::string& endpoint); /** - * @return Identifier of the implementation dependent socket. E.g., zsock_t* - * for ZeroMQ implementation and rank for MPI implementation. + * Send a message to the local router (id=-1) or remote outer. It is + * non-blocking. The message will be deallocated after sending, thus + * should not be used after calling Send(); */ - virtual void* InternalID() const = 0; -}; - -class Poller { - public: - Poller(); - explicit Poller(SocketInterface* socket); - /** - * Add a socket for polling; Multiple sockets can be polled together by - * adding them into the same poller. - */ - void Add(SocketInterface* socket); - /** - * Poll for all sockets added into this poller. - * @param timeout Stop after this number of mseconds - * @return pointer To the socket if it has one message in the receiving - * queue; nullptr if no message in any sockets, - */ - SocketInterface* Wait(int duration); - + int Send(Msg** msg); /** - * @return true if the poller is terminated due to process interupt - */ - virtual bool Terminated(); - - protected: -#ifdef USE_ZMQ - zpoller_t *poller_; - std::map<zsock_t*, SocketInterface*> zsock2Socket_; -#endif -}; - -class Dealer : public SocketInterface { - public: - /* - * @param id Local dealer ID within a procs if the dealer is from worker or - * server thread, starts from 1 (0 is used by the router); or the connected - * remote procs ID for inter-process dealers from the stub thread. + * Recv msg from local router. + * + * @param timeout return if waiting for timeout microseconds. + * @return a message pointer if success; nullptr if failure */ - Dealer(); - explicit Dealer(int id); - ~Dealer() override; - /** - * Setup the connection with the router. - * - * @param endpoint Identifier of the router. For intra-process - * connection, the endpoint follows the format of ZeroMQ, i.e., - * starting with "inproc://"; in Singa, since each process has one - * router, hence we can fix the endpoint to be "inproc://router" for - * intra-process. For inter-process, the endpoint follows ZeroMQ's - * format, i.e., IP:port, where IP is the connected process. - * @return 1 connection sets up successfully; 0 otherwise - */ - int Connect(const std::string& endpoint); - int Send(Msg** msg) override; - Msg* Receive() override; - void* InternalID() const override; + Msg* Receive(int timeout = 0); protected: - int id_ = -1; + std::string endpoint_; + int id_; #ifdef USE_ZMQ zsock_t* dealer_ = nullptr; - zpoller_t* poller_ = nullptr; #endif }; - -class Router : public SocketInterface { +/** + * In Singa, since each process has one router used by Stub, hence we fix the + * router to use the msg queue indexed by -1. + */ +class Router { public: - Router(); - /** - * There is only one router per procs, hence its local id is 0 and is not set - * explicitly. - * - * @param bufsize Buffer at most this number of messages - */ - explicit Router(int bufsize); - ~Router() override; + ~Router(); /** - * Setup the connection with dealers. + * Bind the router to an endpoint for recv msg from remote dealer. + * If the router is used for intra-communication only, then no need to call + * Bind. * - * It automatically binds to the endpoint for intra-process communication, - * i.e., "inproc://router". - * - * @param endpoint The identifier for the Dealer socket in other process + * @param endpoint identifier for the Dealer socket in other process * to connect. It has the format IP:Port, where IP is the host machine. - * If endpoint is empty, it means that all connections are - * intra-process connection. * @return number of connected dealers. */ int Bind(const std::string& endpoint); /** - * If the destination socket has not connected yet, buffer this the message. + * Send msg to local dealers by pushing the msg into the msg queue indexed by + * dst of the msg. + */ + int Send(Msg** msg); + /** + * Recv msg from local (msg queue) or remote dealer (via zmq). */ - int Send(Msg** msg) override; - Msg* Receive() override; - void* InternalID() const override; + Msg* Receive(int timeout = 0); protected: - int nBufmsg_ = 0; - int bufsize_ = 100; + std::string endpoint_; #ifdef USE_ZMQ zsock_t* router_ = nullptr; zpoller_t* poller_ = nullptr; - std::map<int, zframe_t*> id2addr_; - std::map<int, std::vector<zmsg_t*>> bufmsg_; #endif }; -#ifdef USE_MPI -// TODO(wangsheng): add intra-process communication using shared queue -std::vector<SafeQueue*> MPIQueues; -#endif - +/** + * Used for intra-process communication. + * Each dealer/router has a SafeQueue for recieving msgs. + * The sender pushes msgs onto the queue of the reciever's queue. + */ +extern std::unordered_map<int, SafeQueue<Msg*>> msgQueues; } // namespace singa #endif // SINGA_COMM_SOCKET_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/include/singa/utils/safe_queue.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/safe_queue.h b/include/singa/utils/safe_queue.h new file mode 100644 index 0000000..99adbf0 --- /dev/null +++ b/include/singa/utils/safe_queue.h @@ -0,0 +1,263 @@ +#ifndef SINGA_UTILS_SAFE_QUEUE_H_ +#define SINGA_UTILS_SAFE_QUEUE_H_ + +// source: http://gnodebian.blogspot.sg/2013/07/a-thread-safe-asynchronous-queue-in-c11.html +#include <queue> +#include <list> +#include <mutex> +#include <thread> +#include <cstdint> +#include <condition_variable> + +/** A thread-safe asynchronous queue */ +template <class T, class Container = std::list<T>> +class SafeQueue { + + typedef typename Container::value_type value_type; + typedef typename Container::size_type size_type; + typedef Container container_type; + + public: + + /*! Create safe queue. */ + SafeQueue() = default; + SafeQueue (SafeQueue&& sq) { + m_queue = std::move (sq.m_queue); + } + SafeQueue (const SafeQueue& sq) { + std::lock_guard<std::mutex> lock (sq.m_mutex); + m_queue = sq.m_queue; + } + + /*! Destroy safe queue. */ + ~SafeQueue() { + std::lock_guard<std::mutex> lock (m_mutex); + } + + /** + * Sets the maximum number of items in the queue. Defaults is 0: No limit + * \param[in] item An item. + */ + void set_max_num_items (unsigned int max_num_items) { + m_max_num_items = max_num_items; + } + + /** + * Pushes the item into the queue. + * \param[in] item An item. + * \return true if an item was pushed into the queue + */ + bool push (const value_type& item) { + std::lock_guard<std::mutex> lock (m_mutex); + + if (m_max_num_items > 0 && m_queue.size() > m_max_num_items) + return false; + + m_queue.push (item); + m_condition.notify_one(); + return true; + } + + /** + * Pushes the item into the queue. + * \param[in] item An item. + * \return true if an item was pushed into the queue + */ + bool push (const value_type&& item) { + std::lock_guard<std::mutex> lock (m_mutex); + + if (m_max_num_items > 0 && m_queue.size() > m_max_num_items) + return false; + + m_queue.push (item); + m_condition.notify_one(); + return true; + } + + /** + * Pops item from the queue. If queue is empty, this function blocks until item becomes available. + * \param[out] item The item. + */ + void pop (value_type& item) { + std::unique_lock<std::mutex> lock (m_mutex); + m_condition.wait (lock, [this]() // Lambda funct + { + return !m_queue.empty(); + }); + item = m_queue.front(); + m_queue.pop(); + } + + /** + * Pops item from the queue using the contained type's move assignment operator, if it has one.. + * This method is identical to the pop() method if that type has no move assignment operator. + * If queue is empty, this function blocks until item becomes available. + * \param[out] item The item. + */ + void move_pop (value_type& item) { + std::unique_lock<std::mutex> lock (m_mutex); + m_condition.wait (lock, [this]() // Lambda funct + { + return !m_queue.empty(); + }); + item = std::move (m_queue.front()); + m_queue.pop(); + } + + /** + * Tries to pop item from the queue. + * \param[out] item The item. + * \return False is returned if no item is available. + */ + bool try_pop (value_type& item) { + std::unique_lock<std::mutex> lock (m_mutex); + + if (m_queue.empty()) + return false; + + item = m_queue.front(); + m_queue.pop(); + return true; + } + + /** + * Tries to pop item from the queue using the contained type's move assignment operator, if it has one.. + * This method is identical to the try_pop() method if that type has no move assignment operator. + * \param[out] item The item. + * \return False is returned if no item is available. + */ + bool try_move_pop (value_type& item) { + std::unique_lock<std::mutex> lock (m_mutex); + + if (m_queue.empty()) + return false; + + item = std::move (m_queue.front()); + m_queue.pop(); + return true; + } + + /** + * Pops item from the queue. If the queue is empty, blocks for timeout microseconds, or until item becomes available. + * \param[out] t An item. + * \param[in] timeout The number of microseconds to wait. + * \return true if get an item from the queue, false if no item is received before the timeout. + */ + bool timeout_pop (value_type& item, std::uint64_t timeout) { + std::unique_lock<std::mutex> lock (m_mutex); + + if (m_queue.empty()) + { + if (timeout == 0) + return false; + + if (m_condition.wait_for (lock, std::chrono::microseconds (timeout)) == std::cv_status::timeout) + return false; + } + + item = m_queue.front(); + m_queue.pop(); + return true; + } + + /** + * Pops item from the queue using the contained type's move assignment operator, if it has one.. + * If the queue is empty, blocks for timeout microseconds, or until item becomes available. + * This method is identical to the try_pop() method if that type has no move assignment operator. + * \param[out] t An item. + * \param[in] timeout The number of microseconds to wait. + * \return true if get an item from the queue, false if no item is received before the timeout. + */ + bool timeout_move_pop (value_type& item, std::uint64_t timeout) { + std::unique_lock<std::mutex> lock (m_mutex); + + if (m_queue.empty()) + { + if (timeout == 0) + return false; + + if (m_condition.wait_for (lock, std::chrono::microseconds (timeout)) == std::cv_status::timeout) + return false; + } + + item = std::move (m_queue.front()); + m_queue.pop(); + return true; + } + + /** + * Gets the number of items in the queue. + * \return Number of items in the queue. + */ + size_type size() const { + std::lock_guard<std::mutex> lock (m_mutex); + return m_queue.size(); + } + + /** + * Check if the queue is empty. + * \return true if queue is empty. + */ + bool empty() const { + std::lock_guard<std::mutex> lock (m_mutex); + return m_queue.empty(); + } + + /** + * Swaps the contents. + * \param[out] sq The SafeQueue to swap with 'this'. + */ + void swap (SafeQueue& sq) { + if (this != &sq) { + std::lock_guard<std::mutex> lock1 (m_mutex); + std::lock_guard<std::mutex> lock2 (sq.m_mutex); + m_queue.swap (sq.m_queue); + + if (!m_queue.empty()) + m_condition.notify_all(); + + if (!sq.m_queue.empty()) + sq.m_condition.notify_all(); + } + } + + /*! The copy assignment operator */ + SafeQueue& operator= (const SafeQueue& sq) { + if (this != &sq) { + std::lock_guard<std::mutex> lock1 (m_mutex); + std::lock_guard<std::mutex> lock2 (sq.m_mutex); + std::queue<T, Container> temp {sq.m_queue}; + m_queue.swap (temp); + + if (!m_queue.empty()) + m_condition.notify_all(); + } + + return *this; + } + + /*! The move assignment operator */ + SafeQueue& operator= (SafeQueue && sq) { + std::lock_guard<std::mutex> lock (m_mutex); + m_queue = std::move (sq.m_queue); + + if (!m_queue.empty()) m_condition.notify_all(); + + return *this; + } + + + private: + + std::queue<T, Container> m_queue; + mutable std::mutex m_mutex; + std::condition_variable m_condition; + unsigned int m_max_num_items = 0; +}; + +/*! Swaps the contents of two SafeQueue objects. */ +template <class T, class Container> +void swap (SafeQueue<T, Container>& q1, SafeQueue<T, Container>& q2) { + q1.swap (q2); +} +#endif // SINGA_UTILS_SAFE_QUEUE_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/comm/msg.cc ---------------------------------------------------------------------- diff --git a/src/comm/msg.cc b/src/comm/msg.cc index 5c33026..94f3074 100644 --- a/src/comm/msg.cc +++ b/src/comm/msg.cc @@ -22,18 +22,25 @@ #include "singa/comm/msg.h" #include <glog/logging.h> +#include <stdarg.h> namespace singa { -#ifdef USE_ZMQ Msg::~Msg() { +#ifdef USE_ZMQ if (msg_ != nullptr) zmsg_destroy(&msg_); frame_ = nullptr; +#else + for (auto& frame : frames_) + delete static_cast<char*>(frame.first); +#endif } Msg::Msg() { +#ifdef USE_ZMQ msg_ = zmsg_new(); +#endif } Msg::Msg(const Msg& msg) { @@ -42,51 +49,49 @@ Msg::Msg(const Msg& msg) { type_ = msg.type_; trgt_val_ = msg.trgt_val_; trgt_version_ = msg.trgt_version_; +#ifdef USE_ZMQ msg_ = zmsg_dup(msg.msg_); +#endif } Msg::Msg(int src, int dst) { src_ = src; dst_ = dst; +#ifdef USE_ZMQ msg_ = zmsg_new(); +#endif } void Msg::SwapAddr() { std::swap(src_, dst_); } +#ifdef USE_ZMQ int Msg::size() const { return zmsg_content_size(msg_); } - void Msg::AddFrame(const void* addr, int nBytes) { zmsg_addmem(msg_, addr, nBytes); } - int Msg::FrameSize() { return zframe_size(frame_); } - -void* Msg::FrameData() { - return zframe_data(frame_); -} - char* Msg::FrameStr() { return zframe_strdup(frame_); } +void* Msg::FrameData() { + return zframe_data(frame_); +} bool Msg::NextFrame() { frame_ = zmsg_next(msg_); return frame_ != nullptr; } - void Msg::FirstFrame() { frame_ = zmsg_first(msg_); } - void Msg::LastFrame() { frame_ = zmsg_last(msg_); } - void Msg::ParseFromZmsg(zmsg_t* msg) { char* tmp = zmsg_popstr(msg); sscanf(tmp, "%d %d %d %d %d", @@ -103,6 +108,49 @@ zmsg_t* Msg::DumpToZmsg() { return tmp; } +#else + +int Msg::size() const { + int s = 0; + for (auto& entry : frames_) + s += entry.second; + return s; +} + +void Msg::AddFrame(const void* addr, int nBytes) { + char* tmp = new char[nBytes]; + memcpy(tmp, addr, nBytes); + frames_.push_back(std::make_pair(tmp, nBytes)); +} + +int Msg::FrameSize() { + return frames_.at(idx_).second; +} + +char* Msg::FrameStr() { + return static_cast<char*>(frames_.at(idx_).first); +} + +void* Msg::FrameData() { + return frames_.at(idx_).first; +} + +bool Msg::NextFrame() { + idx_++; +// LOG(ERROR) << "idx " << idx_ << " vs size " << frames_.size(); + return idx_ < frames_.size(); +} + +void Msg::FirstFrame() { + idx_ = 0; +} + +void Msg::LastFrame() { + idx_ = frames_.size() - 1; +} + +#endif + // frame marker indicating this frame is serialize like printf #define FMARKER "*singa*" @@ -156,14 +204,14 @@ int Msg::AddFormatFrame(const char *format, ...) { CHECK_LE(size, kMaxFrameLen); } va_end(argptr); - zmsg_addmem(msg_, dst, size); + AddFrame(dst, size); return size; } int Msg::ParseFormatFrame(const char *format, ...) { va_list argptr; va_start(argptr, format); - char* src = zframe_strdup(frame_); + char* src = FrameStr(); CHECK_STREQ(FMARKER, src); int size = strlen(FMARKER) + 1; while (*format) { @@ -207,9 +255,8 @@ int Msg::ParseFormatFrame(const char *format, ...) { format++; } va_end(argptr); - delete src; + // delete src; return size; } -#endif } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/comm/socket.cc ---------------------------------------------------------------------- diff --git a/src/comm/socket.cc b/src/comm/socket.cc index 09a6913..8245398 100644 --- a/src/comm/socket.cc +++ b/src/comm/socket.cc @@ -23,158 +23,116 @@ #include <glog/logging.h> namespace singa { - -#ifdef USE_ZMQ -Poller::Poller() { - poller_ = zpoller_new(nullptr); -} - -Poller::Poller(SocketInterface* socket) { - poller_ = zpoller_new(nullptr); - Add(socket); -} - -void Poller::Add(SocketInterface* socket) { - zsock_t* zsock = static_cast<zsock_t*>(socket->InternalID()); - zpoller_add(poller_, zsock); - zsock2Socket_[zsock] = socket; -} - -SocketInterface* Poller::Wait(int timeout) { - zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout)); - if (sock != nullptr) - return zsock2Socket_[sock]; - else - return nullptr; -} - -bool Poller::Terminated() { - return zpoller_terminated(poller_); -} - - -Dealer::Dealer() : Dealer(-1) {} - -Dealer::Dealer(int id) : id_(id) { - dealer_ = zsock_new(ZMQ_DEALER); - CHECK_NOTNULL(dealer_); -} - +const int TIME_OUT = 2; // max blocking time in milliseconds. +std::unordered_map<int, SafeQueue<Msg*>> msgQueues; Dealer::~Dealer() { +#ifdef USE_ZMQ zsock_destroy(&dealer_); +#endif } int Dealer::Connect(const std::string& endpoint) { - CHECK_GT(endpoint.length(), 0); - if (endpoint.length()) { + if (endpoint.length() > 0) { +#ifdef USE_ZMQ + dealer_ = zsock_new(ZMQ_DEALER); + CHECK_NOTNULL(dealer_); CHECK_EQ(zsock_connect(dealer_, "%s", endpoint.c_str()), 0); - return 1; +#else + LOG(FATAL) << "No message passing lib is linked"; +#endif + endpoint_ = endpoint; } - return 0; + return 1; } int Dealer::Send(Msg** msg) { - zmsg_t* zmsg = (*msg)->DumpToZmsg(); - zmsg_send(&zmsg, dealer_); - delete *msg; - *msg = nullptr; + if (endpoint_.length()) { +#ifdef USE_ZMQ + zmsg_t* zmsg = (*msg)->DumpToZmsg(); + zmsg_send(&zmsg, dealer_); +#else + LOG(FATAL) << "No message passing lib is linked"; +#endif + delete *msg; + *msg = nullptr; + } else { + msgQueues.at(-1).push(*msg); + } return 1; } -Msg* Dealer::Receive() { - zmsg_t* zmsg = zmsg_recv(dealer_); - if (zmsg == nullptr) - return nullptr; - Msg* msg = new Msg(); - msg->ParseFromZmsg(zmsg); +Msg* Dealer::Receive(int timeout) { + Msg* msg = nullptr; + if (timeout > 0) { + if(!msgQueues.at(id_).timeout_pop(msg, timeout)) + return nullptr; + } else { + msgQueues.at(id_).pop(msg); + } + msg->FirstFrame(); return msg; } -void* Dealer::InternalID() const { - return dealer_; -} - -Router::Router() : Router(100) {} - -Router::Router(int bufsize) { - nBufmsg_ = 0; - bufsize_ = bufsize; - router_ = zsock_new(ZMQ_ROUTER); - CHECK_NOTNULL(router_); - poller_ = zpoller_new(router_); - CHECK_NOTNULL(poller_); -} - Router::~Router() { +#ifdef USE_ZMQ zsock_destroy(&router_); - for (auto it : id2addr_) - zframe_destroy(&it.second); - for (auto it : bufmsg_) { - for (auto *msg : it.second) - zmsg_destroy(&msg); - } +#endif } + int Router::Bind(const std::string& endpoint) { int port = -1; - if (endpoint.length()) { + if (endpoint.length() > 0) { + endpoint_ = endpoint; +#ifdef USE_ZMQ + router_ = zsock_new(ZMQ_ROUTER); + CHECK_NOTNULL(router_); port = zsock_bind(router_, "%s", endpoint.c_str()); + CHECK_NE(port, -1) << endpoint; + LOG(INFO) << "bind successfully to " << zsock_endpoint(router_); + poller_ = zpoller_new(router_); +#else + LOG(FATAL) << "No message passing lib is linked"; +#endif } - CHECK_NE(port, -1) << endpoint; - LOG(INFO) << "bind successfully to " << zsock_endpoint(router_); return port; } int Router::Send(Msg **msg) { - zmsg_t* zmsg = (*msg)->DumpToZmsg(); int dstid = (*msg)->dst(); - if (id2addr_.find(dstid) != id2addr_.end()) { - // the connection has already been set up - zframe_t* addr = zframe_dup(id2addr_[dstid]); - zmsg_prepend(zmsg, &addr); - zmsg_send(&zmsg, router_); + if (msgQueues.find(dstid) != msgQueues.end()) { + msgQueues.at(dstid).push(*msg); } else { - // the connection is not ready, buffer the message - if (bufmsg_.size() == 0) - nBufmsg_ = 0; - bufmsg_[dstid].push_back(zmsg); - ++nBufmsg_; - CHECK_LE(nBufmsg_, bufsize_); + LOG(FATAL) << "The dst queue not exist for dstid = " << dstid; } - delete *msg; - *msg = nullptr; return 1; } -Msg* Router::Receive() { - zmsg_t* zmsg = zmsg_recv(router_); - if (zmsg == nullptr) { - LOG(ERROR) << "Connection broken!"; - exit(0); - } - zframe_t* dealer = zmsg_pop(zmsg); - Msg* msg = new Msg(); - msg->ParseFromZmsg(zmsg); - if (id2addr_.find(msg->src()) == id2addr_.end()) { - // new connection, store the sender's identfier and send buffered messages - // for it - id2addr_[msg->src()] = dealer; - if (bufmsg_.find(msg->src()) != bufmsg_.end()) { - for (auto& it : bufmsg_.at(msg->src())) { - zframe_t* addr = zframe_dup(dealer); - zmsg_prepend(it, &addr); - zmsg_send(&it, router_); +Msg* Router::Receive(int timeout) { + Msg* msg = nullptr; + if (timeout == 0) + timeout = TIME_OUT; + while (msg == nullptr) { +#ifdef USE_ZMQ + if (router_ != nullptr) { + zsock_t* sock = static_cast<zsock_t*>(zpoller_wait(poller_, timeout)); + if (sock != NULL) { + zmsg_t* zmsg = zmsg_recv(router_); + if (zmsg == nullptr) { + LOG(ERROR) << "Connection broken!"; + exit(0); + } + zframe_t* dealer = zmsg_pop(zmsg); + zframe_destroy(&dealer); + Msg* remote_msg = new Msg(); + remote_msg->ParseFromZmsg(zmsg); + msgQueues.at(-1).push(remote_msg); } - bufmsg_.erase(msg->src()); } - } else { - zframe_destroy(&dealer); +#endif + msgQueues.at(-1).timeout_pop(msg, timeout * 10); } + msg->FirstFrame(); return msg; } -void* Router::InternalID() const { - return router_; -} -#endif - } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/driver.cc ---------------------------------------------------------------------- diff --git a/src/driver.cc b/src/driver.cc index 6163865..2952c62 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -25,6 +25,7 @@ #include <set> #include <string> #include <vector> +#include "singa/comm/socket.h" #include "singa/neuralnet/layer.h" #include "singa/utils/common.h" #include "singa/utils/tinydir.h" @@ -231,12 +232,18 @@ void Driver::Train(const JobProto& job_conf) { net->ToGraph(true).ToJson()); const vector<Worker*> workers = CreateWorkers(job_conf, net); const vector<Server*> servers = CreateServers(job_conf, net); - -#ifdef USE_MPI - int nthreads = workers.size() + servers.size() + 1; - for (int i = 0; i < nthreads; i++) - MPIQueues.push_back(make_shared<SafeQueue>()); -#endif + // Add msg queues for each socket + for (auto worker : workers) { + msgQueues[Addr(worker->grp_id(), worker->id(), kWorkerParam)]; + msgQueues[Addr(worker->grp_id(), worker->id(), kWorkerLayer)]; +// LOG(ERROR) << "worker addr " << Addr(worker->grp_id(), worker->id(), kWorkerParam); +// LOG(ERROR) << "worker addr " << Addr(worker->grp_id(), worker->id(), kWorkerLayer); + } + for (auto server : servers) { + msgQueues[Addr(server->grp_id(), server->id(), kServer)]; +// LOG(ERROR) << "server addr " << Addr(server->grp_id(), server->id(), kServer); + } + msgQueues[-1]; vector<std::thread> threads; for (auto server : servers) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/server.cc ---------------------------------------------------------------------- diff --git a/src/server.cc b/src/server.cc index bd7b5f8..d5ef028 100644 --- a/src/server.cc +++ b/src/server.cc @@ -71,28 +71,15 @@ void Server::Run() { n_pending_sync_.resize(slice2group_.size(), 0); last_sync_.resize(slice2group_.size()); - // TODO(wangsh): give each dealer a unique id - auto dealer = new Dealer(0); - CHECK(dealer->Connect(kInprocRouterEndpoint)); - Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub)); - ping->set_type(kConnect); - dealer->Send(&ping); - bool running = true; CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running)); - Poller poll(dealer); + auto dealer = new Dealer(Addr(grp_id_, id_, kServer)); // start recv loop and process requests while (running) { - // must use poller here; otherwise Receive() gets stuck after workers stop. - auto* sock = poll.Wait(cluster->poll_time()); - if (poll.Terminated()) { - LOG(ERROR) << "Connection broken!"; - exit(0); - } else if (sock == nullptr) { + // cannot use blocking Receive() here, it will get stuck after workers stop. + Msg* msg = dealer->Receive(cluster->poll_time()); + if (msg == nullptr) continue; - } - Msg* msg = dealer->Receive(); - if (msg == nullptr) break; // interrupted Msg* response = nullptr; int type = msg->type(); int slice_id = SliceID(msg->trgt_val()); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/stub.cc ---------------------------------------------------------------------- diff --git a/src/stub.cc b/src/stub.cc index c06128c..c7658fc 100644 --- a/src/stub.cc +++ b/src/stub.cc @@ -43,11 +43,13 @@ Stub::~Stub() { } void Stub::Setup() { router_ = new Router(); - router_->Bind(kInprocRouterEndpoint); auto cluster = Cluster::Get(); - const string hostip = cluster->hostip(); - int port = router_->Bind("tcp://" + hostip + ":*"); - endpoint_ = hostip + ":" + std::to_string(port); + if (cluster->nprocs() > 1) { + const string hostip = cluster->hostip(); + int port = router_->Bind("tcp://" + hostip + ":*"); + endpoint_ = hostip + ":" + std::to_string(port); + } else + endpoint_ = "localhost"; } /** * Get a hash id for a Param object from a group. @@ -116,6 +118,7 @@ void Stub::Run(const vector<int>& slice2server, msg = msg_queue.front(); msg_queue.pop(); } +// LOG(ERROR) << "stub recv msg " << msg; int type = msg->type(), dst = msg->dst(), flag = AddrType(dst); if (flag == kStub && (AddrProc(dst) == procs_id || AddrGrp(dst) == -1)) { // the following statements are ordered! @@ -174,6 +177,7 @@ void Stub::Run(const vector<int>& slice2server, inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs); inter_dealers[dst_procs]->Send(&msg); } else { +// LOG(ERROR) << "router send msg " << msg; router_->Send(&msg); } } @@ -186,7 +190,7 @@ void Stub::Run(const vector<int>& slice2server, Dealer* Stub::CreateInterProcsDealer(int dst_procs) { // forward to other procs auto cluster = Cluster::Get(); - auto dealer = new Dealer(); + auto dealer = new Dealer(-2); while (cluster->endpoint(dst_procs) == "") { // kCollectSleepTime)); std::this_thread::sleep_for(std::chrono::milliseconds(3000)); @@ -223,6 +227,7 @@ void Stub::GenMsgs(int type, int version, ParamEntry* entry, Msg* msg, new_msg->set_src(Addr(src_grp, procs_id, kStub)); new_msg->set_dst(Addr(dst_grp, server, kServer)); ret->push_back(new_msg); +// LOG(ERROR) << "stub gen msg " << new_msg; } } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/utils/cluster_rt.cc ---------------------------------------------------------------------- diff --git a/src/utils/cluster_rt.cc b/src/utils/cluster_rt.cc index 7a04ff7..9a7b8bd 100644 --- a/src/utils/cluster_rt.cc +++ b/src/utils/cluster_rt.cc @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at -* +* * http://www.apache.org/licenses/LICENSE-2.0 -* +* * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/utils/param.cc ---------------------------------------------------------------------- diff --git a/src/utils/param.cc b/src/utils/param.cc index 158c777..e1c04c7 100644 --- a/src/utils/param.cc +++ b/src/utils/param.cc @@ -235,6 +235,7 @@ Msg* Param::GenPutMsg(bool copy, int idx) { if (copy) { msg->AddFrame(ptr, slice_size_[idx] * sizeof(float)); } +// LOG(ERROR) << "gen put msg: " << msg; return msg; } @@ -281,6 +282,7 @@ Msg* Param::HandlePutMsg(Msg** msg, bool reserve) { int size; float lr, wc; float* ptr; +// LOG(ERROR) << "handle put msg:" << *msg; (*msg)->ParseFormatFrame("iffp", &size, &lr, &wc, &ptr); ParamProto proto; proto.set_lr_scale(lr); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/d8dffdf0/src/worker.cc ---------------------------------------------------------------------- diff --git a/src/worker.cc b/src/worker.cc index 1e35ff9..6c461ce 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -123,24 +123,13 @@ void Worker::Test(int steps, Phase phase, NeuralNet* net) { Display(phase, " ", net); } -void ConnectStub(int grp, int id, Dealer* dealer, EntityType entity) { - dealer->Connect(kInprocRouterEndpoint); - Msg* ping = new Msg(Addr(grp, id, entity), Addr(-1, -1, kStub)); - ping->set_type(kConnect); - dealer->Send(&ping); -} - void Worker::InitSockets(const NeuralNet* net) { - // TODO(wangsh): provide a unique sock id from cluster - dealer_ = new Dealer(0); - ConnectStub(grp_id_, id_, dealer_, kWorkerParam); + dealer_ = new Dealer(Addr(grp_id_, id_, kWorkerParam)); for (auto layer : net->layers()) { if (layer->partition_id() == id_) { if (typeid(*layer) == typeid(BridgeDstLayer) || typeid(*layer) == typeid(BridgeSrcLayer)) { - // TODO(wangsh): provide a unique socket id from cluster - bridge_dealer_ = new Dealer(1); - ConnectStub(grp_id_, id_, bridge_dealer_, kWorkerLayer); + bridge_dealer_ = new Dealer(Addr(grp_id_, id_, kWorkerLayer)); break; } } @@ -253,6 +242,7 @@ int Worker::Put(int step, Param* param) { msg->set_trgt(ParamTrgt(param->owner(), 0), step); msg->set_type(kPut); dealer_->Send(&msg); +// LOG(ERROR) << "worker msg " << msg; return 1; }
