Repository: incubator-singa
Updated Branches:
  refs/heads/dev 541ad6898 -> 17bfb1967


SINGA-233 A new communication framework for SINGA

In this ticket, we implement a new communication framework for SINGA. We 
abstract each physical computing node as an endpoint, and add two interfaces, 
i.e., send and recv, to the endpoint so that users can directly call them to 
accomplish data transfer.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/889abf8a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/889abf8a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/889abf8a

Branch: refs/heads/dev
Commit: 889abf8aea1e5eeef82cc6de65be5f5e20865398
Parents: db5478e
Author: caiqc <[email protected]>
Authored: Thu Aug 4 15:51:47 2016 +0800
Committer: caiqc <[email protected]>
Committed: Wed Aug 10 09:42:20 2016 +0800

----------------------------------------------------------------------
 CMakeLists.txt                      |  11 +-
 include/singa/io/network/endpoint.h | 130 +++++++
 include/singa/io/network/integer.h  |  73 ++++
 include/singa/io/network/message.h  |  75 ++++
 src/CMakeLists.txt                  |   1 +
 src/io/network/endpoint.cc          | 650 +++++++++++++++++++++++++++++++
 src/io/network/message.cc           | 115 ++++++
 test/singa/test_ep.cc               |  69 ++++
 8 files changed, 1120 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23f8ef6..4e529f0 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -18,12 +18,12 @@ SET(SINGA_INCLUDE_DIR
     
"${CMAKE_SOURCE_DIR}/include;${CMAKE_SOURCE_DIR}/lib/cnmem/include;${PROJECT_BINARY_DIR}")
 INCLUDE_DIRECTORIES(${SINGA_INCLUDE_DIR})
 
-OPTION(USE_CBLAS "Use CBlas libs" ON)
-OPTION(USE_CUDA "Use Cuda libs" ON)
-OPTION(USE_CUDNN "Use Cudnn libs" ON)
+OPTION(USE_CBLAS "Use CBlas libs" OFF)
+OPTION(USE_CUDA "Use Cuda libs" OFF)
+OPTION(USE_CUDNN "Use Cudnn libs" OFF)
 OPTION(USE_OPENCV "Use opencv" OFF)
 OPTION(USE_LMDB "Use LMDB libs" OFF)
-OPTION(USE_PYTHON "Generate py wrappers" ON)
+OPTION(USE_PYTHON "Generate py wrappers" OFF)
 OPTION(USE_OPENCL "Use OpenCL" OFF)
 #OPTION(BUILD_OPENCL_TESTS "Build OpenCL tests" OFF)
 
@@ -46,6 +46,9 @@ IF (USE_CUDA)
     ADD_SUBDIRECTORY(lib/cnmem)
     LIST(APPEND SINGA_LINKER_LIBS cnmem)
 ENDIF()
+
+LIST(APPEND SINGA_LINKER_LIBS ev)
+
 ADD_SUBDIRECTORY(src)
 ADD_SUBDIRECTORY(test)
 ADD_SUBDIRECTORY(examples)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/include/singa/io/network/endpoint.h
----------------------------------------------------------------------
diff --git a/include/singa/io/network/endpoint.h 
b/include/singa/io/network/endpoint.h
new file mode 100644
index 0000000..1079fcc
--- /dev/null
+++ b/include/singa/io/network/endpoint.h
@@ -0,0 +1,130 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_COMM_END_POINT_H_
+#define SINGA_COMM_END_POINT_H_
+
+#include <ev.h>
+#include <thread>
+#include <unordered_map>
+#include <map>
+#include <vector>
+#include <condition_variable>
+#include <mutex>
+#include <atomic>
+#include <string>
+#include <netinet/in.h>
+
+#include "singa/io/network/message.h"
+
+namespace singa {
+
+#define LOCKED 1
+#define UNLOCKED 0
+
+#define SIG_EP 1
+#define SIG_MSG 2
+
+#define CONN_INIT 0
+#define CONN_PENDING 1
+#define CONN_EST 2
+#define CONN_ERROR 3
+
+#define MAX_RETRY_CNT 3
+
+class NetworkThread;
+class EndPointFactory;
+
+class EndPoint {
+    private:
+        std::queue<Message*> send_;
+        std::queue<Message*> recv_;
+        std::queue<Message*> to_ack_;
+        std::condition_variable cv_;
+        std::mutex mtx_;
+        struct sockaddr_in addr_;
+        int fd_[2] = {-1, -1}; // two endpoints simultaneously connect to each 
other
+        int conn_status_ = CONN_INIT;
+        int pending_cnt_ = 0;
+        int retry_cnt_ = 0;
+        NetworkThread* thread_ = nullptr;
+        EndPoint(NetworkThread* t):thread_(t){}
+        ~EndPoint();
+        friend class NetworkThread;
+        friend class EndPointFactory;
+    public:
+        int send(Message*);
+        Message* recv();
+};
+
+class EndPointFactory {
+    private:
+        std::unordered_map<uint32_t, EndPoint*> ip_ep_map_;
+        std::condition_variable map_cv_;
+        std::mutex map_mtx_;
+        NetworkThread* thread_;
+        EndPoint* getEp(uint32_t ip);
+        EndPoint* getOrCreateEp(uint32_t ip);
+        friend class NetworkThread;
+    public:
+        EndPointFactory(NetworkThread* thread) : thread_(thread) {}
+        ~EndPointFactory();
+        EndPoint* getEp(const char* host);
+        void getNewEps(std::vector<EndPoint*>& neps);
+};
+
+class NetworkThread{
+    private:
+        struct ev_loop *loop_;
+        ev_async ep_sig_;
+        ev_async msg_sig_;
+
+        std::thread* thread_;
+
+        std::unordered_map<int, ev_io> fd_wwatcher_map_;
+        std::unordered_map<int, ev_io> fd_rwatcher_map_;
+        std::unordered_map<int, uint32_t> fd_ip_map_;
+
+        std::map<int, Message> pending_msgs_;
+
+        ev_io socket_watcher_;
+        int port_;
+        int socket_fd_;
+
+        void handleConnLost(int, EndPoint*, bool = true);
+        void doWork();
+        int asyncSend(int);
+        void asyncSendPendingMsg(EndPoint*);
+    public:
+        EndPointFactory* epf_;
+
+        NetworkThread(int);
+        //void join();
+        void notify(int signal);
+
+        void onRecv(int fd);
+        void onSend(int fd = -1);
+        void onConnEst(int fd);
+        void onNewEp();
+        void onNewConn();
+};
+}
+#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/include/singa/io/network/integer.h
----------------------------------------------------------------------
diff --git a/include/singa/io/network/integer.h 
b/include/singa/io/network/integer.h
new file mode 100644
index 0000000..9c2799d
--- /dev/null
+++ b/include/singa/io/network/integer.h
@@ -0,0 +1,73 @@
+/************************************************************
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *************************************************************/
+
+#ifndef INTEGER_H_
+#define INTEGER_H_
+
+#include <cstdint>
+
+namespace singa{
+static bool isNetworkOrder() {
+    int test = 1;
+    return (1 != *(uint8_t*)&test);
+}
+
+template <typename T>
+static inline T byteSwap(const T& v) {
+    int size = sizeof(v);
+    T ret;
+    uint8_t *dest = reinterpret_cast<uint8_t *>(&ret);
+    uint8_t *src = const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&v));
+    for (int i = 0; i < size; ++i) {
+        dest[i] = src[size - i - 1];
+    }
+    return ret;
+}
+
+template <typename T>
+static inline T hton(const T& v)
+{
+    return isNetworkOrder() ? v : byteSwap(v);
+}
+
+template <typename T>
+static inline T ntoh(const T& v) 
+{
+    return hton(v);
+}
+
+static inline int appendInteger(char* buf) {return 0;}
+static inline int readInteger(char* buf) {return 0;}
+
+template<typename Type, typename... Types>
+static int appendInteger(char* buf, Type value, Types... values) {
+    *(Type*)buf = hton(value);
+    return sizeof(Type) + appendInteger(buf + sizeof(Type), values...);
+}
+
+template<typename Type, typename... Types>
+static int readInteger(char* buf, Type& value, Types&... values) {
+    value = ntoh(*(Type*)buf);
+    return sizeof(Type) + readInteger(buf + sizeof(Type), values...);
+}
+
+}
+#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/include/singa/io/network/message.h
----------------------------------------------------------------------
diff --git a/include/singa/io/network/message.h 
b/include/singa/io/network/message.h
new file mode 100644
index 0000000..0f691c0
--- /dev/null
+++ b/include/singa/io/network/message.h
@@ -0,0 +1,75 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#ifndef SINGA_COMM_MESSAGE_H_
+#define SINGA_COMM_MESSAGE_H_
+
+#include <mutex>
+#include <queue>
+
+namespace singa {
+
+#define MSG_DATA 0
+#define MSG_ACK 1
+
+class NetworkThread;
+class EndPoint;
+class Message{
+    private:
+        uint8_t type_;
+        uint32_t id_;
+        std::size_t msize_ = 0;
+        std::size_t psize_ = 0;
+        std::size_t processed_ = 0;
+        char* msg_ = nullptr;
+        static const int hsize_ = sizeof(id_) + 2 * sizeof(std::size_t) + 
sizeof(type_);
+        char mdata_[hsize_];
+        friend class NetworkThread;
+        friend class EndPoint;
+    public:
+        Message(int = MSG_DATA, uint32_t = 0);
+        Message(const Message&) = delete;
+        Message(Message&&);
+        ~Message();
+
+        void setMetadata(const void*, int);
+        void setPayload(const void*, int);
+
+        std::size_t getMetadata(void**);
+        std::size_t getPayload(void**);
+
+        std::size_t getSize();
+        void setId(uint32_t);
+};
+
+class MessageQueue 
+{
+    public:
+        void push(Message&);
+        Message& front();
+        void pop();
+        std::size_t size();
+    private:
+        std::mutex lock_;
+        std::queue<Message> mqueue_;
+};
+}
+#endif

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/src/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 65a81fc..38c6ac0 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -68,6 +68,7 @@ TARGET_LINK_LIBRARIES(singa_model ${SINGA_LINKER_LIBS})
 LIST(APPEND SINGA_LINKER_LIBS singa_model)
 
 AUX_SOURCE_DIRECTORY(io io_source)
+AUX_SOURCE_DIRECTORY(io/network io_source)
 ADD_LIBRARY(singa_io SHARED ${io_source})
 TARGET_LINK_LIBRARIES(singa_io ${SINGA_LINKER_LIBS})
 LIST(APPEND SINGA_LINKER_LIBS singa_io)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/src/io/network/endpoint.cc
----------------------------------------------------------------------
diff --git a/src/io/network/endpoint.cc b/src/io/network/endpoint.cc
new file mode 100644
index 0000000..2926a05
--- /dev/null
+++ b/src/io/network/endpoint.cc
@@ -0,0 +1,650 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "singa/io/network/endpoint.h"
+#include "singa/io/network/integer.h"
+#include "singa/utils/logging.h"
+
+#include <sys/socket.h>
+#include <netdb.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <string.h>
+#include <arpa/inet.h>
+
+#include <atomic>
+
+namespace singa {
+
+static void async_ep_cb(struct ev_loop* loop, ev_async* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onNewEp();
+}
+
+static void async_msg_cb(struct ev_loop* loop, ev_async* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onSend();
+}
+
+static void writable_cb(struct ev_loop* loop, ev_io* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onSend(ev->fd);
+}
+
+static void readable_cb(struct ev_loop* loop, ev_io* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onRecv(ev->fd);
+}
+
+static void conn_cb(struct ev_loop* loop, ev_io* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onConnEst(ev->fd);
+}
+
+static void accept_cb(struct ev_loop* loop, ev_io* ev, int revent) {
+    reinterpret_cast<NetworkThread*>(ev_userdata(loop))->onNewConn();
+}
+
+EndPoint::~EndPoint() {
+    while(!recv_.empty()) {
+        delete send_.front();
+        send_.pop();
+    }
+    while(!to_ack_.empty()) {
+        delete send_.front();
+        send_.pop();
+    }
+    while(!send_.empty()) {
+        delete send_.front();
+        send_.pop();
+    }
+}
+
+int EndPoint::send(Message* msg) {
+    CHECK(msg->type_ == MSG_DATA);
+    static std::atomic<uint32_t> id(0);
+    std::unique_lock<std::mutex> lock(this->mtx_);
+
+    if (this->conn_status_ == CONN_ERROR) {
+        LOG(INFO) << "EndPoint " << inet_ntoa(addr_.sin_addr) << " is 
disconnected";
+        return -1;
+    }
+
+    if (msg->psize_ == 0 && msg->msize_ == 0)
+        // no data to send
+        return 0;
+
+    msg->setId(id++);
+
+    send_.push(new Message(static_cast<Message&&>(*msg)));
+
+    thread_->notify(SIG_MSG);
+    return msg->getSize();
+}
+
+Message* EndPoint::recv() {
+    std::unique_lock<std::mutex> lock(this->mtx_);
+    while(this->recv_.empty() && conn_status_ != CONN_ERROR)
+        this->cv_.wait(lock);
+
+    Message* ret = nullptr;
+    if (!recv_.empty()) {
+        ret = recv_.front();
+        recv_.pop();
+    }
+    return ret;
+}
+
+EndPointFactory::~EndPointFactory() {
+    for (auto& p : ip_ep_map_) 
+    {
+        delete p.second;
+    }
+}
+
+EndPoint* EndPointFactory::getOrCreateEp(uint32_t ip) {
+    std::unique_lock<std::mutex> lock(map_mtx_);
+    if (0 == ip_ep_map_.count(ip)) {
+        ip_ep_map_[ip] = new EndPoint(this->thread_);
+    }
+    return ip_ep_map_[ip];
+}
+
+EndPoint* EndPointFactory::getEp(uint32_t ip) {
+    std::unique_lock<std::mutex> lock(map_mtx_);
+    if (0 == ip_ep_map_.count(ip)) {
+        return nullptr;
+    }
+    return ip_ep_map_[ip];
+}
+
+EndPoint* EndPointFactory::getEp(const char* host) {
+    // get the ip address of host
+    struct hostent *he;
+    struct in_addr **list;
+
+    if ((he = gethostbyname(host)) == nullptr) {
+        LOG(INFO) << "Unable to resolve host " << host;
+        return nullptr;
+    }
+
+    list = (struct in_addr**) he->h_addr_list;
+    uint32_t ip = ntohl(list[0]->s_addr);
+
+    EndPoint* ep = nullptr;
+    map_mtx_.lock();
+    if (0 == ip_ep_map_.count(ip)) {
+        ep = new EndPoint(this->thread_);
+        ep->thread_ = this->thread_;
+        ip_ep_map_[ip] = ep;
+
+        // copy the address info
+        bcopy(list[0], &ep->addr_.sin_addr, sizeof(struct in_addr));
+
+        thread_->notify(SIG_EP);
+    }
+    ep = ip_ep_map_[ip];
+    map_mtx_.unlock(); 
+
+    std::unique_lock<std::mutex> eplock(ep->mtx_);
+    while (ep->conn_status_ == CONN_PENDING || ep->conn_status_ == CONN_INIT) {
+        ep->pending_cnt_++;
+        ep->cv_.wait(eplock);
+        ep->pending_cnt_--;
+    }
+
+    if (ep->conn_status_ == CONN_ERROR) {
+        ep = nullptr;
+    }
+
+    return ep;
+}
+
+void EndPointFactory::getNewEps(std::vector<EndPoint*>& neps) {
+    std::unique_lock<std::mutex> lock(this->map_mtx_);
+    for (auto& p : this->ip_ep_map_) {
+        EndPoint* ep = p.second;
+        std::unique_lock<std::mutex> eplock(ep->mtx_);
+        if (ep->conn_status_ == CONN_INIT) {
+            neps.push_back(ep);
+        }
+    } 
+}
+
+NetworkThread::NetworkThread(int port) {
+    this->port_ = port;
+    thread_ = new std::thread([this] {doWork();});
+    this->epf_ = new EndPointFactory(this);
+}
+
+void NetworkThread::doWork() {
+
+    // prepare event loop
+    if (!(loop_ = ev_default_loop(0))) {
+        // log here
+    }
+
+    ev_async_init(&ep_sig_, async_ep_cb);
+    ev_async_start(loop_, &ep_sig_);
+
+    ev_async_init(&msg_sig_, async_msg_cb);
+    ev_async_start(loop_, &msg_sig_);
+
+    // bind and listen
+    struct sockaddr_in addr;
+    if ((socket_fd_ = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
+        LOG(FATAL) << "Socket Error: " << strerror(errno);
+    }
+
+    bzero(&addr, sizeof(addr));
+    addr.sin_family = AF_INET;
+    addr.sin_port = htons(this->port_);
+    addr.sin_addr.s_addr = INADDR_ANY;
+
+    if (bind(socket_fd_, (struct sockaddr*)&addr, sizeof(addr))) {
+        LOG(FATAL) << "Bind Error: " << strerror(errno);
+    }
+
+    if (listen(socket_fd_, 10)) {
+        LOG(FATAL) << "Listen Error: " << strerror(errno);
+    }
+
+    ev_io_init(&socket_watcher_, accept_cb, socket_fd_, EV_READ);
+    ev_io_start(loop_, &socket_watcher_);
+
+    ev_set_userdata(loop_, this);
+
+    while(1)
+        ev_run(loop_, 0);
+}
+
+void NetworkThread::notify(int signal) {
+    switch(signal) {
+        case SIG_EP:
+            ev_async_send(this->loop_, &this->ep_sig_);
+            break;
+        case SIG_MSG:
+            ev_async_send(this->loop_, &this->msg_sig_);
+            break;
+        default:
+            break;
+    }
+}
+
+void NetworkThread::onNewEp() {
+    std::vector<EndPoint*> neps;
+    this->epf_->getNewEps(neps);
+
+    for (auto& ep : neps) {
+        std::unique_lock<std::mutex>  ep_lock(ep->mtx_);
+        int& fd = ep->fd_[0];
+        if (ep->conn_status_ == CONN_INIT) {
+
+            fd = socket(AF_INET, SOCK_STREAM, 0);
+            if (fd < 0) {
+                // resources not available
+                LOG(FATAL) << "Unable to create socket";
+            }
+
+            // set this fd non-blocking
+            fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
+
+            this->fd_ip_map_[fd] = ntohl(ep->addr_.sin_addr.s_addr);
+
+            // initialize the addess
+            ep->addr_.sin_family = AF_INET;
+            ep->addr_.sin_port = htons(port_);
+            bzero(&(ep->addr_.sin_zero), 8);
+
+            if (connect(fd, (struct sockaddr*)&ep->addr_, 
+                    sizeof(struct sockaddr)) ) {
+                LOG(INFO) << "Connect Error: " << strerror(errno);
+                if (errno != EINPROGRESS) {
+                    ep->conn_status_ = CONN_ERROR;
+                    ep->cv_.notify_all();
+                    continue;
+                } else {
+                    ep->conn_status_ = CONN_PENDING;
+                    ev_io_init(&this->fd_wwatcher_map_[fd], conn_cb, fd, 
EV_WRITE);
+                    ev_io_start(this->loop_, &this->fd_wwatcher_map_[fd]);
+                }
+            } else {
+                // connection established immediately
+                LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) 
<< " fd = "<< fd;
+                ep->conn_status_ = CONN_EST;
+                ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]);
+
+                // poll for new msgs
+                ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, 
EV_READ);
+                ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]);
+
+                asyncSendPendingMsg(ep);
+                ep->cv_.notify_all();
+            }
+        }
+    }
+}
+
+void NetworkThread::onConnEst(int fd) {
+
+    EndPoint* ep = epf_->getEp(this->fd_ip_map_[fd]);
+
+    std::unique_lock<std::mutex> lock(ep->mtx_);
+
+    if (connect(fd, (struct sockaddr*)&ep->addr_, sizeof(struct sockaddr)) < 0 
&& errno != EISCONN) {
+        LOG(INFO) << "Unable to connect to " << inet_ntoa(ep->addr_.sin_addr) 
<< ": "<< strerror(errno);
+        if (errno == EINPROGRESS) {
+            // continue to watch this socket
+            return;
+        }
+
+        handleConnLost(ep->fd_[0], ep);
+
+        switch(ep->conn_status_) {
+            case CONN_INIT:
+            case CONN_PENDING:
+                return;
+            default:
+                break;
+        }
+
+    } else {
+        LOG(INFO) << "Connected to " << inet_ntoa(ep->addr_.sin_addr) << ", fd 
= "<< fd;
+        ep->conn_status_ = CONN_EST;
+        // connect established; poll for new msgs
+        ev_io_stop(this->loop_, &this->fd_wwatcher_map_[fd]);
+
+        ev_io_init(&this->fd_rwatcher_map_[fd], readable_cb, fd, EV_READ);
+        ev_io_start(this->loop_, &this->fd_rwatcher_map_[fd]);
+    }
+
+    if (ep->conn_status_ == CONN_EST && ep->to_ack_.size() > 0)
+        // if there are pending message, it means these msgs were sent over
+        // previous sockets that have been lost now
+        // we need to resend these msgs to the remote side
+        asyncSendPendingMsg(ep);
+
+    // Finally notify all waiting threads
+    ep->cv_.notify_all();
+}
+
+void NetworkThread::onNewConn() {
+    // accept new tcp connection
+    struct sockaddr_in addr;
+    socklen_t len = sizeof(addr);
+    int fd = accept(socket_fd_, (struct sockaddr*)&addr, &len);
+    if (fd < 0) {
+        LOG(INFO) << "Accept Error: " << strerror(errno);
+        return;
+    }
+
+    LOG(INFO) << "Accept a client from " << inet_ntoa(addr.sin_addr) << ", fd 
= " << fd;
+
+    // set this fd as non-blocking
+    fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
+
+    EndPoint* ep;
+    uint32_t a = ntohl(addr.sin_addr.s_addr);
+
+    ep = epf_->getOrCreateEp(a); 
+    std::unique_lock<std::mutex> lock(ep->mtx_);
+
+    if (ep->fd_[1] >= 0) {
+        // the previous connection is lost
+        handleConnLost(ep->fd_[1], ep, false);
+    }
+
+    if (ep->fd_[0] == fd) {
+        // this fd is reused
+        handleConnLost(fd, ep, false);
+    }
+
+    fd_ip_map_[fd] = a;
+    ev_io_init(&fd_rwatcher_map_[fd], readable_cb, fd, EV_READ);
+    ev_io_start(loop_, &fd_rwatcher_map_[fd]);
+
+    // record the remote address
+    bcopy(&addr, &ep->addr_, len);
+
+    ep->conn_status_ = CONN_EST;
+    ep->fd_[1] = fd;
+
+    if (ep->to_ack_.size() > 0)
+        // see if there are any messages waiting for ack
+        // if yes, resend them
+        asyncSendPendingMsg(ep);
+
+    // this connection is initiaed by remote side, 
+    // so we dont need to notify the waiting thread
+    // later threads wanting to send to this ep, however,
+    // are able to reuse this ep
+}
+
+void NetworkThread::onSend(int fd) {
+    std::vector<int> invalid_fd;
+
+    if (fd == -1) {
+        // this is a signal of new message to send
+        for(auto& p : fd_ip_map_) {
+            // send message
+            if (asyncSend(p.first) < 0)
+                invalid_fd.push_back(p.first);
+        }
+    } else {
+        if (asyncSend(fd) < 0)
+            invalid_fd.push_back(fd);
+    }
+
+    for (auto& p : invalid_fd) {
+        EndPoint* ep = epf_->getEp(fd_ip_map_.at(p));
+        std::unique_lock<std::mutex> lock(ep->mtx_);
+        handleConnLost(p, ep);
+    } 
+}
+
+void NetworkThread::asyncSendPendingMsg(EndPoint* ep) {
+    // simply put the pending msgs to the send queue
+
+    LOG(INFO) << "There are " << ep->send_.size() << " to-send msgs, and " << 
ep->to_ack_.size() << " to-ack msgs";
+
+    while (!ep->send_.empty()) {
+        ep->to_ack_.push(ep->send_.front());
+        ep->send_.pop();
+    }
+
+    std::swap(ep->send_, ep->to_ack_);
+
+    if (ep->send_.size() > 0) {
+        notify(SIG_MSG);
+    }
+}
+
+/**
+ * @brief non-locking send; 
+ *
+ * @param ep
+ *
+ */
+int NetworkThread::asyncSend(int fd) {
+
+    EndPoint* ep = epf_->getEp(fd_ip_map_[fd]);
+
+    std::unique_lock<std::mutex> ep_lock(ep->mtx_);
+
+    if (ep->conn_status_ != CONN_EST)
+        goto out;
+
+    while(!ep->send_.empty()) {
+
+        Message& msg = *ep->send_.front();
+        int nbytes;
+
+        while(msg.processed_ < msg.getSize()) {
+            if (msg.type_ == MSG_ACK) {
+                nbytes = write(fd, msg.mdata_ + msg.processed_, msg.getSize() 
- msg.processed_);
+            }
+            else
+                nbytes = write(fd, msg.msg_ + msg.processed_, msg.getSize() - 
msg.processed_);
+
+            LOG(INFO) << "Send " << nbytes << " bytes to " << 
inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd;
+
+            if (nbytes == -1) {
+                if (errno == EWOULDBLOCK) {
+                    ev_io_init(&fd_wwatcher_map_[fd], writable_cb, fd, 
EV_WRITE);
+                    ev_io_start(loop_, &fd_wwatcher_map_[fd]);
+                    goto out;
+                } else {
+                    // this connection is lost; reset the send status
+                    // so that next time the whole msg would be sent entirely
+                    msg.processed_ = 0;
+                    goto err;
+                }
+            } else 
+                msg.processed_ += nbytes;
+        }
+
+        CHECK(msg.processed_ == msg.getSize());
+
+        if (msg.type_ != MSG_ACK) {
+            msg.processed_ = 0;
+            ep->to_ack_.push(&msg);
+        } else {
+            delete &msg;
+        }
+
+        ep->send_.pop();
+
+        // for test
+        if (ep->retry_cnt_ == 0)
+            close(fd);
+    }
+out:
+    if (ep->send_.empty())
+        ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]);
+    return 0;
+err:
+    return -1;
+}
+
+void NetworkThread::onRecv(int fd) {
+
+    Message* m = &pending_msgs_[fd];
+    Message& msg = (*m);
+    int nread;
+    EndPoint* ep = epf_->getEp(fd_ip_map_[fd]);
+
+    LOG(INFO) << "Start to read from EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << " over fd " << fd; 
+
+    std::unique_lock<std::mutex> lock(ep->mtx_);
+    while(1) {
+
+        if (msg.processed_ < Message::hsize_) {
+            nread = read(fd, msg.mdata_ + msg.processed_, 
+                    Message::hsize_ - msg.processed_);
+
+            if (nread <= 0) {
+                if (errno != EWOULDBLOCK || nread == 0) {
+                    // socket error or shuts down 
+                    if (nread < 0)
+                        LOG(INFO) << "Faile to receive from EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << ": " << strerror(errno);
+                    else
+                        LOG(INFO) << "Faile to receive from EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << ": Connection reset by remote side";
+                    handleConnLost(fd, ep);
+                }
+                break;
+            }
+
+            msg.processed_ += nread;
+            while (msg.processed_ >= sizeof(msg.type_) + sizeof(msg.id_)) {
+                readInteger(msg.mdata_, msg.type_, msg.id_);
+                if(msg.type_ == MSG_ACK) {
+                    LOG(INFO) << "Receive an ACK message from " << 
inet_ntoa(ep->addr_.sin_addr) << " for MSG " << msg.id_;
+                    while (!ep->to_ack_.empty()) {
+                        Message* m = ep->to_ack_.front();
+                        if (m->id_ <= msg.id_) {
+                            delete m;
+                            ep->to_ack_.pop();
+                        } else {
+                            break;
+                        }
+                    }
+
+                    // reset
+                    msg.processed_ -= sizeof(msg.type_) + sizeof(msg.id_);
+                    memmove(msg.mdata_, 
+                            msg.mdata_ + sizeof(msg.type_) + sizeof(msg.id_),
+                            msg.processed_);
+
+                } else break;
+            }
+
+            if (msg.processed_ < Message::hsize_) {
+                continue;
+            }
+
+            // got the whole metadata; 
+            readInteger(msg.mdata_, msg.type_, msg.id_, msg.msize_, 
msg.psize_);
+            LOG(INFO) << "Receive a message: id = " << msg.id_ << ", msize_ = 
" << msg.msize_ << ", psize_ = " << msg.psize_;
+        }
+
+        // start reading the real data
+        if (msg.msg_ == nullptr) {
+            msg.msg_ = new char[msg.getSize()];
+            memcpy(msg.msg_, msg.mdata_, Message::hsize_);
+        }
+
+        nread = read(fd, msg.msg_ + msg.processed_, msg.getSize() - 
msg.processed_);
+        if (nread <= 0) {
+            if (errno != EWOULDBLOCK || nread == 0) {
+                // socket error or shuts down 
+                if (nread < 0)
+                    LOG(INFO) << "Faile to receive from EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << ": " << strerror(errno);
+                else
+                    LOG(INFO) << "Faile to receive from EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << ": Connection reset by remote side";
+                handleConnLost(fd, ep);
+            }
+            break;
+        }
+
+        msg.processed_ += nread;
+        if (msg.processed_ == msg.getSize()) {
+            LOG(INFO) << "Receive a " << msg.processed_ << " bytes DATA 
message from " << inet_ntoa(ep->addr_.sin_addr) << " with id " << msg.id_;
+            ep->recv_.push(new Message(static_cast<Message&&>(msg)));
+            // notify of waiting thread
+            ep->cv_.notify_one();
+            ep->send_.push(new Message(MSG_ACK, msg.id_));
+            msg.processed_ = 0;
+        }
+    }
+}
+
+/**
+ * @brief clean up for the lost connection; the caller should acquire the lock
+ * for the respective endpoint
+ *
+ * @param fd
+ * @param ep
+ * @param reconn
+ */
+void NetworkThread::handleConnLost(int fd, EndPoint* ep, bool reconn) {
+    CHECK(fd >= 0);
+    LOG(INFO) << "Lost connection to EndPoint " << 
inet_ntoa(ep->addr_.sin_addr) << ", fd = " << fd;
+
+    this->pending_msgs_.erase(fd);
+    this->fd_ip_map_.erase(fd);
+    ev_io_stop(loop_, &this->fd_wwatcher_map_[fd]);
+    ev_io_stop(loop_, &this->fd_rwatcher_map_[fd]);
+    fd_wwatcher_map_.erase(fd);
+    fd_rwatcher_map_.erase(fd);
+    close(fd);
+
+    if (fd == ep->fd_[0] || ep->fd_[0] < 0) {
+        if (!ep->send_.empty())
+            ep->send_.front()->processed_ = 0;
+    }
+
+    if (reconn) {
+
+        int sfd = (fd == ep->fd_[0]) ? ep->fd_[1] : ep->fd_[0];
+
+        if (fd == ep->fd_[0])
+            ep->fd_[0] = -1;
+        else
+            ep->fd_[1] = -1;
+
+        // see if the other fd is ok or not
+        if (sfd < 0) {
+            if (ep->retry_cnt_ < MAX_RETRY_CNT) {
+                // notify myself for retry
+                ep->retry_cnt_++;
+                ep->conn_status_ = CONN_INIT;
+                LOG(INFO) << "Reconnect to EndPoint " << 
inet_ntoa(ep->addr_.sin_addr);
+                this->notify(SIG_EP);
+            } else {
+                LOG(INFO) << "Maximum retry count achieved for EndPoint " << 
inet_ntoa(ep->addr_.sin_addr);
+                ep->conn_status_ = CONN_ERROR;
+                // notify all threads that this ep is no longer connected
+                ep->cv_.notify_all();
+            }
+        } else {
+            // if there is another working fd, try to send data over this fd
+            if (!ep->send_.empty())
+                this->notify(SIG_MSG);
+        }
+    }
+}
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/src/io/network/message.cc
----------------------------------------------------------------------
diff --git a/src/io/network/message.cc b/src/io/network/message.cc
new file mode 100644
index 0000000..ac3fc14
--- /dev/null
+++ b/src/io/network/message.cc
@@ -0,0 +1,115 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include <cstdlib>
+#include <cstring>
+
+#include <atomic>
+
+#include "singa/io/network/message.h"
+#include "singa/io/network/integer.h"
+
+namespace singa {
+
+Message::Message(Message&& msg) {
+    std::swap(msize_, msg.msize_);
+    std::swap(psize_, msg.psize_);
+    std::swap(msg_, msg.msg_);
+    std::swap(type_, msg.type_);
+}
+
+Message::Message(int type, uint32_t ack_msg_id): type_(type), id_(ack_msg_id) {
+    if (type_ == MSG_ACK)
+        appendInteger(mdata_, type_, id_);
+}
+
+Message::~Message() {
+    if (msg_)
+        free(msg_);
+}
+
+std::size_t Message::getSize() {
+    if (type_ == MSG_ACK) 
+        return sizeof(type_) + sizeof(id_);
+    else
+        return this->hsize_ + this->psize_ + this->msize_;
+}
+
+void Message::setId(uint32_t id) {
+    this->id_ = id;
+    appendInteger(msg_, type_, id_);
+}
+
+void Message::setMetadata(const void* buf, int size) {
+    this->msize_ = size;
+    msg_ = (char*) malloc (this->getSize());
+    appendInteger(msg_, type_, id_, msize_, psize_);
+    memcpy(msg_ + hsize_, buf, size);
+}
+
+void Message::setPayload(const void* buf, int size) {
+    this->psize_ = size;
+    msg_ = (char*) realloc(msg_, this->getSize());
+    appendInteger(msg_ + hsize_ - sizeof(psize_), psize_);
+    memcpy(msg_ + hsize_ + msize_, buf, size);
+}
+
+std::size_t Message::getMetadata(void** p) {
+    if (this->msize_ == 0)
+        *p = nullptr;
+    else 
+        *p = msg_ + hsize_;
+    return this->msize_;
+}
+
+std::size_t Message::getPayload(void** p) { 
+    if (this->psize_ == 0)
+        *p = nullptr;
+    else 
+        *p = msg_ + hsize_ + msize_;
+    return this->psize_;
+}
+
+void MessageQueue::push(Message& msg) {
+    this->lock_.lock();
+    this->mqueue_.push(static_cast<Message&&>(msg));
+    this->lock_.unlock();
+}
+
+void MessageQueue::pop() {
+    this->lock_.lock();
+    this->mqueue_.pop();
+    this->lock_.unlock();
+}
+
+Message& MessageQueue::front() {
+    this->lock_.lock();
+    Message& ret = this->mqueue_.front();
+    this->lock_.unlock();
+    return ret;
+}
+
+std::size_t MessageQueue::size() {
+    std::unique_lock<std::mutex> lock(lock_);
+    return mqueue_.size();
+}
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/889abf8a/test/singa/test_ep.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_ep.cc b/test/singa/test_ep.cc
new file mode 100644
index 0000000..2435f28
--- /dev/null
+++ b/test/singa/test_ep.cc
@@ -0,0 +1,69 @@
+#include "singa/io/network/endpoint.h"
+#include "singa/io/network/integer.h"
+#include "singa/io/network/message.h"
+#include <assert.h>
+#include <unistd.h>
+
+#include "singa/utils/logging.h"
+
+#define SIZE 100
+#define PORT 10000
+
+using namespace singa;
+int main(int argc, char** argv) {
+    char md[SIZE];
+    char payload[SIZE];
+
+    char* host = "localhost";
+    int port = PORT;
+
+    for (int i = 1; i < argc; ++i)
+    {
+        if (strcmp(argv[i], "-p") == 0)
+            port = atoi(argv[++i]);
+        else if (strcmp(argv[i], "-h") == 0)
+            host = argv[++i];
+        else
+            fprintf(stderr, "Invalid option %s\n", argv[i]);
+    }
+
+    memset(md, 'a', SIZE);
+    memset(payload, 'b', SIZE);
+
+    Message* m = new Message();
+    m->setMetadata(md, SIZE);
+    m->setPayload(payload, SIZE);
+
+    NetworkThread* t = new NetworkThread(port);
+
+    EndPointFactory* epf = t->epf_;
+
+    // sleep
+    sleep(3);
+
+    EndPoint* ep = epf->getEp(host);
+
+    int cnt = 0;
+
+    while(ep && cnt++ <= 100 && ep->send(m) > 0 ) {
+
+        LOG(INFO) << "Send a " << m->getSize() << " bytes message";
+
+        Message* m1 = ep->recv();
+
+        if (!m1)
+            break;
+
+        char *p;
+
+        LOG(INFO) << "Receive a " << m1->getSize() << " bytes message";
+
+        CHECK(m1->getMetadata((void**)&p) == SIZE);
+        CHECK(0 == strncmp(p, md, SIZE));
+        CHECK(m1->getPayload((void**)&p) == SIZE);
+        CHECK(0 == strncmp(p, payload, SIZE));
+
+        delete m;
+        m = m1;
+    }
+}


Reply via email to