common/Unit.hpp | 1 net/ServerSocket.hpp | 7 +- net/Socket.cpp | 14 +++++ net/Socket.hpp | 97 +++++++++++++++++++++++++++------- net/WebSocketHandler.hpp | 4 - net/loolnb.cpp | 8 +- test/UnitFuzz.cpp | 1 wsd/ClientSession.cpp | 8 +- wsd/ClientSession.hpp | 2 wsd/LOOLWSD.cpp | 131 ++++++++++++++++++++++------------------------- 10 files changed, 168 insertions(+), 105 deletions(-)
New commits: commit cc131ab46556c9548c0fa6121d33752905542cfb Author: Michael Meeks <[email protected]> Date: Thu Apr 20 23:11:07 2017 -0400 SocketDisposition - controls socket movement & lifecyle. It is important that moving sockets between polls happens in a safe and readable fashion, this enforces that. Squashes: wsd: clear the socket buffer only after processing the request Because POST requests need to consume the complete request message, we shouldn't clear the buffer before handling the POST request. SocketDisposition: push it down the stack, and cleanup around that. Dung out overlapping return enumerations. Move more work into 'move' callbacks at a safer time, etc. Change-Id: I5deaa32bbb0f9017e3b453010065771e7b2a0c0d Reviewed-on: https://gerrit.libreoffice.org/37389 Reviewed-by: Ashod Nakashian <[email protected]> Tested-by: Michael Meeks <[email protected]> diff --git a/common/Unit.hpp b/common/Unit.hpp index e8197fd1..5f8d20ea 100644 --- a/common/Unit.hpp +++ b/common/Unit.hpp @@ -177,6 +177,7 @@ public: /// Intercept incoming requests, so unit tests can silently communicate virtual bool filterHandleRequest( TestRequest /* type */, + SocketDisposition & /* disposition */, WebSocketHandler & /* handler */) { return false; diff --git a/net/ServerSocket.hpp b/net/ServerSocket.hpp index 805430ea..4d4bb353 100644 --- a/net/ServerSocket.hpp +++ b/net/ServerSocket.hpp @@ -88,8 +88,9 @@ public: void dumpState(std::ostream& os) override; - HandleResult handlePoll(std::chrono::steady_clock::time_point /* now */, - int events) override + void handlePoll(SocketDisposition &, + std::chrono::steady_clock::time_point /* now */, + int events) override { if (events & POLLIN) { @@ -103,8 +104,6 @@ public: LOG_DBG("Accepted client #" << clientSocket->getFD()); _clientPoller.insertNewSocket(clientSocket); } - - return Socket::HandleResult::CONTINUE; } private: diff --git a/net/Socket.cpp b/net/Socket.cpp index b38dd3fe..70b63e78 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -121,6 +121,20 @@ void ServerSocket::dumpState(std::ostream& os) os << "\t" << getFD() << "\t<accept>\n"; } + +void SocketDisposition::execute() +{ + // We should have hard ownership of this socket. + assert(_socket->getThreadOwner() == std::this_thread::get_id()); + if (_socketMove) + { + // Drop pretentions of ownership before _socketMove. + _socket->setThreadOwner(std::thread::id(0)); + _socketMove(_socket); + } + _socketMove = nullptr; +} + namespace { void dump_hex (const char *legend, const char *prefix, std::vector<char> buffer) diff --git a/net/Socket.hpp b/net/Socket.hpp index 70b229b3..89dd746e 100644 --- a/net/Socket.hpp +++ b/net/Socket.hpp @@ -38,6 +38,56 @@ #include "Util.hpp" #include "SigUtil.hpp" +namespace Poco +{ + namespace Net + { + class HTTPResponse; + } +} + +class Socket; + +/// Helper to allow us to easily defer the movement of a socket +/// between polls to clarify thread ownership. +class SocketDisposition +{ + typedef std::function<void(const std::shared_ptr<Socket> &)> MoveFunction; + enum class Type { CONTINUE, CLOSED, MOVE }; + + Type _disposition; + MoveFunction _socketMove; + std::shared_ptr<Socket> _socket; + +public: + SocketDisposition(const std::shared_ptr<Socket> &socket) : + _disposition(Type::CONTINUE), + _socket(socket) + {} + ~SocketDisposition() + { + assert (!_socketMove); + } + void setMove() + { + _disposition = Type::MOVE; + } + void setMove(MoveFunction moveFn) + { + _socketMove = moveFn; + _disposition = Type::MOVE; + } + void setClosed() + { + _disposition = Type::CLOSED; + } + bool isMove() { return _disposition == Type::MOVE; } + bool isClosed() { return _disposition == Type::CLOSED; } + + /// Perform the queued up work. + void execute(); +}; + /// A non-blocking, streaming socket. class Socket { @@ -80,8 +130,9 @@ public: int &timeoutMaxMs) = 0; /// Handle results of events returned from poll - enum class HandleResult { CONTINUE, SOCKET_CLOSED, MOVED }; - virtual HandleResult handlePoll(std::chrono::steady_clock::time_point now, int events) = 0; + virtual void handlePoll(SocketDisposition &disposition, + std::chrono::steady_clock::time_point now, + int events) = 0; /// manage latency issues around packet aggregation void setNoDelay(bool noDelay = true) @@ -192,6 +243,11 @@ public: } } + const std::thread::id &getThreadOwner() + { + return _owner; + } + /// Asserts in the debug builds, otherwise just logs. void assertCorrectThread() { @@ -244,7 +300,6 @@ private: std::thread::id _owner; }; - /// Handles non-blocking socket event polling. /// Only polls on N-Sockets and invokes callback and /// doesn't manage buffers or client data. @@ -401,26 +456,30 @@ public: // Fire the poll callbacks and remove dead fds. std::chrono::steady_clock::time_point newNow = std::chrono::steady_clock::now(); + for (int i = static_cast<int>(size) - 1; i >= 0; --i) { - Socket::HandleResult res = Socket::HandleResult::SOCKET_CLOSED; + SocketDisposition disposition(_pollSockets[i]); try { - res = _pollSockets[i]->handlePoll(newNow, _pollFds[i].revents); + _pollSockets[i]->handlePoll(disposition, newNow, + _pollFds[i].revents); } catch (const std::exception& exc) { LOG_ERR("Error while handling poll for socket #" << _pollFds[i].fd << " in " << _name << ": " << exc.what()); + disposition.setClosed(); } - if (res == Socket::HandleResult::SOCKET_CLOSED || - res == Socket::HandleResult::MOVED) + if (disposition.isMove() || disposition.isClosed()) { LOG_DBG("Removing socket #" << _pollFds[i].fd << " (of " << _pollSockets.size() << ") from " << _name); _pollSockets.erase(_pollSockets.begin() + i); } + + disposition.execute(); } } @@ -589,14 +648,8 @@ public: /// Will be called exactly once. virtual void onConnect(const std::shared_ptr<StreamSocket>& socket) = 0; - enum class SocketOwnership - { - UNCHANGED, //< Same socket poll, business as usual. - MOVED //< The socket poll is now different. - }; - /// Called after successful socket reads. - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() = 0; + virtual void handleIncomingMessage(SocketDisposition &disposition) = 0; /// Prepare our poll record; adjust @timeoutMaxMs downwards /// for timeouts, based on current time @now. @@ -759,15 +812,16 @@ protected: /// Called when a polling event is received. /// @events is the mask of events that triggered the wake. - HandleResult handlePoll(std::chrono::steady_clock::time_point now, - const int events) override + void handlePoll(SocketDisposition &disposition, + std::chrono::steady_clock::time_point now, + const int events) override { assertCorrectThread(); _socketHandler->checkTimeout(now); if (!events) - return Socket::HandleResult::CONTINUE; + return; // FIXME: need to close input, but not output (?) bool closed = (events & (POLLHUP | POLLERR | POLLNVAL)); @@ -787,8 +841,9 @@ protected: while (!_inBuffer.empty() && oldSize != _inBuffer.size()) { oldSize = _inBuffer.size(); - if (_socketHandler->handleIncomingMessage() == SocketHandlerInterface::SocketOwnership::MOVED) - return Socket::HandleResult::MOVED; + _socketHandler->handleIncomingMessage(disposition); + if (disposition.isMove()) + return; } do @@ -823,8 +878,8 @@ protected: _socketHandler->onDisconnect(); } - return _closed ? HandleResult::SOCKET_CLOSED : - HandleResult::CONTINUE; + if (_closed) + disposition.setClosed(); } /// Override to write data out to socket. diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp index a863afad..69c8ed3b 100644 --- a/net/WebSocketHandler.hpp +++ b/net/WebSocketHandler.hpp @@ -250,7 +250,7 @@ public: } /// Implementation of the SocketHandlerInterface. - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + virtual void handleIncomingMessage(SocketDisposition&) override { auto socket = _socket.lock(); if (socket == nullptr) @@ -262,8 +262,6 @@ public: while (handleOneIncomingMessage(socket)) ; // can have multiple msgs in one recv'd packet. } - - return SocketHandlerInterface::SocketOwnership::UNCHANGED; } int getPollEvents(std::chrono::steady_clock::time_point now, diff --git a/net/loolnb.cpp b/net/loolnb.cpp index a014173a..e268b067 100644 --- a/net/loolnb.cpp +++ b/net/loolnb.cpp @@ -45,7 +45,7 @@ public: { } - virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + virtual void handleIncomingMessage(SocketDisposition &disposition) override { LOG_TRC("incoming WebSocket message"); if (_wsState == WSState::HTTP) @@ -89,16 +89,16 @@ public: std::string str = oss.str(); socket->_outBuffer.insert(socket->_outBuffer.end(), str.begin(), str.end()); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } else if (tokens.count() == 2 && tokens[1] == "ws") { upgradeToWebSocket(req); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } } - return WebSocketHandler::handleIncomingMessage(); + WebSocketHandler::handleIncomingMessage(disposition); } virtual void handleMessage(const bool fin, const WSOpCode code, std::vector<char> &data) override diff --git a/test/UnitFuzz.cpp b/test/UnitFuzz.cpp index 49575b5d..68367884 100644 --- a/test/UnitFuzz.cpp +++ b/test/UnitFuzz.cpp @@ -121,6 +121,7 @@ public: virtual bool filterHandleRequest( TestRequest /* type */, + SocketDisposition & /* disposition */, WebSocketHandler & /* socket */) override { #if 0 // loolnb diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp index 11db33af..4c965d7b 100644 --- a/wsd/ClientSession.cpp +++ b/wsd/ClientSession.cpp @@ -50,13 +50,13 @@ ClientSession::~ClientSession() LOG_INF("~ClientSession dtor [" << getName() << "], current number of connections: " << curConnections); } -SocketHandlerInterface::SocketOwnership ClientSession::handleIncomingMessage() +void ClientSession::handleIncomingMessage(SocketDisposition &disposition) { if (UnitWSD::get().filterHandleRequest( - UnitWSD::TestRequest::Client, *this)) - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + UnitWSD::TestRequest::Client, disposition, *this)) + return; - return Session::handleIncomingMessage(); + Session::handleIncomingMessage(disposition); } bool ClientSession::_handleInput(const char *buffer, int length) diff --git a/wsd/ClientSession.hpp b/wsd/ClientSession.hpp index d5ea76df..a8e85127 100644 --- a/wsd/ClientSession.hpp +++ b/wsd/ClientSession.hpp @@ -30,7 +30,7 @@ public: virtual ~ClientSession(); - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override; + void handleIncomingMessage(SocketDisposition &) override; void setReadOnly() override; diff --git a/wsd/LOOLWSD.cpp b/wsd/LOOLWSD.cpp index c6e009c8..d9a9704c 100644 --- a/wsd/LOOLWSD.cpp +++ b/wsd/LOOLWSD.cpp @@ -1370,16 +1370,17 @@ private: } /// Called after successful socket reads. - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + void handleIncomingMessage(SocketDisposition &disposition) override { if (UnitWSD::get().filterHandleRequest( - UnitWSD::TestRequest::Prisoner, *this)) - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + UnitWSD::TestRequest::Prisoner, disposition, *this)) + return; if (_childProcess.lock()) { // FIXME: inelegant etc. - derogate to websocket code - return WebSocketHandler::handleIncomingMessage(); + WebSocketHandler::handleIncomingMessage(disposition); + return; } auto socket = _socket.lock(); @@ -1392,7 +1393,7 @@ private: if (itBody == in.end()) { LOG_TRC("#" << socket->getFD() << " doesn't have enough data yet."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } // Skip the marker. @@ -1424,7 +1425,7 @@ private: if (request.getURI().find(NEW_CHILD_URI) != 0) { LOG_ERR("Invalid incoming URI."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } // New Child is spawned. @@ -1445,7 +1446,7 @@ private: if (pid <= 0) { LOG_ERR("Invalid PID in child URI [" << request.getURI() << "]."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } in.clear(); @@ -1456,24 +1457,20 @@ private: auto child = std::make_shared<ChildProcess>(pid, socket, request); - // Drop pretentions of ownership before adding to the list. - socket->setThreadOwner(std::thread::id(0)); - _childProcess = child; // weak - addNewChild(child); // Remove from prisoner poll since there is no activity // until we attach the childProcess (with this socket) // to a docBroker, which will do the polling. - return SocketHandlerInterface::SocketOwnership::MOVED; + disposition.setMove([child](const std::shared_ptr<Socket> &){ + addNewChild(child); + }); } catch (const std::exception& exc) { // Probably don't have enough data just yet. // TODO: timeout if we never get enough. } - - return SocketHandlerInterface::SocketOwnership::UNCHANGED; } /// Prisoner websocket fun ... (for now) @@ -1528,7 +1525,7 @@ private: } /// Called after successful socket reads. - SocketHandlerInterface::SocketOwnership handleIncomingMessage() override + void handleIncomingMessage(SocketDisposition &disposition) override { auto socket = _socket.lock(); std::vector<char>& in = socket->_inBuffer; @@ -1539,8 +1536,8 @@ private: marker.begin(), marker.end()); if (itBody == in.end()) { - LOG_TRC("#" << socket->getFD() << " doesn't have enough data yet."); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + LOG_DBG("#" << socket->getFD() << " doesn't have enough data yet."); + return; } // Skip the marker. @@ -1575,17 +1572,16 @@ private: if (contentLength != Poco::Net::HTTPMessage::UNKNOWN_CONTENT_LENGTH && available < contentLength) { LOG_DBG("Not enough content yet: ContentLength: " << contentLength << ", available: " << available); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } } catch (const std::exception& exc) { // Probably don't have enough data just yet. // TODO: timeout if we never get enough. - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } - SocketHandlerInterface::SocketOwnership socketOwnership = SocketHandlerInterface::SocketOwnership::UNCHANGED; try { // Routing @@ -1601,12 +1597,13 @@ private: // Admin connections else if (reqPathSegs.size() >= 2 && reqPathSegs[0] == "lool" && reqPathSegs[1] == "adminws") { - LOG_ERR("Admin request: " << request.getURI()); + LOG_INF("Admin request: " << request.getURI()); if (AdminSocketHandler::handleInitialRequest(_socket, request)) { - // Hand the socket over to the Admin poll. - Admin::instance().insertNewSocket(socket); - socketOwnership = SocketHandlerInterface::SocketOwnership::MOVED; + disposition.setMove([](const std::shared_ptr<Socket> &moveSocket){ + // Hand the socket over to the Admin poll. + Admin::instance().insertNewSocket(moveSocket); + }); } } // Client post and websocket connections @@ -1631,11 +1628,11 @@ private: reqPathTokens.count() > 0 && reqPathTokens[0] == "lool") { // All post requests have url prefix 'lool'. - socketOwnership = handlePostRequest(request, message); + handlePostRequest(request, message, disposition); } else if (reqPathTokens.count() > 2 && reqPathTokens[0] == "lool" && reqPathTokens[2] == "ws") { - socketOwnership = handleClientWsUpgrade(request, reqPathTokens[1]); + handleClientWsUpgrade(request, reqPathTokens[1], disposition); } else { @@ -1652,10 +1649,6 @@ private: socket->shutdown(); } } - - // if we succeeded - remove the request from our input buffer - // we expect one request per socket - in.clear(); } catch (const std::exception& exc) { @@ -1665,7 +1658,9 @@ private: LOOLProtocol::getAbbreviatedMessage(in) << "]: " << exc.what()); } - return socketOwnership; + // if we succeeded - remove the request from our input buffer + // we expect one request per socket + in.erase(in.begin(), itBody); } int getPollEvents(std::chrono::steady_clock::time_point /* now */, @@ -1799,14 +1794,14 @@ private: return "application/octet-stream"; } - SocketHandlerInterface::SocketOwnership handlePostRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message) + void handlePostRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message, + SocketDisposition &disposition) { LOG_INF("Post request: [" << request.getURI() << "]"); Poco::Net::HTTPResponse response; auto socket = _socket.lock(); - SocketHandlerInterface::SocketOwnership socketOwnership = SocketHandlerInterface::SocketOwnership::UNCHANGED; StringTokenizer tokens(request.getURI(), "/?"); if (tokens.count() >= 3 && tokens[2] == "convert-to") { @@ -1844,21 +1839,23 @@ private: auto clientSession = createNewClientSession(nullptr, _id, uriPublic, docBroker, isReadOnly); if (clientSession) { - // Transfer the client socket to the DocumentBroker. - socketOwnership = SocketHandlerInterface::SocketOwnership::MOVED; + disposition.setMove([docBroker, clientSession, format] + (const std::shared_ptr<Socket> &moveSocket) + { // Perform all of this after removing the socket // Make sure the thread is running before adding callback. docBroker->startThread(); // We no longer own this socket. - socket->setThreadOwner(std::thread::id(0)); + moveSocket->setThreadOwner(std::thread::id(0)); - docBroker->addCallback([docBroker, socket, clientSession, format]() + docBroker->addCallback([docBroker, moveSocket, clientSession, format]() { - clientSession->setSaveAsSocket(socket); + auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket); + clientSession->setSaveAsSocket(streamSocket); // Move the socket into DocBroker. - docBroker->addSocketToPoll(socket); + docBroker->addSocketToPoll(moveSocket); // First add and load the session. docBroker->addSession(clientSession); @@ -1882,6 +1879,7 @@ private: std::vector<char> saveasRequest(saveas.begin(), saveas.end()); clientSession->handleMessage(true, WebSocketHandler::WSOpCode::Text, saveasRequest); }); + }); sent = true; } @@ -1895,8 +1893,7 @@ private: // TODO: We should differentiate between bad request and failed conversion. throw BadRequestException("Failed to convert and send file."); } - - return socketOwnership; + return; } else if (tokens.count() >= 4 && tokens[3] == "insertfile") { @@ -1936,7 +1933,7 @@ private: File(tmpPath).moveTo(fileName); response.setContentLength(0); socket->send(response); - return socketOwnership; + return; } } } @@ -2003,21 +2000,21 @@ private: { LOG_ERR("Download file [" << filePath.toString() << "] not found."); } - (void)responded; - return socketOwnership; + return; } throw BadRequestException("Invalid or unknown request."); } - SocketHandlerInterface::SocketOwnership handleClientWsUpgrade(const Poco::Net::HTTPRequest& request, const std::string& url) + void handleClientWsUpgrade(const Poco::Net::HTTPRequest& request, const std::string& url, + SocketDisposition &disposition) { auto socket = _socket.lock(); if (!socket) { LOG_WRN("No socket to handle client WS upgrade for request: " << request.getURI() << ", url: " << url); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } LOG_INF("Client WS request: " << request.getURI() << ", url: " << url << ", socket #" << socket->getFD()); @@ -2029,7 +2026,7 @@ private: { LOG_ERR("Limit on maximum number of connections of " << MAX_CONNECTIONS << " reached."); shutdownLimitReached(ws); - return SocketHandlerInterface::SocketOwnership::UNCHANGED; + return; } LOG_INF("Starting GET request handler for session [" << _id << "] on url [" << url << "]."); @@ -2057,8 +2054,6 @@ private: LOG_INF("URL [" << url << "] is " << (isReadOnly ? "readonly" : "writable") << "."); - SocketHandlerInterface::SocketOwnership socketOwnership = SocketHandlerInterface::SocketOwnership::UNCHANGED; - // Request a kit process for this doc. auto docBroker = findOrCreateDocBroker(ws, url, docKey, _id, uriPublic); if (docBroker) @@ -2066,28 +2061,30 @@ private: auto clientSession = createNewClientSession(&ws, _id, uriPublic, docBroker, isReadOnly); if (clientSession) { - // Transfer the client socket to the DocumentBroker. - - // Remove from current poll as we're moving ownership. - socketOwnership = SocketHandlerInterface::SocketOwnership::MOVED; + // Transfer the client socket to the DocumentBroker when we get back to the poll: + disposition.setMove([docBroker, clientSession] + (const std::shared_ptr<Socket> &moveSocket) + { + // Make sure the thread is running before adding callback. + docBroker->startThread(); - // Make sure the thread is running before adding callback. - docBroker->startThread(); + // We no longer own this socket. + moveSocket->setThreadOwner(std::thread::id(0)); - // We no longer own this socket. - socket->setThreadOwner(std::thread::id(0)); + docBroker->addCallback([docBroker, moveSocket, clientSession]() + { + auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket); - docBroker->addCallback([docBroker, socket, clientSession]() - { - // Set the ClientSession to handle Socket events. - socket->setHandler(clientSession); - LOG_DBG("Socket #" << socket->getFD() << " handler is " << clientSession->getName()); + // Set the ClientSession to handle Socket events. + streamSocket->setHandler(clientSession); + LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName()); - // Move the socket into DocBroker. - docBroker->addSocketToPoll(socket); + // Move the socket into DocBroker. + docBroker->addSocketToPoll(moveSocket); - // Add and load the session. - docBroker->addSession(clientSession); + // Add and load the session. + docBroker->addSession(clientSession); + }); }); } else @@ -2098,8 +2095,6 @@ private: } else LOG_WRN("Failed to create DocBroker with docKey [" << docKey << "]."); - - return socketOwnership; } private: _______________________________________________ Libreoffice-commits mailing list [email protected] https://lists.freedesktop.org/mailman/listinfo/libreoffice-commits
