SINGA-233 New communication interface for SINGA Reformat code following google style. Update cmake files to ignore communication code if ENABLE_DIST is off.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/45620d59 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/45620d59 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/45620d59 Branch: refs/heads/dev Commit: 45620d59f00faa7d75e704259cb4692e3d3d9663 Parents: 4353ce9 Author: Wei Wang <[email protected]> Authored: Wed Aug 10 13:22:47 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Aug 10 13:22:47 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 6 +- cmake/Templates/singa_config.h.in | 2 + include/singa/io/network.h | 195 ++--- src/io/network/endpoint.cc | 1248 ++++++++++++++++---------------- src/io/network/message.cc | 85 +-- test/CMakeLists.txt | 17 +- test/singa/test_convolution.cc | 4 + test/singa/test_ep.cc | 141 ++-- 8 files changed, 880 insertions(+), 818 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e529f0..b306e8b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ OPTION(USE_OPENCV "Use opencv" OFF) OPTION(USE_LMDB "Use LMDB libs" OFF) OPTION(USE_PYTHON "Generate py wrappers" OFF) OPTION(USE_OPENCL "Use OpenCL" OFF) +OPTION(ENABLE_DIST "enable distributed training" OFF) #OPTION(BUILD_OPENCL_TESTS "Build OpenCL tests" OFF) INCLUDE("cmake/Dependencies.cmake") @@ -47,7 +48,10 @@ IF (USE_CUDA) LIST(APPEND SINGA_LINKER_LIBS cnmem) ENDIF() -LIST(APPEND SINGA_LINKER_LIBS ev) +# TODO(wangwei) detect the ev lib +IF (ENABLE_DIST) + LIST(APPEND SINGA_LINKER_LIBS ev) +ENDIF() ADD_SUBDIRECTORY(src) ADD_SUBDIRECTORY(test) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/cmake/Templates/singa_config.h.in ---------------------------------------------------------------------- diff --git a/cmake/Templates/singa_config.h.in b/cmake/Templates/singa_config.h.in index 0220d18..75eb062 100644 --- a/cmake/Templates/singa_config.h.in +++ b/cmake/Templates/singa_config.h.in @@ -17,6 +17,8 @@ #cmakedefine USE_OPENCL +#cmakedefine ENABLE_DIST + // lmdb #cmakedefine USE_LMDB http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/include/singa/io/network.h ---------------------------------------------------------------------- diff --git a/include/singa/io/network.h b/include/singa/io/network.h index 846c94b..63983ad 100644 --- a/include/singa/io/network.h +++ b/include/singa/io/network.h @@ -21,7 +21,8 @@ #ifndef SINGA_COMM_NETWORK_H_ #define SINGA_COMM_NETWORK_H_ - +#include "singa/singa_config.h" +#ifdef ENABLE_DIST #include <ev.h> #include <thread> #include <unordered_map> @@ -58,107 +59,113 @@ class NetworkThread; class EndPoint; class EndPointFactory; -class Message{ - private: - uint8_t type_; - uint32_t id_; - std::size_t msize_ = 0; - std::size_t psize_ = 0; - std::size_t processed_ = 0; - char* msg_ = nullptr; - static const int hsize_ = sizeof(id_) + 2 * sizeof(std::size_t) + sizeof(type_); - char mdata_[hsize_]; - friend class NetworkThread; - friend class EndPoint; - public: - Message(int = MSG_DATA, uint32_t = 0); - Message(const Message&) = delete; - Message(Message&&); - ~Message(); - - void setMetadata(const void*, int); - void setPayload(const void*, int); - - std::size_t getMetadata(void**); - std::size_t getPayload(void**); - - std::size_t getSize(); - void setId(uint32_t); +class Message { +private: + uint8_t type_; + uint32_t id_; + std::size_t msize_ = 0; + std::size_t psize_ = 0; + std::size_t processed_ = 0; + char *msg_ = nullptr; + static const int hsize_ = + sizeof(id_) + 2 * sizeof(std::size_t) + sizeof(type_); + char mdata_[hsize_]; + friend class NetworkThread; + friend class EndPoint; + +public: + Message(int = MSG_DATA, uint32_t = 0); + Message(const Message &) = delete; + Message(Message &&); + ~Message(); + + void setMetadata(const void *, int); + void setPayload(const void *, int); + + std::size_t getMetadata(void **); + std::size_t getPayload(void **); + + std::size_t getSize(); + void setId(uint32_t); }; class EndPoint { - private: - std::queue<Message*> send_; - std::queue<Message*> recv_; - std::queue<Message*> to_ack_; - std::condition_variable cv_; - std::mutex mtx_; - struct sockaddr_in addr_; - ev_timer timer_; - ev_tstamp last_msg_time_; - int fd_[2] = {-1, -1}; // two endpoints simultaneously connect to each other - int pfd_ = -1; - bool is_socket_loop_ = false; - int conn_status_ = CONN_INIT; - int pending_cnt_ = 0; - int retry_cnt_ = 0; - NetworkThread* thread_ = nullptr; - EndPoint(NetworkThread* t); - ~EndPoint(); - friend class NetworkThread; - friend class EndPointFactory; - public: - int send(Message*); - Message* recv(); +private: + std::queue<Message *> send_; + std::queue<Message *> recv_; + std::queue<Message *> to_ack_; + std::condition_variable cv_; + std::mutex mtx_; + struct sockaddr_in addr_; + ev_timer timer_; + ev_tstamp last_msg_time_; + int fd_[2] = { -1, -1 }; // two endpoints simultaneously connect to each other + int pfd_ = -1; + bool is_socket_loop_ = false; + int conn_status_ = CONN_INIT; + int pending_cnt_ = 0; + int retry_cnt_ = 0; + NetworkThread *thread_ = nullptr; + EndPoint(NetworkThread *t); + ~EndPoint(); + friend class NetworkThread; + friend class EndPointFactory; + +public: + int send(Message *); + Message *recv(); }; class EndPointFactory { - private: - std::unordered_map<uint32_t, EndPoint*> ip_ep_map_; - std::condition_variable map_cv_; - std::mutex map_mtx_; - NetworkThread* thread_; - EndPoint* getEp(uint32_t ip); - EndPoint* getOrCreateEp(uint32_t ip); - friend class NetworkThread; - public: - EndPointFactory(NetworkThread* thread) : thread_(thread) {} - ~EndPointFactory(); - EndPoint* getEp(const char* host); - void getNewEps(std::vector<EndPoint*>& neps); +private: + std::unordered_map<uint32_t, EndPoint *> ip_ep_map_; + std::condition_variable map_cv_; + std::mutex map_mtx_; + NetworkThread *thread_; + EndPoint *getEp(uint32_t ip); + EndPoint *getOrCreateEp(uint32_t ip); + friend class NetworkThread; + +public: + EndPointFactory(NetworkThread *thread) : thread_(thread) {} + ~EndPointFactory(); + EndPoint *getEp(const char *host); + void getNewEps(std::vector<EndPoint *> &neps); }; -class NetworkThread{ - private: - struct ev_loop *loop_; - ev_async ep_sig_; - ev_async msg_sig_; - ev_io socket_watcher_; - int port_; - int socket_fd_; - std::thread* thread_; - std::unordered_map<int, ev_io> fd_wwatcher_map_; - std::unordered_map<int, ev_io> fd_rwatcher_map_; - std::unordered_map<int, EndPoint*> fd_ep_map_; - std::map<int, Message> pending_msgs_; - - void handleConnLost(int, EndPoint*, bool = true); - void doWork(); - int asyncSend(int); - void asyncSendPendingMsg(EndPoint*); - void afterConnEst(EndPoint* ep, int fd, bool active); - public: - EndPointFactory* epf_; - - NetworkThread(int); - void notify(int signal); - - void onRecv(int fd); - void onSend(int fd = -1); - void onConnEst(int fd); - void onNewEp(); - void onNewConn(); - void onTimeout(struct ev_timer* timer); +class NetworkThread { +private: + struct ev_loop *loop_; + ev_async ep_sig_; + ev_async msg_sig_; + ev_io socket_watcher_; + int port_; + int socket_fd_; + std::thread *thread_; + std::unordered_map<int, ev_io> fd_wwatcher_map_; + std::unordered_map<int, ev_io> fd_rwatcher_map_; + std::unordered_map<int, EndPoint *> fd_ep_map_; + std::map<int, Message> pending_msgs_; + + void handleConnLost(int, EndPoint *, bool = true); + void doWork(); + int asyncSend(int); + void asyncSendPendingMsg(EndPoint *); + void afterConnEst(EndPoint *ep, int fd, bool active); + +public: + EndPointFactory *epf_; + + NetworkThread(int); + void notify(int signal); + + void onRecv(int fd); + void onSend(int fd = -1); + void onConnEst(int fd); + void onNewEp(); + void onNewConn(); + void onTimeout(struct ev_timer *timer); }; } +#endif // ENABLE_DIST #endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/src/io/network/endpoint.cc ---------------------------------------------------------------------- diff --git a/src/io/network/endpoint.cc b/src/io/network/endpoint.cc index 96a2e4a..e61acdb 100644 --- a/src/io/network/endpoint.cc +++ b/src/io/network/endpoint.cc @@ -18,9 +18,11 @@ * under the License. * *************************************************************/ +#include "singa/singa_config.h" +#ifdef ENABLE_DIST #include "singa/io/network.h" -#include "singa/io/integer.h" +#include "singa/utils/integer.h" #include "singa/utils/logging.h" #include <sys/socket.h> @@ -34,374 +36,376 @@ namespace singa { -static void async_ep_cb(struct ev_loop* loop, ev_async* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onNewEp(); +static void async_ep_cb(struct ev_loop *loop, ev_async *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onNewEp(); } -static void async_msg_cb(struct ev_loop* loop, ev_async* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onSend(); +static void async_msg_cb(struct ev_loop *loop, ev_async *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onSend(); } -static void writable_cb(struct ev_loop* loop, ev_io* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onSend(ev->fd); +static void writable_cb(struct ev_loop *loop, ev_io *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onSend(ev->fd); } -static void readable_cb(struct ev_loop* loop, ev_io* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onRecv(ev->fd); +static void readable_cb(struct ev_loop *loop, ev_io *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onRecv(ev->fd); } -static void conn_cb(struct ev_loop* loop, ev_io* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onConnEst(ev->fd); +static void conn_cb(struct ev_loop *loop, ev_io *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onConnEst(ev->fd); } -static void accept_cb(struct ev_loop* loop, ev_io* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onNewConn(); +static void accept_cb(struct ev_loop *loop, ev_io *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onNewConn(); } -static void timeout_cb(struct ev_loop* loop, ev_timer* ev, int revent) { - reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onTimeout(ev); +static void timeout_cb(struct ev_loop *loop, ev_timer *ev, int revent) { + reinterpret_cast<NetworkThread *>(ev_userdata(loop))->onTimeout(ev); } -EndPoint::EndPoint(NetworkThread* t) : thread_(t) { - this->timer_.data = reinterpret_cast<void*>(this); +EndPoint::EndPoint(NetworkThread *t) : thread_(t) { + this->timer_.data = reinterpret_cast<void *>(this); } EndPoint::~EndPoint() { - while(!recv_.empty()) { - delete send_.front(); - send_.pop(); - } - while(!to_ack_.empty()) { - delete send_.front(); - send_.pop(); - } - while(!send_.empty()) { - delete send_.front(); - send_.pop(); - } + while (!recv_.empty()) { + delete send_.front(); + send_.pop(); + } + while (!to_ack_.empty()) { + delete send_.front(); + send_.pop(); + } + while (!send_.empty()) { + delete send_.front(); + send_.pop(); + } } -int EndPoint::send(Message* msg) { - CHECK(msg->type_ == MSG_DATA); - static std::atomic<uint32_t> id(0); - std::unique_lock<std::mutex> lock(this->mtx_); +int EndPoint::send(Message *msg) { + CHECK(msg->type_ == MSG_DATA); + static std::atomic<uint32_t> id(0); + std::unique_lock<std::mutex> lock(this->mtx_); - if (this->conn_status_ == CONN_ERROR) { - LOG(INFO) << "EndPoint " << inet_ntoa(addr_.sin_addr) << " is disconnected"; - return -1; - } + if (this->conn_status_ == CONN_ERROR) { + LOG(INFO) << "EndPoint " << inet_ntoa(addr_.sin_addr) << " is disconnected"; + return -1; + } - if (msg->psize_ == 0 && msg->msize_ == 0) - // no data to send - return 0; + if (msg->psize_ == 0 && msg->msize_ == 0) + // no data to send + return 0; - msg->setId(id++); + msg->setId(id++); - send_.push(new Message(static_cast<Message&&>(*msg))); + send_.push(new Message(static_cast<Message &&>(*msg))); - thread_->notify(SIG_MSG); - return msg->getSize(); + thread_->notify(SIG_MSG); + return msg->getSize(); } -Message* EndPoint::recv() { - std::unique_lock<std::mutex> lock(this->mtx_); - while(this->recv_.empty() && conn_status_ != CONN_ERROR) - this->cv_.wait(lock); - - Message* ret = nullptr; - if (!recv_.empty()) { - ret = recv_.front(); - recv_.pop(); - } - return ret; +Message *EndPoint::recv() { + std::unique_lock<std::mutex> lock(this->mtx_); + while (this->recv_.empty() && conn_status_ != CONN_ERROR) + this->cv_.wait(lock); + + Message *ret = nullptr; + if (!recv_.empty()) { + ret = recv_.front(); + recv_.pop(); + } + return ret; } EndPointFactory::~EndPointFactory() { - for (auto& p : ip_ep_map_) - { - delete p.second; - } + for (auto &p : ip_ep_map_) { + delete p.second; + } } -EndPoint* EndPointFactory::getOrCreateEp(uint32_t ip) { - std::unique_lock<std::mutex> lock(map_mtx_); - if (0 == ip_ep_map_.count(ip)) { - ip_ep_map_[ip] = new EndPoint(this->thread_); - } - return ip_ep_map_[ip]; +EndPoint *EndPointFactory::getOrCreateEp(uint32_t ip) { + std::unique_lock<std::mutex> lock(map_mtx_); + if (0 == ip_ep_map_.count(ip)) { + ip_ep_map_[ip] = new EndPoint(this->thread_); + } + return ip_ep_map_[ip]; } -EndPoint* EndPointFactory::getEp(uint32_t ip) { - std::unique_lock<std::mutex> lock(map_mtx_); - if (0 == ip_ep_map_.count(ip)) { - return nullptr; - } - return ip_ep_map_[ip]; +EndPoint *EndPointFactory::getEp(uint32_t ip) { + std::unique_lock<std::mutex> lock(map_mtx_); + if (0 == ip_ep_map_.count(ip)) { + return nullptr; + } + return ip_ep_map_[ip]; } -EndPoint* EndPointFactory::getEp(const char* host) { - // get the ip address of host - struct hostent *he; - struct in_addr **list; - - if ((he = gethostbyname(host)) == nullptr) { - LOG(INFO) << "Unable to resolve host " << host; - return nullptr; - } - - list = (struct in_addr**) he->h_addr_list; - uint32_t ip = ntohl(list[0]->s_addr); - - EndPoint* ep = nullptr; - map_mtx_.lock(); - if (0 == ip_ep_map_.count(ip)) { - ep = new EndPoint(this->thread_); - ep->thread_ = this->thread_; - ip_ep_map_[ip] = ep; - - // copy the address info - bcopy(list[0], &ep->addr_.sin_addr, sizeof(struct in_addr)); - - thread_->notify(SIG_EP); - } - ep = ip_ep_map_[ip]; - map_mtx_.unlock(); +EndPoint *EndPointFactory::getEp(const char *host) { + // get the ip address of host + struct hostent *he; + struct in_addr **list; + + if ((he = gethostbyname(host)) == nullptr) { + LOG(INFO) << "Unable to resolve host " << host; + return nullptr; + } + + list = (struct in_addr **)he->h_addr_list; + uint32_t ip = ntohl(list[0]->s_addr); + + EndPoint *ep = nullptr; + map_mtx_.lock(); + if (0 == ip_ep_map_.count(ip)) { + ep = new EndPoint(this->thread_); + ep->thread_ = this->thread_; + ip_ep_map_[ip] = ep; + + // copy the address info + bcopy(list[0], &ep->addr_.sin_addr, sizeof(struct in_addr)); + + thread_->notify(SIG_EP); + } + ep = ip_ep_map_[ip]; + map_mtx_.unlock(); + + std::unique_lock<std::mutex> eplock(ep->mtx_); + while (ep->conn_status_ == CONN_PENDING || ep->conn_status_ == CONN_INIT) { + ep->pending_cnt_++; + ep->cv_.wait(eplock); + ep->pending_cnt_--; + } + + if (ep->conn_status_ == CONN_ERROR) { + ep = nullptr; + } + + return ep; +} +void EndPointFactory::getNewEps(std::vector<EndPoint *> &neps) { + std::unique_lock<std::mutex> lock(this->map_mtx_); + for (auto &p : this->ip_ep_map_) { + EndPoint *ep = p.second; std::unique_lock<std::mutex> eplock(ep->mtx_); - while (ep->conn_status_ == CONN_PENDING || ep->conn_status_ == CONN_INIT) { - ep->pending_cnt_++; - ep->cv_.wait(eplock); - ep->pending_cnt_--; + if (ep->conn_status_ == CONN_INIT) { + neps.push_back(ep); } - - if (ep->conn_status_ == CONN_ERROR) { - ep = nullptr; - } - - return ep; -} - -void EndPointFactory::getNewEps(std::vector<EndPoint*>& neps) { - std::unique_lock<std::mutex> lock(this->map_mtx_); - for (auto& p : this->ip_ep_map_) { - EndPoint* ep = p.second; - std::unique_lock<std::mutex> eplock(ep->mtx_); - if (ep->conn_status_ == CONN_INIT) { - neps.push_back(ep); - } - } + } } NetworkThread::NetworkThread(int port) { - this->port_ = port; - thread_ = new std::thread([this] {doWork();}); - this->epf_ = new EndPointFactory(this); + this->port_ = port; + thread_ = new std::thread([this] { doWork(); }); + this->epf_ = new EndPointFactory(this); } void NetworkThread::doWork() { - // prepare event loop - if (!(loop_ = ev_default_loop(0))) { - // log here - } + // prepare event loop + if (!(loop_ = ev_default_loop(0))) { + // log here + } - ev_async_init(&ep_sig_, async_ep_cb); - ev_async_start(loop_, &ep_sig_); + ev_async_init(&ep_sig_, async_ep_cb); + ev_async_start(loop_, &ep_sig_); - ev_async_init(&msg_sig_, async_msg_cb); - ev_async_start(loop_, &msg_sig_); + ev_async_init(&msg_sig_, async_msg_cb); + ev_async_start(loop_, &msg_sig_); - // bind and listen - struct sockaddr_in addr; - if ((socket_fd_ = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - LOG(FATAL) << "Socket Error: " << strerror(errno); - } + // bind and listen + struct sockaddr_in addr; + if ((socket_fd_ = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + LOG(FATAL) << "Socket Error: " << strerror(errno); + } - bzero(&addr, sizeof(addr)); - addr.sin_family = AF_INET; - addr.sin_port = htons(this->port_); - addr.sin_addr.s_addr = INADDR_ANY; + bzero(&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(this->port_); + addr.sin_addr.s_addr = INADDR_ANY; - if (bind(socket_fd_, (struct sockaddr*)&addr, sizeof(addr))) { - LOG(FATAL) << "Bind Error: " << strerror(errno); - } + if (bind(socket_fd_, (struct sockaddr *)&addr, sizeof(addr))) { + LOG(FATAL) << "Bind Error: " << strerror(errno); + } - if (listen(socket_fd_, 10)) { - LOG(FATAL) << "Listen Error: " << strerror(errno); - } + if (listen(socket_fd_, 10)) { + LOG(FATAL) << "Listen Error: " << strerror(errno); + } - ev_io_init(&socket_watcher_, accept_cb, socket_fd_, EV_READ); - ev_io_start(loop_, &socket_watcher_); + ev_io_init(&socket_watcher_, accept_cb, socket_fd_, EV_READ); + ev_io_start(loop_, &socket_watcher_); - ev_set_userdata(loop_, this); + ev_set_userdata(loop_, this); - while(1) - ev_run(loop_, 0); + while (1) + ev_run(loop_, 0); } void NetworkThread::notify(int signal) { - switch(signal) { - case SIG_EP: - ev_async_send(this->loop_, &this->ep_sig_); - break; - case SIG_MSG: - ev_async_send(this->loop_, &this->msg_sig_); - break; - default: - break; - } + switch (signal) { + case SIG_EP: + ev_async_send(this->loop_, &this->ep_sig_); + break; + case SIG_MSG: + ev_async_send(this->loop_, &this->msg_sig_); + break; + default: + break; + } } void NetworkThread::onNewEp() { - std::vector<EndPoint*> neps; - this->epf_->getNewEps(neps); - - for (auto& ep : neps) { - std::unique_lock<std::mutex> ep_lock(ep->mtx_); - int& fd = ep->fd_[0]; - if (ep->conn_status_ == CONN_INIT) { - - fd = socket(AF_INET, SOCK_STREAM, 0); - if (fd < 0) { - // resources not available - LOG(FATAL) << "Unable to create socket"; - } + std::vector<EndPoint *> neps; + this->epf_->getNewEps(neps); - // set this fd non-blocking - fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); - - this->fd_ep_map_[fd] = ep; - - // initialize the addess - ep->addr_.sin_family = AF_INET; - ep->addr_.sin_port = htons(port_); - bzero(&(ep->addr_.sin_zero), 8); - - LOG(INFO) << "Connecting to " << inet_ntoa(ep->addr_.sin_addr) << " fd = "<< fd; - if (connect(fd, (struct sockaddr*)&ep->addr_, - sizeof(struct sockaddr)) ) { - LOG(INFO) << "Connect Error: " << strerror(errno); - if (errno != EINPROGRESS) { - ep->conn_status_ = CONN_ERROR; - ep->cv_.notify_all(); - continue; - } else { - ep->conn_status_ = CONN_PENDING; - ev_io_init(&this->fd_wwatcher_map_[fd], conn_cb, fd, EV_WRITE); - ev_io_start(this->loop_, &this->fd_wwatcher_map_[fd]); - } - } else { - afterConnEst(ep, fd, true); + for (auto &ep : neps) { + std::unique_lock<std::mutex> ep_lock(ep->mtx_); + int &fd = ep->fd_[0]; + if (ep->conn_status_ == CONN_INIT) { + + fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + // resources not available + LOG(FATAL) << "Unable to create socket"; + } + + // set this fd non-blocking + fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); + + this->fd_ep_map_[fd] = ep; + + // initialize the addess + ep->addr_.sin_family = AF_INET; + ep->addr_.sin_port = htons(port_); + bzero(&(ep->addr_.sin_zero), 8); + + LOG(INFO) << "Connecting to " << inet_ntoa(ep->addr_.sin_addr) + << " fd = " << fd; + if (connect(fd, (struct sockaddr *)&ep->addr_, sizeof(struct sockaddr))) { + LOG(INFO) << "Connect Error: " << strerror(errno); + if (errno != EINPROGRESS) { + ep->conn_status_ = CONN_ERROR; + ep->cv_.notify_all(); + continue; + } else { + ep->conn_status_ = CONN_PENDING; + ev_io_init(&this->fd_wwatcher_map_[fd], conn_cb, fd, EV_WRITE); + ev_io_start(this->loop_, &this->fd_wwatcher_map_[fd]); + } + } else { + afterConnEst(ep, fd, true); - // connection established immediately - // LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) << " fd = "<< fd; - // ep->conn_status_ = CONN_EST; + // connection established immediately + // LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) << " fd + // = "<< fd; + // ep->conn_status_ = CONN_EST; - // //ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]); - // ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); + // //ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]); + // ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); - // // poll for new msgs - // ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); - // ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]); + // // poll for new msgs + // ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); + // ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]); - // asyncSendPendingMsg(ep); - // ep->cv_.notify_all(); - } - } + // asyncSendPendingMsg(ep); + // ep->cv_.notify_all(); + } } + } } void NetworkThread::onConnEst(int fd) { - //EndPoint* ep = epf_->getEp(this->fd_ip_map_[fd]); - CHECK(fd_ep_map_.count(fd) > 0); - EndPoint* ep = fd_ep_map_.at(fd); + // EndPoint* ep = epf_->getEp(this->fd_ip_map_[fd]); + CHECK(fd_ep_map_.count(fd) > 0); + EndPoint *ep = fd_ep_map_.at(fd); - std::unique_lock<std::mutex> lock(ep->mtx_); + std::unique_lock<std::mutex> lock(ep->mtx_); - if (connect(fd, (struct sockaddr*)&ep->addr_, sizeof(struct sockaddr)) < 0 && errno != EISCONN) { - LOG(INFO) << "Unable to connect to " << inet_ntoa(ep->addr_.sin_addr) << ": "<< strerror(errno); - if (errno == EINPROGRESS) { - // continue to watch this socket - return; - } + if (connect(fd, (struct sockaddr *)&ep->addr_, sizeof(struct sockaddr)) < 0 && + errno != EISCONN) { + LOG(INFO) << "Unable to connect to " << inet_ntoa(ep->addr_.sin_addr) + << ": " << strerror(errno); + if (errno == EINPROGRESS) { + // continue to watch this socket + return; + } - handleConnLost(ep->fd_[0], ep); + handleConnLost(ep->fd_[0], ep); - if (ep->conn_status_ == CONN_EST && ep->conn_status_ == CONN_ERROR) - ep->cv_.notify_all(); + if (ep->conn_status_ == CONN_EST && ep->conn_status_ == CONN_ERROR) + ep->cv_.notify_all(); - } else { + } else { - afterConnEst(ep, fd, true); + afterConnEst(ep, fd, true); - //ep->conn_status_ = CONN_EST; - //// connect established; poll for new msgs - //ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]); - //ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); + // ep->conn_status_ = CONN_EST; + //// connect established; poll for new msgs + // ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]); + // ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); - //ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); - //ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]); - } + // ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); + // ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]); + } } void NetworkThread::onNewConn() { - // accept new tcp connection - struct sockaddr_in addr; - socklen_t len = sizeof(addr); - int fd = accept(socket_fd_, (struct sockaddr*)&addr, &len); - if (fd < 0) { - LOG(INFO) << "Accept Error: " << strerror(errno); - return; - } + // accept new tcp connection + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + int fd = accept(socket_fd_, (struct sockaddr *)&addr, &len); + if (fd < 0) { + LOG(INFO) << "Accept Error: " << strerror(errno); + return; + } - LOG(INFO) << "Accept a client from " << inet_ntoa(addr.sin_addr) << ", fd = " << fd; + LOG(INFO) << "Accept a client from " << inet_ntoa(addr.sin_addr) + << ", fd = " << fd; - // set this fd as non-blocking - fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); + // set this fd as non-blocking + fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK); - EndPoint* ep; - uint32_t a = ntohl(addr.sin_addr.s_addr); + EndPoint *ep; + uint32_t a = ntohl(addr.sin_addr.s_addr); - ep = epf_->getOrCreateEp(a); - std::unique_lock<std::mutex> lock(ep->mtx_); + ep = epf_->getOrCreateEp(a); + std::unique_lock<std::mutex> lock(ep->mtx_); - // Passive connection - afterConnEst(ep, fd, false); + // Passive connection + afterConnEst(ep, fd, false); - // record the remote address - bcopy(&addr, &ep->addr_, len); + // record the remote address + bcopy(&addr, &ep->addr_, len); } -void NetworkThread::onTimeout(struct ev_timer* timer) { - - EndPoint* ep = reinterpret_cast<EndPoint*>(timer->data); +void NetworkThread::onTimeout(struct ev_timer *timer) { - ev_tstamp timeout = EP_TIMEOUT + ep->last_msg_time_; - ev_tstamp now = ev_now(loop_); + EndPoint *ep = reinterpret_cast<EndPoint *>(timer->data); - std::unique_lock<std::mutex> lock(ep->mtx_); - if (now > timeout) { - if (!ep->to_ack_.empty() || !ep->send_.empty()) { - - LOG(INFO) << "EndPoint " << inet_ntoa(ep->addr_.sin_addr) << " timeouts"; - // we consider this ep has been disconnected - for (int i = 0; i < 2; ++i) - { - int fd = ep->fd_[i]; - if (fd >= 0) - handleConnLost(fd, ep); - } - return; - } + ev_tstamp timeout = EP_TIMEOUT + ep->last_msg_time_; + ev_tstamp now = ev_now(loop_); - timer->repeat = EP_TIMEOUT; + std::unique_lock<std::mutex> lock(ep->mtx_); + if (now > timeout) { + if (!ep->to_ack_.empty() || !ep->send_.empty()) { - } else { - timer->repeat = timeout - now; + LOG(INFO) << "EndPoint " << inet_ntoa(ep->addr_.sin_addr) << " timeouts"; + // we consider this ep has been disconnected + for (int i = 0; i < 2; ++i) { + int fd = ep->fd_[i]; + if (fd >= 0) + handleConnLost(fd, ep); + } + return; } - ev_timer_again(loop_, &ep->timer_); + timer->repeat = EP_TIMEOUT; + + } else { + timer->repeat = timeout - now; + } + + ev_timer_again(loop_, &ep->timer_); } /** @@ -411,322 +415,351 @@ void NetworkThread::onTimeout(struct ev_timer* timer) { * @param fd * @param active indicate whethen this socket is locally initiated or not */ -void NetworkThread::afterConnEst(EndPoint* ep, int fd, bool active) { - - if (active) - LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) << ", fd = "<< fd; - - int sfd; - - if (active) { - ep->fd_[0] = fd; - sfd = ep->fd_[1]; - } else { - if (ep->fd_[1] >= 0) { - // the previous connection is lost - handleConnLost(ep->fd_[1], ep, false); - } - ep->fd_[1] = fd; - sfd = ep->fd_[0]; - } - - if (sfd == fd) { - // this fd is a reuse of a previous socket fd - // so we first need to clean the resouce for that fd - // we duplicate this fd to let the resouce of the oldf fd can be freed - // also indicate there is no need to reconnect - fd = dup(fd); - handleConnLost(sfd, ep, false); - } - - // initialize io watchers and add the read watcher to the ev loop - ev_io_init(&fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); - ev_io_start(loop_, &fd_rwatcher_map_[fd]); - - // stop watching the writable watcher if necessary - if (active) - ev_io_stop(loop_, &fd_wwatcher_map_[fd]); - ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); - - ep->last_msg_time_ = ev_now(loop_); - - // see whether there is already a established connection for this fd - if (ep->conn_status_ == CONN_EST && sfd >= 0) { - // check if fd and sfd are associate with the same socket - struct sockaddr_in addr; - socklen_t len; - if (getsockname(fd, (struct sockaddr*)&addr, &len)) { - LOG(INFO) << "Unable to get local socket address: " << strerror(errno); - } else { - // see whether the local address of fd is the same as the remote side - // of sfd, which has already been stored in ep->addr_ - if (addr.sin_addr.s_addr == ep->addr_.sin_addr.s_addr && addr.sin_port == ep->addr_.sin_port) { - LOG(INFO) << fd << " and " << sfd << " are associated with the same socket"; - ep->is_socket_loop_ = true; - } else { - // this socket is redundant, we close it maunally if the local ip - // is smaller than the peer ip - if ((addr.sin_addr.s_addr < ep->addr_.sin_addr.s_addr) - || (addr.sin_addr.s_addr == ep->addr_.sin_addr.s_addr && addr.sin_port < ep->addr_.sin_port)) - handleConnLost(fd, ep, false); - } - } +void NetworkThread::afterConnEst(EndPoint *ep, int fd, bool active) { + + if (active) + LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) + << ", fd = " << fd; + + int sfd; + + if (active) { + ep->fd_[0] = fd; + sfd = ep->fd_[1]; + } else { + if (ep->fd_[1] >= 0) { + // the previous connection is lost + handleConnLost(ep->fd_[1], ep, false); + } + ep->fd_[1] = fd; + sfd = ep->fd_[0]; + } + + if (sfd == fd) { + // this fd is a reuse of a previous socket fd + // so we first need to clean the resouce for that fd + // we duplicate this fd to let the resouce of the oldf fd can be freed + // also indicate there is no need to reconnect + fd = dup(fd); + handleConnLost(sfd, ep, false); + } + + // initialize io watchers and add the read watcher to the ev loop + ev_io_init(&fd_rwatcher_map_[fd], readable_cb, fd, EV_READ); + ev_io_start(loop_, &fd_rwatcher_map_[fd]); + + // stop watching the writable watcher if necessary + if (active) + ev_io_stop(loop_, &fd_wwatcher_map_[fd]); + ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, EV_WRITE); + + ep->last_msg_time_ = ev_now(loop_); + + // see whether there is already a established connection for this fd + if (ep->conn_status_ == CONN_EST && sfd >= 0) { + // check if fd and sfd are associate with the same socket + struct sockaddr_in addr; + socklen_t len; + if (getsockname(fd, (struct sockaddr *)&addr, &len)) { + LOG(INFO) << "Unable to get local socket address: " << strerror(errno); } else { - ep->pfd_ = fd; // set the primary fd - ep->conn_status_ = CONN_EST; - - // start timeout watcher to detect the liveness of EndPoint - ev_init(&ep->timer_, timeout_cb); - ep->timer_.repeat = EP_TIMEOUT; - ev_timer_start(loop_, &ep->timer_); - //timeout_cb(loop_, &ep->timer_, EV_TIMER); - } - - if (fd == ep->pfd_) { - this->asyncSendPendingMsg(ep); - } - - fd_ep_map_[fd] = ep; - - // Finally notify all waiting threads - // if this connection is initiaed by remote side, - // we dont need to notify the waiting thread - // later threads wanting to send to this ep, however, - // are able to reuse this ep - if (active) { - ep->cv_.notify_all(); - } + // see whether the local address of fd is the same as the remote side + // of sfd, which has already been stored in ep->addr_ + if (addr.sin_addr.s_addr == ep->addr_.sin_addr.s_addr && + addr.sin_port == ep->addr_.sin_port) { + LOG(INFO) << fd << " and " << sfd + << " are associated with the same socket"; + ep->is_socket_loop_ = true; + } else { + // this socket is redundant, we close it maunally if the local ip + // is smaller than the peer ip + if ((addr.sin_addr.s_addr < ep->addr_.sin_addr.s_addr) || + (addr.sin_addr.s_addr == ep->addr_.sin_addr.s_addr && + addr.sin_port < ep->addr_.sin_port)) + handleConnLost(fd, ep, false); + } + } + } else { + ep->pfd_ = fd; // set the primary fd + ep->conn_status_ = CONN_EST; + + // start timeout watcher to detect the liveness of EndPoint + ev_init(&ep->timer_, timeout_cb); + ep->timer_.repeat = EP_TIMEOUT; + ev_timer_start(loop_, &ep->timer_); + // timeout_cb(loop_, &ep->timer_, EV_TIMER); + } + + if (fd == ep->pfd_) { + this->asyncSendPendingMsg(ep); + } + + fd_ep_map_[fd] = ep; + + // Finally notify all waiting threads + // if this connection is initiaed by remote side, + // we dont need to notify the waiting thread + // later threads wanting to send to this ep, however, + // are able to reuse this ep + if (active) { + ep->cv_.notify_all(); + } } void NetworkThread::onSend(int fd) { - std::vector<int> invalid_fd; - - if (fd == -1) { - //LOG(INFO) << "There are " << fd_ip_map_.size() << " connections"; - // this is a signal of new message to send - for(auto& p : fd_ep_map_) { - // send message - //LOG(INFO) << "Try to send over fd " << p.first; - if (asyncSend(p.first) < 0) - invalid_fd.push_back(p.first); - } - } else { - if (asyncSend(fd) < 0) - invalid_fd.push_back(fd); - } - - for (auto& p : invalid_fd) { - //EndPoint* ep = epf_->getEp(fd_ip_map_.at(p)); - EndPoint* ep = fd_ep_map_.at(p); - std::unique_lock<std::mutex> lock(ep->mtx_); - handleConnLost(p, ep); - } + std::vector<int> invalid_fd; + + if (fd == -1) { + // LOG(INFO) << "There are " << fd_ip_map_.size() << " connections"; + // this is a signal of new message to send + for (auto &p : fd_ep_map_) { + // send message + // LOG(INFO) << "Try to send over fd " << p.first; + if (asyncSend(p.first) < 0) + invalid_fd.push_back(p.first); + } + } else { + if (asyncSend(fd) < 0) + invalid_fd.push_back(fd); + } + + for (auto &p : invalid_fd) { + // EndPoint* ep = epf_->getEp(fd_ip_map_.at(p)); + EndPoint *ep = fd_ep_map_.at(p); + std::unique_lock<std::mutex> lock(ep->mtx_); + handleConnLost(p, ep); + } } -void NetworkThread::asyncSendPendingMsg(EndPoint* ep) { - // simply put the pending msgs to the send queue +void NetworkThread::asyncSendPendingMsg(EndPoint *ep) { + // simply put the pending msgs to the send queue - LOG(INFO) << "There are " << ep->send_.size() << " to-send msgs, and " << ep->to_ack_.size() << " to-ack msgs"; + LOG(INFO) << "There are " << ep->send_.size() << " to-send msgs, and " + << ep->to_ack_.size() << " to-ack msgs"; - if (!ep->to_ack_.empty()) { - while (!ep->send_.empty()) { - ep->to_ack_.push(ep->send_.front()); - ep->send_.pop(); - } - std::swap(ep->send_, ep->to_ack_); + if (!ep->to_ack_.empty()) { + while (!ep->send_.empty()) { + ep->to_ack_.push(ep->send_.front()); + ep->send_.pop(); } + std::swap(ep->send_, ep->to_ack_); + } - if (ep->send_.size() > 0) { - notify(SIG_MSG); - } + if (ep->send_.size() > 0) { + notify(SIG_MSG); + } } /** - * @brief non-locking send; + * @brief non-locking send; * * @param ep * */ int NetworkThread::asyncSend(int fd) { - //EndPoint* ep = epf_->getEp(fd_ip_map_[fd]); - CHECK(fd_ep_map_.count(fd) > 0); - EndPoint* ep = fd_ep_map_.at(fd); - - std::unique_lock<std::mutex> ep_lock(ep->mtx_); - - if (fd != ep->pfd_ ) - // we only send over the primary fd - // return -1 to indicate this fd is redundant - return ep->is_socket_loop_ ? 0 : -1; + // EndPoint* ep = epf_->getEp(fd_ip_map_[fd]); + CHECK(fd_ep_map_.count(fd) > 0); + EndPoint *ep = fd_ep_map_.at(fd); - if (ep->conn_status_ != CONN_EST) - // This happens during reconnection - goto out; + std::unique_lock<std::mutex> ep_lock(ep->mtx_); - while(!ep->send_.empty()) { + if (fd != ep->pfd_) + // we only send over the primary fd + // return -1 to indicate this fd is redundant + return ep->is_socket_loop_ ? 0 : -1; - Message& msg = *ep->send_.front(); - int nbytes; + if (ep->conn_status_ != CONN_EST) + // This happens during reconnection + goto out; - while(msg.processed_ < msg.getSize()) { - if (msg.type_ == MSG_ACK) { - nbytes = write(fd, msg.mdata_ + msg.processed_, msg.getSize() - msg.processed_); - } - else - nbytes = write(fd, msg.msg_ + msg.processed_, msg.getSize() - msg.processed_); - - if (nbytes == -1) { - if (errno == EWOULDBLOCK) { - if (!ev_is_active(&fd_wwatcher_map_[fd]) && !ev_is_pending(&fd_wwatcher_map_[fd])) - ev_io_start(loop_, &fd_wwatcher_map_[fd]); - goto out; - } else { - // this connection is lost; reset the send status - // so that next time the whole msg would be sent entirely - msg.processed_ = 0; - goto err; - } - } else { - ep->last_msg_time_ = ev_now(loop_); - msg.processed_ += nbytes; - } + while (!ep->send_.empty()) { - //std::size_t m, p; - //uint8_t type; - //uint32_t id; - //if (msg.msg_) { - // readInteger(msg.msg_, type, id, m, p); - // LOG(INFO) << "Send " << msg.processed_ << " bytes to " << inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd << " for the current DATA MSG " << msg.id_ << ", " << id << ", " << m << ", " << p; - //} - } + Message &msg = *ep->send_.front(); + int nbytes; - CHECK(msg.processed_ == msg.getSize()); + while (msg.processed_ < msg.getSize()) { + if (msg.type_ == MSG_ACK) { + nbytes = write(fd, msg.mdata_ + msg.processed_, + msg.getSize() - msg.processed_); + } else + nbytes = write(fd, msg.msg_ + msg.processed_, + msg.getSize() - msg.processed_); - if (msg.type_ != MSG_ACK) { - LOG(INFO) << "Send a DATA message to " << inet_ntoa(ep->addr_.sin_addr) << " for MSG " << msg.id_ << ", len = " << msg.getSize() << " over fd " << fd; - msg.processed_ = 0; - ep->to_ack_.push(&msg); + if (nbytes == -1) { + if (errno == EWOULDBLOCK) { + if (!ev_is_active(&fd_wwatcher_map_[fd]) && + !ev_is_pending(&fd_wwatcher_map_[fd])) + ev_io_start(loop_, &fd_wwatcher_map_[fd]); + goto out; } else { - //LOG(INFO) << "Send an ACK message to " << inet_ntoa(ep->addr_.sin_addr) << " for MSG " << msg.id_; - delete &msg; + // this connection is lost; reset the send status + // so that next time the whole msg would be sent entirely + msg.processed_ = 0; + goto err; } + } else { + ep->last_msg_time_ = ev_now(loop_); + msg.processed_ += nbytes; + } + + // std::size_t m, p; + // uint8_t type; + // uint32_t id; + // if (msg.msg_) { + // readInteger(msg.msg_, type, id, m, p); + // LOG(INFO) << "Send " << msg.processed_ << " bytes to " << + // inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd << " for the current + // DATA MSG " << msg.id_ << ", " << id << ", " << m << ", " << p; + //} + } + + CHECK(msg.processed_ == msg.getSize()); + + if (msg.type_ != MSG_ACK) { + LOG(INFO) << "Send a DATA message to " << inet_ntoa(ep->addr_.sin_addr) + << " for MSG " << msg.id_ << ", len = " << msg.getSize() + << " over fd " << fd; + msg.processed_ = 0; + ep->to_ack_.push(&msg); + } else { + // LOG(INFO) << "Send an ACK message to " << inet_ntoa(ep->addr_.sin_addr) + // << " for MSG " << msg.id_; + delete &msg; + } - ep->send_.pop(); + ep->send_.pop(); - //for test - // if (ep->retry_cnt_ == 0) { - // LOG(INFO) << "Disconnect with Endpoint " << inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; - // close(fd); - // goto err; - // } - } + // for test + // if (ep->retry_cnt_ == 0) { + // LOG(INFO) << "Disconnect with Endpoint " << + // inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; + // close(fd); + // goto err; + // } + } out: - if (ep->send_.empty()) - ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]); - return 0; + if (ep->send_.empty()) + ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]); + return 0; err: - return -1; + return -1; } void NetworkThread::onRecv(int fd) { - Message* m = &pending_msgs_[fd]; - Message& msg = (*m); - int nread; - //EndPoint* ep = epf_->getEp(fd_ip_map_[fd]); - - CHECK(fd_ep_map_.count(fd) > 0); - EndPoint* ep = fd_ep_map_.at(fd); - - //LOG(INFO) << "Start to read from EndPoint " << inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; - - std::unique_lock<std::mutex> lock(ep->mtx_); - - ep->last_msg_time_ = ev_now(loop_); - while(1) { - if (msg.processed_ < Message::hsize_) { - nread = read(fd, msg.mdata_ + msg.processed_, - Message::hsize_ - msg.processed_); - - if (nread <= 0) { - if (errno != EWOULDBLOCK || nread == 0) { - // socket error or shuts down - if (nread < 0) - LOG(INFO) << "Fail to receive from EndPoint " << inet_ntoa(ep->addr_.sin_addr) << ": " << strerror(errno); - else - LOG(INFO) << "Fail to receive from EndPoint " << inet_ntoa(ep->addr_.sin_addr) << ": Connection reset by remote side"; - handleConnLost(fd, ep); - } - break; - } - - msg.processed_ += nread; - while (msg.processed_ >= sizeof(msg.type_) + sizeof(msg.id_)) { - readInteger(msg.mdata_, msg.type_, msg.id_); - if(msg.type_ == MSG_ACK) { - LOG(INFO) << "Receive an ACK message from " << inet_ntoa(ep->addr_.sin_addr) << " for MSG " << msg.id_; - while (!ep->to_ack_.empty()) { - Message* m = ep->to_ack_.front(); - if (m->id_ <= msg.id_) { - delete m; - ep->to_ack_.pop(); - } else { - break; - } - } - - // reset - msg.processed_ -= sizeof(msg.type_) + sizeof(msg.id_); - memmove(msg.mdata_, - msg.mdata_ + sizeof(msg.type_) + sizeof(msg.id_), - msg.processed_); - - } else break; - } - - if (msg.processed_ < Message::hsize_) { - continue; - } - - // got the whole metadata; - readInteger(msg.mdata_, msg.type_, msg.id_, msg.msize_, msg.psize_); - - LOG(INFO) << "Receive a message: id = " << msg.id_ << ", msize_ = " << msg.msize_ << ", psize_ = " << msg.psize_ << " from " << inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; + Message *m = &pending_msgs_[fd]; + Message &msg = (*m); + int nread; + // EndPoint* ep = epf_->getEp(fd_ip_map_[fd]); + + CHECK(fd_ep_map_.count(fd) > 0); + EndPoint *ep = fd_ep_map_.at(fd); + + // LOG(INFO) << "Start to read from EndPoint " << + // inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; + + std::unique_lock<std::mutex> lock(ep->mtx_); + + ep->last_msg_time_ = ev_now(loop_); + while (1) { + if (msg.processed_ < Message::hsize_) { + nread = read(fd, msg.mdata_ + msg.processed_, + Message::hsize_ - msg.processed_); + + if (nread <= 0) { + if (errno != EWOULDBLOCK || nread == 0) { + // socket error or shuts down + if (nread < 0) + LOG(INFO) << "Fail to receive from EndPoint " + << inet_ntoa(ep->addr_.sin_addr) << ": " + << strerror(errno); + else + LOG(INFO) << "Fail to receive from EndPoint " + << inet_ntoa(ep->addr_.sin_addr) + << ": Connection reset by remote side"; + handleConnLost(fd, ep); } - - // start reading the real data - if (msg.msg_ == nullptr) { - msg.msg_ = new char[msg.getSize()]; - memcpy(msg.msg_, msg.mdata_, Message::hsize_); - } - - nread = read(fd, msg.msg_ + msg.processed_, msg.getSize() - msg.processed_); - if (nread <= 0) { - if (errno != EWOULDBLOCK || nread == 0) { - // socket error or shuts down - if (nread < 0) - LOG(INFO) << "Fail to receive from EndPoint " << inet_ntoa(ep->addr_.sin_addr) << ": " << strerror(errno); - else - LOG(INFO) << "Fail to receive from EndPoint " << inet_ntoa(ep->addr_.sin_addr) << ": Connection reset by remote side"; - handleConnLost(fd, ep); + break; + } + + msg.processed_ += nread; + while (msg.processed_ >= sizeof(msg.type_) + sizeof(msg.id_)) { + readInteger(msg.mdata_, msg.type_, msg.id_); + if (msg.type_ == MSG_ACK) { + LOG(INFO) << "Receive an ACK message from " + << inet_ntoa(ep->addr_.sin_addr) << " for MSG " << msg.id_; + while (!ep->to_ack_.empty()) { + Message *m = ep->to_ack_.front(); + if (m->id_ <= msg.id_) { + delete m; + ep->to_ack_.pop(); + } else { + break; } - break; - } - - msg.processed_ += nread; - - //LOG(INFO) << "Receive a message: id = " << msg.id_ << ", msize_ = " << msg.msize_ << ", psize_ = " << msg.psize_ << ", processed_ = " << msg.processed_ << " from " << inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; - - if (msg.processed_ == msg.getSize()) { - LOG(INFO) << "Receive a " << msg.processed_ << " bytes DATA message from " << inet_ntoa(ep->addr_.sin_addr) << " with id " << msg.id_; - ep->recv_.push(new Message(static_cast<Message&&>(msg))); - // notify of waiting thread - ep->cv_.notify_one(); - ep->send_.push(new Message(MSG_ACK, msg.id_)); - msg.processed_ = 0; - } - } + } + + // reset + msg.processed_ -= sizeof(msg.type_) + sizeof(msg.id_); + memmove(msg.mdata_, msg.mdata_ + sizeof(msg.type_) + sizeof(msg.id_), + msg.processed_); + + } else + break; + } + + if (msg.processed_ < Message::hsize_) { + continue; + } + + // got the whole metadata; + readInteger(msg.mdata_, msg.type_, msg.id_, msg.msize_, msg.psize_); + + LOG(INFO) << "Receive a message: id = " << msg.id_ + << ", msize_ = " << msg.msize_ << ", psize_ = " << msg.psize_ + << " from " << inet_ntoa(ep->addr_.sin_addr) << " over fd " + << fd; + } + + // start reading the real data + if (msg.msg_ == nullptr) { + msg.msg_ = new char[msg.getSize()]; + memcpy(msg.msg_, msg.mdata_, Message::hsize_); + } + + nread = read(fd, msg.msg_ + msg.processed_, msg.getSize() - msg.processed_); + if (nread <= 0) { + if (errno != EWOULDBLOCK || nread == 0) { + // socket error or shuts down + if (nread < 0) + LOG(INFO) << "Fail to receive from EndPoint " + << inet_ntoa(ep->addr_.sin_addr) << ": " << strerror(errno); + else + LOG(INFO) << "Fail to receive from EndPoint " + << inet_ntoa(ep->addr_.sin_addr) + << ": Connection reset by remote side"; + handleConnLost(fd, ep); + } + break; + } + + msg.processed_ += nread; + + // LOG(INFO) << "Receive a message: id = " << msg.id_ << ", msize_ = " << + // msg.msize_ << ", psize_ = " << msg.psize_ << ", processed_ = " << + // msg.processed_ << " from " << inet_ntoa(ep->addr_.sin_addr) << " over fd + // " << fd; + + if (msg.processed_ == msg.getSize()) { + LOG(INFO) << "Receive a " << msg.processed_ << " bytes DATA message from " + << inet_ntoa(ep->addr_.sin_addr) << " with id " << msg.id_; + ep->recv_.push(new Message(static_cast<Message &&>(msg))); + // notify of waiting thread + ep->cv_.notify_one(); + ep->send_.push(new Message(MSG_ACK, msg.id_)); + msg.processed_ = 0; + } + } } /** @@ -737,59 +770,62 @@ void NetworkThread::onRecv(int fd) { * @param ep * @param reconn */ -void NetworkThread::handleConnLost(int fd, EndPoint* ep, bool reconn) { - CHECK(fd >= 0); - LOG(INFO) << "Lost connection to EndPoint " << inet_ntoa(ep->addr_.sin_addr) << ", fd = " << fd; - - this->pending_msgs_.erase(fd); - this->fd_ep_map_.erase(fd); - ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]); - ev_io_stop(loop_, &this->fd_rwatcher_map_[fd]); - fd_wwatcher_map_.erase(fd); - fd_rwatcher_map_.erase(fd); - close(fd); - - if (fd == ep->pfd_) { - if (!ep->send_.empty()) - ep->send_.front()->processed_ = 0; - } - - int sfd = (fd == ep->fd_[0]) ? ep->fd_[1] : ep->fd_[0]; - if (fd == ep->fd_[0]) - ep->fd_[0] = -1; - else - ep->fd_[1] = -1; - - if (reconn) { - // see if the other fd is alive or not - if (sfd < 0) { - if (ep->conn_status_ == CONN_EST) - ev_timer_stop(loop_, &ep->timer_); - if (ep->retry_cnt_ < MAX_RETRY_CNT) { - // notify myself for retry - ep->retry_cnt_++; - ep->conn_status_ = CONN_INIT; - LOG(INFO) << "Reconnect to EndPoint " << inet_ntoa(ep->addr_.sin_addr); - this->notify(SIG_EP); - } else { - LOG(INFO) << "Maximum retry count achieved for EndPoint " << inet_ntoa(ep->addr_.sin_addr); - ep->conn_status_ = CONN_ERROR; - - // notify all threads that this ep is no longer connected - ep->cv_.notify_all(); - } - } else { - if (!ep->is_socket_loop_) { - // if there is another working fd, set this fd as primary and - // send data over this fd - ep->pfd_ = sfd; - ep->last_msg_time_ = ev_now(loop_); - asyncSendPendingMsg(ep); - } else { - handleConnLost(sfd, ep); - } - } - } +void NetworkThread::handleConnLost(int fd, EndPoint *ep, bool reconn) { + CHECK(fd >= 0); + LOG(INFO) << "Lost connection to EndPoint " << inet_ntoa(ep->addr_.sin_addr) + << ", fd = " << fd; + + this->pending_msgs_.erase(fd); + this->fd_ep_map_.erase(fd); + ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]); + ev_io_stop(loop_, &this->fd_rwatcher_map_[fd]); + fd_wwatcher_map_.erase(fd); + fd_rwatcher_map_.erase(fd); + close(fd); + + if (fd == ep->pfd_) { + if (!ep->send_.empty()) + ep->send_.front()->processed_ = 0; + } + + int sfd = (fd == ep->fd_[0]) ? ep->fd_[1] : ep->fd_[0]; + if (fd == ep->fd_[0]) + ep->fd_[0] = -1; + else + ep->fd_[1] = -1; + + if (reconn) { + // see if the other fd is alive or not + if (sfd < 0) { + if (ep->conn_status_ == CONN_EST) + ev_timer_stop(loop_, &ep->timer_); + if (ep->retry_cnt_ < MAX_RETRY_CNT) { + // notify myself for retry + ep->retry_cnt_++; + ep->conn_status_ = CONN_INIT; + LOG(INFO) << "Reconnect to EndPoint " << inet_ntoa(ep->addr_.sin_addr); + this->notify(SIG_EP); + } else { + LOG(INFO) << "Maximum retry count achieved for EndPoint " + << inet_ntoa(ep->addr_.sin_addr); + ep->conn_status_ = CONN_ERROR; + + // notify all threads that this ep is no longer connected + ep->cv_.notify_all(); + } + } else { + if (!ep->is_socket_loop_) { + // if there is another working fd, set this fd as primary and + // send data over this fd + ep->pfd_ = sfd; + ep->last_msg_time_ = ev_now(loop_); + asyncSendPendingMsg(ep); + } else { + handleConnLost(sfd, ep); + } + } + } } - } + +#endif // ENABLE_DIST http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/src/io/network/message.cc ---------------------------------------------------------------------- diff --git a/src/io/network/message.cc b/src/io/network/message.cc index 5bf9b8e..32f29b7 100644 --- a/src/io/network/message.cc +++ b/src/io/network/message.cc @@ -18,6 +18,8 @@ * under the License. * *************************************************************/ +#include "singa/singa_config.h" +#ifdef ENABLE_DIST #include <cstdlib> #include <cstring> @@ -25,68 +27,69 @@ #include <atomic> #include "singa/io/network.h" -#include "singa/io/integer.h" +#include "singa/utils/integer.h" namespace singa { -Message::Message(Message&& msg) { - std::swap(msize_, msg.msize_); - std::swap(psize_, msg.psize_); - std::swap(msg_, msg.msg_); - std::swap(type_, msg.type_); - std::swap(id_, msg.id_); +Message::Message(Message &&msg) { + std::swap(msize_, msg.msize_); + std::swap(psize_, msg.psize_); + std::swap(msg_, msg.msg_); + std::swap(type_, msg.type_); + std::swap(id_, msg.id_); } -Message::Message(int type, uint32_t ack_msg_id): type_(type), id_(ack_msg_id) { - if (type_ == MSG_ACK) - appendInteger(mdata_, type_, id_); +Message::Message(int type, uint32_t ack_msg_id) : type_(type), id_(ack_msg_id) { + if (type_ == MSG_ACK) + appendInteger(mdata_, type_, id_); } Message::~Message() { - if (msg_) - free(msg_); + if (msg_) + free(msg_); } std::size_t Message::getSize() { - if (type_ == MSG_ACK) - return sizeof(type_) + sizeof(id_); - else - return this->hsize_ + this->psize_ + this->msize_; + if (type_ == MSG_ACK) + return sizeof(type_) + sizeof(id_); + else + return this->hsize_ + this->psize_ + this->msize_; } void Message::setId(uint32_t id) { - this->id_ = id; - appendInteger(msg_, type_, id_); + this->id_ = id; + appendInteger(msg_, type_, id_); } -void Message::setMetadata(const void* buf, int size) { - this->msize_ = size; - msg_ = (char*) malloc (this->getSize()); - appendInteger(msg_, type_, id_, msize_, psize_); - memcpy(msg_ + hsize_, buf, size); +void Message::setMetadata(const void *buf, int size) { + this->msize_ = size; + msg_ = (char *)malloc(this->getSize()); + appendInteger(msg_, type_, id_, msize_, psize_); + memcpy(msg_ + hsize_, buf, size); } -void Message::setPayload(const void* buf, int size) { - this->psize_ = size; - msg_ = (char*) realloc(msg_, this->getSize()); - appendInteger(msg_ + hsize_ - sizeof(psize_), psize_); - memcpy(msg_ + hsize_ + msize_, buf, size); +void Message::setPayload(const void *buf, int size) { + this->psize_ = size; + msg_ = (char *)realloc(msg_, this->getSize()); + appendInteger(msg_ + hsize_ - sizeof(psize_), psize_); + memcpy(msg_ + hsize_ + msize_, buf, size); } -std::size_t Message::getMetadata(void** p) { - if (this->msize_ == 0) - *p = nullptr; - else - *p = msg_ + hsize_; - return this->msize_; +std::size_t Message::getMetadata(void **p) { + if (this->msize_ == 0) + *p = nullptr; + else + *p = msg_ + hsize_; + return this->msize_; } -std::size_t Message::getPayload(void** p) { - if (this->psize_ == 0) - *p = nullptr; - else - *p = msg_ + hsize_ + msize_; - return this->psize_; +std::size_t Message::getPayload(void **p) { + if (this->psize_ == 0) + *p = nullptr; + else + *p = msg_ + hsize_ + msize_; + return this->psize_; } - } + +#endif // ENABLE_DIST http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/test/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fda871d..f196928 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,22 +1,27 @@ INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/include) + +IF(ENABLE_DIST) + ADD_EXECUTABLE(test_ep "singa/test_ep.cc") + ADD_DEPENDENCIES(test_ep singa_io) + TARGET_LINK_LIBRARIES(test_ep singa_utils singa_io protobuf ${SINGA_LINKER_LIBS}) +ENDIF() + ADD_LIBRARY(gtest STATIC EXCLUDE_FROM_ALL "gtest/gtest.h" "gtest/gtest-all.cc") AUX_SOURCE_DIRECTORY(singa singa_test_source) +LIST(REMOVE_ITEM singa_test_source "singa/test_ep.cc") IF(NOT USE_OPENCL) MESSAGE(STATUS "Skipping OpenCL tests") LIST(REMOVE_ITEM singa_test_source "singa/test_opencl.cc") ENDIF() + ADD_EXECUTABLE(test_singa "gtest/gtest_main.cc" ${singa_test_source}) ADD_DEPENDENCIES(test_singa singa_core singa_utils) -MESSAGE(STATUS "link libs" ${singa_linker_libs}) +#MESSAGE(STATUS "link libs" ${singa_linker_libs}) TARGET_LINK_LIBRARIES(test_singa gtest singa_core singa_utils singa_model singa_io proto protobuf ${SINGA_LINKER_LIBS}) -#SET_TARGET_PROPERTIES(test_singa PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread") +SET_TARGET_PROPERTIES(test_singa PROPERTIES LINK_FLAGS "${LINK_FLAGS} -pthread ") -#ADD_EXECUTABLE(test_ep "singa/test_ep.cc") -#ADD_DEPENDENCIES(test_ep singa_io) -#TARGET_LINK_LIBRARIES(test_ep singa_core singa_utils singa_model -# singa_io proto protobuf ${SINGA_LINKER_LIBS}) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/test/singa/test_convolution.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_convolution.cc b/test/singa/test_convolution.cc index b5f3605..c3ddcee 100644 --- a/test/singa/test_convolution.cc +++ b/test/singa/test_convolution.cc @@ -18,6 +18,9 @@ * under the License. * *************************************************************/ +#include "singa/singa_config.h" + +#ifdef USE_CBLAS #include "../src/model/layer/convolution.h" #include "gtest/gtest.h" @@ -202,3 +205,4 @@ TEST(Convolution, Backward) { dwptr[7]); EXPECT_FLOAT_EQ(dy[0] * x[4] + dy[4] * x[13], dwptr[8]); } +#endif // USE_CBLAS http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/45620d59/test/singa/test_ep.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_ep.cc b/test/singa/test_ep.cc index cc04064..0d862e5 100644 --- a/test/singa/test_ep.cc +++ b/test/singa/test_ep.cc @@ -18,95 +18,96 @@ * under the License. * *************************************************************/ - -#include "singa/io/integer.h" +#include "singa/singa_config.h" +#ifdef ENABLE_DIST #include "singa/io/network.h" +#include "singa/utils/integer.h" +#include "singa/utils/logging.h" #include <assert.h> #include <unistd.h> +#include <string.h> +#include <memory> -#include "singa/utils/logging.h" #define SIZE 10000000 #define PORT 10000 #define ITER 10 using namespace singa; -int main(int argc, char** argv) { - char* md = new char[SIZE]; - char* payload = new char[SIZE]; - - char* host = "localhost"; - int port = PORT; - - for (int i = 1; i < argc; ++i) - { - if (strcmp(argv[i], "-p") == 0) - port = atoi(argv[++i]); - else if (strcmp(argv[i], "-h") == 0) - host = argv[++i]; - else - fprintf(stderr, "Invalid option %s\n", argv[i]); - } - - memset(md, 'a', SIZE); - memset(payload, 'b', SIZE); - - NetworkThread* t = new NetworkThread(port); - - EndPointFactory* epf = t->epf_; - - // sleep - sleep(3); - - EndPoint* ep = epf->getEp(host); - - Message* m[ITER]; - for (int i = 0; i < ITER; ++i) - { - m[i] = new Message(); - m[i]->setMetadata(md, SIZE); - m[i]->setPayload(payload, SIZE); +int main(int argc, char **argv) { + char *md = new char[SIZE]; + char *payload = new char[SIZE]; + + const char *host = "localhost"; + int port = PORT; + + for (int i = 1; i < argc; ++i) { + if (strcmp(argv[i], "-p") == 0) + port = atoi(argv[++i]); + else if (strcmp(argv[i], "-h") == 0) + host = argv[++i]; + else + fprintf(stderr, "Invalid option %s\n", argv[i]); + } + + memset(md, 'a', SIZE); + memset(payload, 'b', SIZE); + + NetworkThread *t = new NetworkThread(port); + + EndPointFactory *epf = t->epf_; + + // sleep + sleep(3); + + EndPoint *ep = epf->getEp(host); + + Message *m[ITER]; + for (int i = 0; i < ITER; ++i) { + m[i] = new Message(); + m[i]->setMetadata(md, SIZE); + m[i]->setPayload(payload, SIZE); + } + + while (1) { + for (int i = 0; i < ITER; ++i) { + if (ep->send(m[i]) < 0) + return 1; + delete m[i]; } - while (1) { - for (int i = 0; i < ITER; ++i) - { - if (ep->send(m[i]) < 0) return 1; - delete m[i]; - } - - for (int i = 0; i < ITER; ++i) - { - m[i] = ep->recv(); - if (!m[i]) - return 1; - char *p; - CHECK(m[i]->getMetadata((void**)&p) == SIZE); - CHECK(0 == strncmp(p, md, SIZE)); - CHECK(m[i]->getPayload((void**)&p) == SIZE); - CHECK(0 == strncmp(p, payload, SIZE)); - } + for (int i = 0; i < ITER; ++i) { + m[i] = ep->recv(); + if (!m[i]) + return 1; + char *p; + CHECK(m[i]->getMetadata((void **)&p) == SIZE); + CHECK(0 == strncmp(p, md, SIZE)); + CHECK(m[i]->getPayload((void **)&p) == SIZE); + CHECK(0 == strncmp(p, payload, SIZE)); } + } - //while(ep && cnt++ <= 5 && ep->send(m) > 0 ) { + // while(ep && cnt++ <= 5 && ep->send(m) > 0 ) { - // LOG(INFO) << "Send a " << m->getSize() << " bytes message"; + // LOG(INFO) << "Send a " << m->getSize() << " bytes message"; - // Message* m1 = ep->recv(); + // Message* m1 = ep->recv(); - // if (!m1) - // break; + // if (!m1) + // break; - // char *p; + // char *p; - // LOG(INFO) << "Receive a " << m1->getSize() << " bytes message"; + // LOG(INFO) << "Receive a " << m1->getSize() << " bytes message"; - // CHECK(m1->getMetadata((void**)&p) == SIZE); - // CHECK(0 == strncmp(p, md, SIZE)); - // CHECK(m1->getPayload((void**)&p) == SIZE); - // CHECK(0 == strncmp(p, payload, SIZE)); + // CHECK(m1->getMetadata((void**)&p) == SIZE); + // CHECK(0 == strncmp(p, md, SIZE)); + // CHECK(m1->getPayload((void**)&p) == SIZE); + // CHECK(0 == strncmp(p, payload, SIZE)); - // delete m; - // m = m1; - //} + // delete m; + // m = m1; + //} } +#endif // ENABLE_DIST
