net/WebSocketHandler.hpp | 266 +++++++++++++++++++++++++++++++---------------- 1 file changed, 181 insertions(+), 85 deletions(-)
New commits: commit 023cbfc32da4b32a0db44ed0497c0ddeab121ff5 Author: Gabriel Masei <gabriel.ma...@1and1.ro> AuthorDate: Fri Mar 8 10:21:17 2019 +0200 Commit: Michael Meeks <michael.me...@collabora.com> CommitDate: Tue Mar 12 15:19:42 2019 +0100 Added support for defragmentation of incoming websocket fragmented messages and handled some protocol error cases Change-Id: I4d11a6527b6b131c65101fd53b71015529645f74 Reviewed-on: https://gerrit.libreoffice.org/68901 Reviewed-by: Michael Meeks <michael.me...@collabora.com> Tested-by: Michael Meeks <michael.me...@collabora.com> diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp index 88350df1e..1b24ab93b 100644 --- a/net/WebSocketHandler.hpp +++ b/net/WebSocketHandler.hpp @@ -37,6 +37,8 @@ private: std::atomic<bool> _shuttingDown; bool _isClient; bool _isMasking; + bool _inFragmentBlock; + bool _isManualDefrag; protected: struct WSFrameMask @@ -50,16 +52,29 @@ protected: public: /// Perform upgrade ourselves, or select a client web socket. - WebSocketHandler(bool isClient = false, bool isMasking = true) : + /// Parameters: + /// isClient: the instance should behave like a client (true) or like a server (false) + /// (from websocket perspective) + /// isMasking: a client should mask (true) or not (false) outgoing frames + /// isManualDefrag: the message handler should be called for every fragment of a message and + /// defragmentation should be handled inside message handler (true) or the message handler + /// should be called after all fragments of a message were received and the message + /// was defragmented (false). + WebSocketHandler(bool isClient = false, bool isMasking = true, bool isManualDefrag = false) : _lastPingSentTime(std::chrono::steady_clock::now()), _pingTimeUs(0), _shuttingDown(false), _isClient(isClient), - _isMasking(isClient && isMasking) + _isMasking(isClient && isMasking), + _inFragmentBlock(false), + _isManualDefrag(isManualDefrag) { } /// Upgrades itself to a websocket directly. + /// Parameters: + /// socket: the TCP socket which received the upgrade request + /// request: the HTTP upgrade request to WebSocket WebSocketHandler(const std::weak_ptr<StreamSocket>& socket, const Poco::Net::HTTPRequest& request) : _socket(socket), @@ -69,7 +84,9 @@ public: _pingTimeUs(0), _shuttingDown(false), _isClient(false), - _isMasking(false) + _isMasking(false), + _inFragmentBlock(false), + _isManualDefrag(false) { upgradeToWebSocket(request); } @@ -99,8 +116,8 @@ public: RESERVED_TLS_FAILURE = 1015 }; - /// Sends WS shutdown message to the peer. - void shutdown(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "") + /// Sends WS Close frame to the peer. + void sendCloseFrame(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "") { std::shared_ptr<StreamSocket> socket = _socket.lock(); if (socket == nullptr) @@ -126,7 +143,22 @@ public: #endif } - bool handleOneIncomingMessage(const std::shared_ptr<StreamSocket>& socket) + void shutdown(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "") + { + if (!_shuttingDown) + sendCloseFrame(statusCode, statusMessage); + std::shared_ptr<StreamSocket> socket = _socket.lock(); + if (socket) + { + socket->closeConnection(); + socket->getInBuffer().clear(); + } + _wsPayload.clear(); + _inFragmentBlock = false; + _shuttingDown = false; + } + + bool handleTCPStream(const std::shared_ptr<StreamSocket>& socket) { assert(socket && "Expected a valid socket instance."); @@ -177,7 +209,7 @@ public: headerLen += 8; } - unsigned char *data, *mask; + unsigned char *data, *mask = nullptr; if (hasMask) { @@ -187,117 +219,165 @@ public: if (payloadLen + headerLen > len) { // partial read wait for more data. - LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket message, have " << len << " bytes, message is " << payloadLen + headerLen << " bytes"); + LOG_TRC("#" << socket->getFD() << ": Still incomplete WebSocket frame, have " << len << " bytes, frame is " << payloadLen + headerLen << " bytes"); return false; } + if (hasMask && _isClient) + { + LOG_ERR("#" << socket->getFD() << ": Servers should not send masked frames. Only clients."); + shutdown(StatusCodes::PROTOCOL_ERROR); + return true; + } + LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket data of " << len << " bytes: " << Util::stringifyHexLine(socket->getInBuffer(), 0, std::min((size_t)32, len))); data = p + headerLen; - if (hasMask) + if (isControlFrame(code)) { - const size_t end = _wsPayload.size(); - _wsPayload.resize(end + payloadLen); - char* wsData = &_wsPayload[end]; - for (size_t i = 0; i < payloadLen; ++i) - *wsData++ = data[i] ^ mask[i % 4]; - } else - _wsPayload.insert(_wsPayload.end(), data, data + payloadLen); -#else - unsigned char * const p = reinterpret_cast<unsigned char*>(&socket->getInBuffer()[0]); - _wsPayload.insert(_wsPayload.end(), p, p + len); - const size_t headerLen = 0; - const size_t payloadLen = len; -#endif - - assert(_wsPayload.size() >= payloadLen); - - socket->getInBuffer().erase(socket->getInBuffer().begin(), socket->getInBuffer().begin() + headerLen + payloadLen); + //Process control frames -#if !MOBILEAPP + std::vector<char> ctrlPayload; - // FIXME: fin, aggregating payloads into _wsPayload etc. - LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket message code " << static_cast<unsigned>(code) << - ", fin? " << fin << ", mask? " << hasMask << ", payload length: " << _wsPayload.size() << + readPayload(data, payloadLen, mask, ctrlPayload); + socket->getInBuffer().erase(socket->getInBuffer().begin(), socket->getInBuffer().begin() + headerLen + payloadLen); + LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket frame code " << static_cast<unsigned>(code) << + ", fin? " << fin << ", mask? " << hasMask << ", payload length: " << payloadLen << ", residual socket data: " << socket->getInBuffer().size() << " bytes."); - bool doClose = false; - - switch (code) - { - case WSOpCode::Pong: - { - if (_isClient) + // All control frames MUST NOT be fragmented and MUST have a payload length of 125 bytes or less + if (!fin) { - LOG_ERR("#" << socket->getFD() << ": Servers should not send pongs, only clients"); - doClose = true; - break; + LOG_ERR("#" << socket->getFD() << ": A control frame cannot be fragmented."); + shutdown(StatusCodes::PROTOCOL_ERROR); + return true; } - else + if (payloadLen > 125) { - _pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds> - (std::chrono::steady_clock::now() - _lastPingSentTime).count(); - LOG_TRC("#" << socket->getFD() << ": Pong received: " << _pingTimeUs << " microseconds"); - break; + LOG_ERR("#" << socket->getFD() << ": The payload length of a control frame must not exceed 125 bytes."); + shutdown(StatusCodes::PROTOCOL_ERROR); + return true; } - } - case WSOpCode::Ping: - if (_isClient) + + switch (code) { - auto now = std::chrono::steady_clock::now(); - _pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds> - (now - _lastPingSentTime).count(); - sendPong(now, &_wsPayload[0], payloadLen, socket); + case WSOpCode::Pong: + if (_isClient) + { + LOG_ERR("#" << socket->getFD() << ": Servers should not send pongs, only clients"); + shutdown(StatusCodes::POLICY_VIOLATION); + return true; + } + else + { + _pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds> + (std::chrono::steady_clock::now() - _lastPingSentTime).count(); + LOG_TRC("#" << socket->getFD() << ": Pong received: " << _pingTimeUs << " microseconds"); + } + break; + case WSOpCode::Ping: + if (_isClient) + { + auto now = std::chrono::steady_clock::now(); + _pingTimeUs = std::chrono::duration_cast<std::chrono::microseconds> + (now - _lastPingSentTime).count(); + sendPong(now, &ctrlPayload[0], payloadLen, socket); + } + else + { + LOG_ERR("#" << socket->getFD() << ": Clients should not send pings, only servers"); + shutdown(StatusCodes::POLICY_VIOLATION); + return true; + } + break; + case WSOpCode::Close: + { + std::string message; + StatusCodes statusCode = StatusCodes::NORMAL_CLOSE; + if (!_shuttingDown) + { + // Peer-initiated shutdown must be echoed. + // Otherwise, this is the echo to _our_ shutdown message, which we should ignore. + LOG_TRC("#" << socket->getFD() << ": Peer initiated socket shutdown. Code: " << static_cast<int>(statusCode)); + if (ctrlPayload.size()) + { + statusCode = static_cast<StatusCodes>((((uint64_t)(unsigned char)ctrlPayload[0]) << 8) + + (((uint64_t)(unsigned char)ctrlPayload[1]) << 0)); + if (ctrlPayload.size() > 2) + message.assign(&ctrlPayload[2], &ctrlPayload[2] + ctrlPayload.size() - 2); + } + } + shutdown(statusCode, message); + return true; + } + default: + LOG_ERR("#" << socket->getFD() << ": Received unknown control code"); + shutdown(StatusCodes::PROTOCOL_ERROR); break; } - else + + return true; + } + + // Check data frames for errors + if (_inFragmentBlock) + { + if (code != WSOpCode::Continuation) { - LOG_ERR("#" << socket->getFD() << ": Clients should not send pings, only servers"); - doClose = true; + LOG_ERR("#" << socket->getFD() << ": A fragment that is not the first fragment of a message must have the opcode equal to 0."); + shutdown(StatusCodes::PROTOCOL_ERROR); + return true; } - break; - case WSOpCode::Close: - doClose = true; - break; - default: - handleMessage(fin, code, _wsPayload); - break; + } + else if (code == WSOpCode::Continuation) + { + LOG_ERR("#" << socket->getFD() << ": An unfragmented message or the first fragment of a fragmented message must have the opcode different than 0."); + shutdown(StatusCodes::PROTOCOL_ERROR); + return true; } + //Process data frame + readPayload(data, payloadLen, mask, _wsPayload); #else - handleMessage(true, WSOpCode::Binary, _wsPayload); - + unsigned char * const p = reinterpret_cast<unsigned char*>(&socket->getInBuffer()[0]); + _wsPayload.insert(_wsPayload.end(), p, p + len); + const size_t headerLen = 0; + const size_t payloadLen = len; #endif + socket->getInBuffer().erase(socket->getInBuffer().begin(), socket->getInBuffer().begin() + headerLen + payloadLen); + #if !MOBILEAPP - if (doClose) + + LOG_TRC("#" << socket->getFD() << ": Incoming WebSocket frame code " << static_cast<unsigned>(code) << + ", fin? " << fin << ", mask? " << hasMask << ", payload length: " << payloadLen << + ", residual socket data: " << socket->getInBuffer().size() << " bytes."); + + if (fin) { - if (!_shuttingDown) + //If is final fragment then process the accumulated message. + handleMessage(fin, code, _wsPayload); + _inFragmentBlock = false; + } + else + { + if (_isManualDefrag) { - // Peer-initiated shutdown must be echoed. - // Otherwise, this is the echo to _our_ shutdown message, which we should ignore. - const StatusCodes statusCode = static_cast<StatusCodes>((((uint64_t)(unsigned char)_wsPayload[0]) << 8) + - (((uint64_t)(unsigned char)_wsPayload[1]) << 0)); - LOG_TRC("#" << socket->getFD() << ": Client initiated socket shutdown. Code: " << static_cast<int>(statusCode)); - if (_wsPayload.size() > 2) - { - const std::string message(&_wsPayload[2], &_wsPayload[2] + _wsPayload.size() - 2); - shutdown(statusCode, message); - } - else - { - shutdown(statusCode); - } + //If the user wants to process defragmentation on its own then let him process it. + handleMessage(fin, code, _wsPayload); + _inFragmentBlock = true; } else { - LOG_TRC("#" << socket->getFD() << ": Client responded to our shutdown."); + _inFragmentBlock = true; + //If is not final fragment then wait for next fragment. + return false; } - - // TCP Close. - socket->closeConnection(); } +#else + handleMessage(true, WSOpCode::Binary, _wsPayload); + #endif _wsPayload.clear(); @@ -328,7 +408,7 @@ public: #endif else { - while (handleOneIncomingMessage(socket)) + while (handleTCPStream(socket)) ; // might have multiple messages in the accumulated buffer. } } @@ -512,6 +592,22 @@ private: protected: + bool isControlFrame(WSOpCode code){ return code >= WSOpCode::Close; } + + void readPayload(unsigned char *data, size_t dataLen, unsigned char* mask, std::vector<char>& payload) + { + if (mask) + { + size_t end = payload.size(); + payload.resize(end + dataLen); + char* wsData = &payload[end]; + for (size_t i = 0; i < dataLen; ++i) + *wsData++ = data[i] ^ mask[i % 4]; + } + else + payload.insert(payload.end(), data, data + dataLen); + } + /// To be overriden to handle the websocket messages the way you need. virtual void handleMessage(bool /*fin*/, WSOpCode /*code*/, std::vector<char> &/*data*/) { _______________________________________________ Libreoffice-commits mailing list libreoffice-comm...@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/libreoffice-commits