THRIFT-3420 C++: TSSLSockets are not interruptable Client: C++ Patch: Martin Haimberger
This closes #690 Project: http://git-wip-us.apache.org/repos/asf/thrift/repo Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/0ad6ee95 Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/0ad6ee95 Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/0ad6ee95 Branch: refs/heads/master Commit: 0ad6ee95e002f41dd628d4044f901468f43ffc32 Parents: ae971ce Author: Martin Haimberger <[email protected]> Authored: Fri Nov 13 03:18:50 2015 -0800 Committer: Nobuaki Sukegawa <[email protected]> Committed: Mon Nov 23 17:09:27 2015 +0900 ---------------------------------------------------------------------- .../src/thrift/transport/TSSLServerSocket.cpp | 7 +- lib/cpp/src/thrift/transport/TSSLSocket.cpp | 259 +++++++++++++++-- lib/cpp/src/thrift/transport/TSSLSocket.h | 46 +++ lib/cpp/src/thrift/transport/TServerSocket.cpp | 16 +- lib/cpp/src/thrift/transport/TServerSocket.h | 4 +- lib/cpp/test/CMakeLists.txt | 7 +- lib/cpp/test/Makefile.am | 4 +- lib/cpp/test/TSSLSocketInterruptTest.cpp | 283 +++++++++++++++++++ 8 files changed, 595 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp index 7e1484d..89423b4 100644 --- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp @@ -48,7 +48,12 @@ TSSLServerSocket::TSSLServerSocket(int port, } boost::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client) { - return factory_->createSocket(client); + if (interruptableChildren_) { + return factory_->createSocket(client, pChildInterruptSockReader_); + + } else { + return factory_->createSocket(client); + } } } } http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLSocket.cpp ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.cpp b/lib/cpp/src/thrift/transport/TSSLSocket.cpp index 98c5326..6e9a4de 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLSocket.cpp @@ -28,6 +28,14 @@ #ifdef HAVE_SYS_SOCKET_H #include <sys/socket.h> #endif +#ifdef HAVE_SYS_POLL_H +#include <sys/poll.h> +#endif +#ifdef HAVE_FCNTL_H +#include <fcntl.h> +#endif + + #include <boost/lexical_cast.hpp> #include <boost/shared_array.hpp> #include <openssl/err.h> @@ -189,14 +197,28 @@ TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx) : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener) + : TSocket(), server_(false), ssl_(NULL), ctx_(ctx) { + interruptListener_ = interruptListener; +} + TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket) : TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener) + : TSocket(socket, interruptListener), server_(false), ssl_(NULL), ctx_(ctx) { +} + TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port) : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) { } +TSSLSocket::TSSLSocket(boost::shared_ptr<SSLContext> ctx, string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener) + : TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) { + interruptListener_ = interruptListener; +} + TSSLSocket::~TSSLSocket() { close(); } @@ -222,16 +244,32 @@ bool TSSLSocket::peek() { checkHandshake(); int rc; uint8_t byte; - rc = SSL_peek(ssl_, &byte, 1); - if (rc < 0) { - int errno_copy = THRIFT_GET_SOCKET_ERROR; - string errors; - buildErrors(errors, errno_copy); - throw TSSLException("SSL_peek: " + errors); - } - if (rc == 0) { - ERR_clear_error(); - } + do { + rc = SSL_peek(ssl_, &byte, 1); + if (rc < 0) { + + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + continue; + default:;// do nothing + } + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_peek: " + errors); + } else if (rc == 0) { + ERR_clear_error(); + break; + } + } while (true); return (rc > 0); } @@ -244,7 +282,28 @@ void TSSLSocket::open() { void TSSLSocket::close() { if (ssl_ != NULL) { - int rc = SSL_shutdown(ssl_); + int rc; + + do { + rc = SSL_shutdown(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); + if (rc < 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; string errors; @@ -262,14 +321,36 @@ uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) { checkHandshake(); int32_t bytes = 0; for (int32_t retries = 0; retries < maxRecvRetries_; retries++) { + ERR_clear_error(); bytes = SSL_read(ssl_, buf, len); if (bytes >= 0) break; - int errno_copy = THRIFT_GET_SOCKET_ERROR; - if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) { - if (ERR_get_error() == 0 && errno_copy == THRIFT_EINTR) { + int32_t errno_copy = THRIFT_GET_SOCKET_ERROR; + int32_t error = SSL_get_error(ssl_, bytes); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + if (retries++ >= maxRecvRetries_) { + // THRIFT_EINTR needs to be handled manually and we can tolerate + // a certain number + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + if (waitForEvent(error == SSL_ERROR_WANT_READ) == TSSL_EINTR ) { + // repeat operation + if (retries++ < maxRecvRetries_) { + // THRIFT_EINTR needs to be handled manually and we can tolerate + // a certain number + continue; + } + throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries"); + } continue; - } + default:;// do nothing } string errors; buildErrors(errors, errno_copy); @@ -283,9 +364,23 @@ void TSSLSocket::write(const uint8_t* buf, uint32_t len) { // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX. uint32_t written = 0; while (written < len) { + ERR_clear_error(); int32_t bytes = SSL_write(ssl_, &buf[written], len - written); if (bytes <= 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, bytes); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + continue; + default:;// do nothing + } string errors; buildErrors(errors, errno_copy); throw TSSLException("SSL_write: " + errors); @@ -319,13 +414,76 @@ void TSSLSocket::checkHandshake() { if (ssl_ != NULL) { return; } + + // set underlying socket to non-blocking + int flags; + if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0 + || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) { + GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ", + THRIFT_GET_SOCKET_ERROR); + ::THRIFT_CLOSESOCKET(socket_); + return; + } + ssl_ = ctx_->createSSL(); + + //set read and write bios to non-blocking + BIO* wbio = BIO_new(BIO_s_mem()); + if (wbio == NULL) { + throw TSSLException("SSL_get_wbio returns NULL"); + } + BIO_set_nbio(wbio, 1); + + BIO* rbio = BIO_new(BIO_s_mem()); + if (rbio == NULL) { + throw TSSLException("SSL_get_rbio returns NULL"); + } + BIO_set_nbio(rbio, 1); + + SSL_set_bio(ssl_, rbio, wbio); + SSL_set_fd(ssl_, static_cast<int>(socket_)); int rc; if (server()) { - rc = SSL_accept(ssl_); + do { + rc = SSL_accept(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); } else { - rc = SSL_connect(ssl_); + do { + rc = SSL_connect(ssl_); + if (rc <= 0) { + int errno_copy = THRIFT_GET_SOCKET_ERROR; + int error = SSL_get_error(ssl_, rc); + switch (error) { + case SSL_ERROR_SYSCALL: + if ((errno_copy != THRIFT_EINTR) + || (errno_copy != THRIFT_EAGAIN)) { + break; + } + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + waitForEvent(error == SSL_ERROR_WANT_READ); + rc = 2; + default:;// do nothing + } + } + } while (rc == 2); } if (rc <= 0) { int errno_copy = THRIFT_GET_SOCKET_ERROR; @@ -443,6 +601,54 @@ void TSSLSocket::authorize() { } } +unsigned int TSSLSocket::waitForEvent(bool wantRead) { + int fdSocket; + BIO* bio; + + if (wantRead) { + bio = SSL_get_rbio(ssl_); + } else { + bio = SSL_get_wbio(ssl_); + } + + if (bio == NULL) { + throw TSSLException("SSL_get_?bio returned NULL"); + } + + if (BIO_get_fd(bio, &fdSocket) <= 0) { + throw TSSLException("BIO_get_fd failed"); + } + + struct THRIFT_POLLFD fds[2]; + std::memset(fds, 0, sizeof(fds)); + fds[0].fd = fdSocket; + fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLOUT; + + if (interruptListener_) { + fds[1].fd = *(interruptListener_.get()); + fds[1].events = THRIFT_POLLIN; + } + + int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, -1); + + if (ret < 0) { + // error cases + if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) { + return TSSL_EINTR; // repeat operation + } + int errno_copy = THRIFT_GET_SOCKET_ERROR; + GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } else if (ret > 0){ + if (fds[1].revents & THRIFT_POLLIN) { + throw TTransportException(TTransportException::INTERRUPTED, "Interrupted"); + } + return TSSL_DATA; + } else { + throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)"); + } +} + // TSSLSocketFactory implementation uint64_t TSSLSocketFactory::count_ = 0; Mutex TSSLSocketFactory::mutex_; @@ -475,18 +681,37 @@ boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() { return ssl; } +boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener) { + boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener)); + setup(ssl); + return ssl; +} + boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) { boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket)); setup(ssl); return ssl; } +boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener) { + boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener)); + setup(ssl); + return ssl; +} + boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port) { boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port)); setup(ssl); return ssl; } +boost::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener) { + boost::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener)); + setup(ssl); + return ssl; +} + + void TSSLSocketFactory::setup(boost::shared_ptr<TSSLSocket> ssl) { ssl->server(server()); if (access_ == NULL && !server()) { http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TSSLSocket.h ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/transport/TSSLSocket.h b/lib/cpp/src/thrift/transport/TSSLSocket.h index 02d414b..ba8abf4 100644 --- a/lib/cpp/src/thrift/transport/TSSLSocket.h +++ b/lib/cpp/src/thrift/transport/TSSLSocket.h @@ -43,6 +43,9 @@ enum SSLProtocol { LATEST = TLSv1_2 }; +#define TSSL_EINTR 0 +#define TSSL_DATA 1 + /** * Initialize OpenSSL library. This function, or some other * equivalent function to initialize OpenSSL, must be called before @@ -99,18 +102,35 @@ protected: */ TSSLSocket(boost::shared_ptr<SSLContext> ctx); /** + * Constructor with an interrupt signal. + */ + TSSLSocket(boost::shared_ptr<SSLContext> ctx, boost::shared_ptr<THRIFT_SOCKET> interruptListener); + /** * Constructor, create an instance of TSSLSocket given an existing socket. * * @param socket An existing socket */ TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket); /** + * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted. + * + * @param socket An existing socket + */ + TSSLSocket(boost::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener); + /** * Constructor. * * @param host Remote host name * @param port Remote port number */ TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port); + /** + * Constructor with an interrupt signal. + * + * @param host Remote host name + * @param port Remote port number + */ + TSSLSocket(boost::shared_ptr<SSLContext> ctx, std::string host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener); /** * Authorize peer access after SSL handshake completes. */ @@ -119,6 +139,15 @@ protected: * Initiate SSL handshake if not already initiated. */ void checkHandshake(); + /** + * Waits for an socket or shutdown event. + * + * @throw TTransportException::INTERRUPTED if interrupted is signaled. + * + * @return TSSL_EINTR if EINTR happened on the underlying socket + * TSSL_DATA if data is available on the socket. + */ + unsigned int waitForEvent(bool wantRead); bool server_; SSL* ssl_; @@ -144,12 +173,22 @@ public: */ virtual boost::shared_ptr<TSSLSocket> createSocket(); /** + * Create an instance of TSSLSocket with a fresh new socket, which is interruptable. + */ + virtual boost::shared_ptr<TSSLSocket> createSocket(boost::shared_ptr<THRIFT_SOCKET> interruptListener); + /** * Create an instance of TSSLSocket with the given socket. * * @param socket An existing socket. */ virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket); /** + * Create an instance of TSSLSocket with the given socket which is interruptable. + * + * @param socket An existing socket. + */ + virtual boost::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket, boost::shared_ptr<THRIFT_SOCKET> interruptListener); + /** * Create an instance of TSSLSocket. * * @param host Remote host to be connected to @@ -157,6 +196,13 @@ public: */ virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port); /** + * Create an instance of TSSLSocket. + * + * @param host Remote host to be connected to + * @param port Remote port to be connected to + */ + virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port, boost::shared_ptr<THRIFT_SOCKET> interruptListener); + /** * Set ciphers to be used in SSL handshake process. * * @param ciphers A list of ciphers http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TServerSocket.cpp ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp index 215cda6..137dc32 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.cpp +++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp @@ -119,7 +119,8 @@ using namespace std; using boost::shared_ptr; TServerSocket::TServerSocket(int port) - : port_(port), + : interruptableChildren_(true), + port_(port), serverSocket_(THRIFT_INVALID_SOCKET), acceptBacklog_(DEFAULT_BACKLOG), sendTimeout_(0), @@ -130,7 +131,6 @@ TServerSocket::TServerSocket(int port) tcpSendBuffer_(0), tcpRecvBuffer_(0), keepAlive_(false), - interruptableChildren_(true), listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), @@ -138,7 +138,8 @@ TServerSocket::TServerSocket(int port) } TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) - : port_(port), + : interruptableChildren_(true), + port_(port), serverSocket_(THRIFT_INVALID_SOCKET), acceptBacklog_(DEFAULT_BACKLOG), sendTimeout_(sendTimeout), @@ -149,7 +150,6 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) tcpSendBuffer_(0), tcpRecvBuffer_(0), keepAlive_(false), - interruptableChildren_(true), listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), @@ -157,7 +157,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) } TServerSocket::TServerSocket(const string& address, int port) - : port_(port), + : interruptableChildren_(true), + port_(port), address_(address), serverSocket_(THRIFT_INVALID_SOCKET), acceptBacklog_(DEFAULT_BACKLOG), @@ -169,7 +170,6 @@ TServerSocket::TServerSocket(const string& address, int port) tcpSendBuffer_(0), tcpRecvBuffer_(0), keepAlive_(false), - interruptableChildren_(true), listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), @@ -177,7 +177,8 @@ TServerSocket::TServerSocket(const string& address, int port) } TServerSocket::TServerSocket(const string& path) - : port_(0), + : interruptableChildren_(true), + port_(0), path_(path), serverSocket_(THRIFT_INVALID_SOCKET), acceptBacklog_(DEFAULT_BACKLOG), @@ -189,7 +190,6 @@ TServerSocket::TServerSocket(const string& path) tcpSendBuffer_(0), tcpRecvBuffer_(0), keepAlive_(false), - interruptableChildren_(true), listening_(false), interruptSockWriter_(THRIFT_INVALID_SOCKET), interruptSockReader_(THRIFT_INVALID_SOCKET), http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/src/thrift/transport/TServerSocket.h ---------------------------------------------------------------------- diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h index 58e4afd..20a37e7 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.h +++ b/lib/cpp/src/thrift/transport/TServerSocket.h @@ -123,6 +123,8 @@ public: protected: boost::shared_ptr<TTransport> acceptImpl(); virtual boost::shared_ptr<TSocket> createSocket(THRIFT_SOCKET client); + bool interruptableChildren_; + boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets private: void notify(THRIFT_SOCKET notifySock); @@ -140,13 +142,11 @@ private: int tcpSendBuffer_; int tcpRecvBuffer_; bool keepAlive_; - bool interruptableChildren_; bool listening_; THRIFT_SOCKET interruptSockWriter_; // is notified on interrupt() THRIFT_SOCKET interruptSockReader_; // is used in select/poll with serverSocket_ for interruptability THRIFT_SOCKET childInterruptSockWriter_; // is notified on interruptChildren() - boost::shared_ptr<THRIFT_SOCKET> pChildInterruptSockReader_; // if interruptableChildren_ this is shared with child TSockets socket_func_t listenCallback_; socket_func_t acceptCallback_; http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/lib/cpp/test/CMakeLists.txt b/lib/cpp/test/CMakeLists.txt index 02932cb..491d343 100644 --- a/lib/cpp/test/CMakeLists.txt +++ b/lib/cpp/test/CMakeLists.txt @@ -92,7 +92,10 @@ if ( MSVC ) endif ( MSVC ) -set( TInterruptTest_SOURCES TSocketInterruptTest.cpp ) +set( TInterruptTest_SOURCES + TSocketInterruptTest.cpp + TSSLSocketInterruptTest.cpp +) if (WIN32) list(APPEND TInterruptTest_SOURCES TPipeInterruptTest.cpp @@ -108,7 +111,7 @@ LINK_AGAINST_THRIFT_LIBRARY(TInterruptTest thrift) if (NOT MSVC AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") target_link_libraries(TInterruptTest -lrt) endif () -add_test(NAME TInterruptTest COMMAND TInterruptTest) +add_test(NAME TInterruptTest COMMAND TInterruptTest "${CMAKE_CURRENT_SOURCE_DIR}/../../../test/keys") add_executable(TServerIntegrationTest TServerIntegrationTest.cpp) target_link_libraries(TServerIntegrationTest http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/Makefile.am ---------------------------------------------------------------------- diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am index 4b7a99f..1895afc 100755 --- a/lib/cpp/test/Makefile.am +++ b/lib/cpp/test/Makefile.am @@ -128,11 +128,13 @@ UnitTests_LDADD = \ $(BOOST_TEST_LDADD) TInterruptTest_SOURCES = \ - TSocketInterruptTest.cpp + TSocketInterruptTest.cpp \ + TSSLSocketInterruptTest.cpp TInterruptTest_LDADD = \ libtestgencpp.la \ $(BOOST_TEST_LDADD) \ + $(BOOST_FILESYSTEM_LDADD) \ $(BOOST_CHRONO_LDADD) \ $(BOOST_SYSTEM_LDADD) \ $(BOOST_THREAD_LDADD) http://git-wip-us.apache.org/repos/asf/thrift/blob/0ad6ee95/lib/cpp/test/TSSLSocketInterruptTest.cpp ---------------------------------------------------------------------- diff --git a/lib/cpp/test/TSSLSocketInterruptTest.cpp b/lib/cpp/test/TSSLSocketInterruptTest.cpp new file mode 100644 index 0000000..c723d0e --- /dev/null +++ b/lib/cpp/test/TSSLSocketInterruptTest.cpp @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <boost/test/auto_unit_test.hpp> +#include <boost/test/unit_test_suite.hpp> +#include <boost/bind.hpp> +#include <boost/chrono/duration.hpp> +#include <boost/date_time/posix_time/posix_time_duration.hpp> +#include <boost/thread/thread.hpp> +#include <boost/filesystem.hpp> +#include <boost/format.hpp> +#include <boost/shared_ptr.hpp> +#include <thrift/transport/TSSLSocket.h> +#include <thrift/transport/TSSLServerSocket.h> +#include "TestPortFixture.h" +#ifdef __linux__ +#include <signal.h> +#endif + +using apache::thrift::transport::TSSLServerSocket; +using apache::thrift::transport::TSSLSocket; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportException; +using apache::thrift::transport::TSSLSocketFactory; + +boost::filesystem::path keyDir; +boost::filesystem::path certFile(const std::string& filename) +{ + return keyDir / filename; +} +boost::mutex gMutex; + +struct GlobalFixtureSSL +{ + GlobalFixtureSSL() + { + using namespace boost::unit_test::framework; + for (int i = 0; i < master_test_suite().argc; ++i) + { + BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]); + } + +#ifdef __linux__ + // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has + // disconnected can cause a SIGPIPE signal... + signal(SIGPIPE, SIG_IGN); +#endif + + TSSLSocketFactory::setManualOpenSSLInitialization(true); + apache::thrift::transport::initializeOpenSSL(); + + keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys"; + if (!boost::filesystem::exists(certFile("server.crt"))) + { + keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]); + if (!boost::filesystem::exists(certFile("server.crt"))) + { + throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s)."); + } + } + } + + virtual ~GlobalFixtureSSL() + { + apache::thrift::transport::cleanupOpenSSL(); +#ifdef __linux__ + signal(SIGPIPE, SIG_DFL); +#endif + } +}; + +#if (BOOST_VERSION >= 105900) +BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL); +#else +BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL) +#endif + +BOOST_FIXTURE_TEST_SUITE(TSSLSocketInterruptTest, TestPortFixture) + +void readerWorker(boost::shared_ptr<TTransport> tt, uint32_t expectedResult) { + uint8_t buf[4]; + try { + tt->read(buf, 1); + BOOST_CHECK_EQUAL(expectedResult, tt->read(buf, 4)); + } catch (const TTransportException& tx) { + BOOST_CHECK_EQUAL(TTransportException::INTERNAL_ERROR, tx.getType()); + } +} + +void readerWorkerMustThrow(boost::shared_ptr<TTransport> tt) { + try { + uint8_t buf[400]; + tt->read(buf, 1); + tt->read(buf, 400); + BOOST_ERROR("should not have gotten here"); + } catch (const TTransportException& tx) { + BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType()); + } +} + +boost::shared_ptr<TSSLSocketFactory> createServerSocketFactory() { + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory; + + pServerSocketFactory.reset(new TSSLSocketFactory()); + pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + pServerSocketFactory->loadCertificate(certFile("server.crt").native().c_str()); + pServerSocketFactory->loadPrivateKey(certFile("server.key").native().c_str()); + pServerSocketFactory->server(true); + return pServerSocketFactory; +} + +boost::shared_ptr<TSSLSocketFactory> createClientSocketFactory() { + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory; + + pClientSocketFactory.reset(new TSSLSocketFactory()); + pClientSocketFactory->authenticate(true); + pClientSocketFactory->loadCertificate(certFile("client.crt").native().c_str()); + pClientSocketFactory->loadPrivateKey(certFile("client.key").native().c_str()); + pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").native().c_str()); + return pClientSocketFactory; +} + +BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read_while_handshaking) { + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.listen(); + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory(); + boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort); + clientSock->open(); + boost::shared_ptr<TTransport> accepted = sock1.accept(); + boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted)); + boost::this_thread::sleep(boost::posix_time::milliseconds(50)); + // readThread is practically guaranteed to be blocking now + sock1.interruptChildren(); + BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)), + "server socket interruptChildren did not interrupt child read"); + clientSock->close(); + accepted->close(); + sock1.close(); +} + +BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read) { + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.listen(); + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory(); + boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort); + clientSock->open(); + boost::shared_ptr<TTransport> accepted = sock1.accept(); + boost::thread readThread(boost::bind(readerWorkerMustThrow, accepted)); + clientSock->write((const uint8_t*)"0", 1); + boost::this_thread::sleep(boost::posix_time::milliseconds(50)); + // readThread is practically guaranteed to be blocking now + sock1.interruptChildren(); + BOOST_CHECK_MESSAGE(readThread.try_join_for(boost::chrono::milliseconds(20)), + "server socket interruptChildren did not interrupt child read"); + accepted->close(); + clientSock->close(); + sock1.close(); +} + +BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_read) { + std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl; + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior + sock1.listen(); + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory(); + boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort); + clientSock->open(); + boost::shared_ptr<TTransport> accepted = sock1.accept(); + boost::thread readThread(boost::bind(readerWorker, accepted, 0)); + clientSock->write((const uint8_t*)"0", 1); + boost::this_thread::sleep(boost::posix_time::milliseconds(50)); + // readThread is practically guaranteed to be blocking here + sock1.interruptChildren(); + BOOST_CHECK_MESSAGE(!readThread.try_join_for(boost::chrono::milliseconds(200)), + "server socket interruptChildren interrupted child read"); + + // only way to proceed is to have the client disconnect + clientSock->close(); + readThread.join(); + accepted->close(); + sock1.close(); +} + +BOOST_AUTO_TEST_CASE(test_ssl_cannot_change_after_listen) { + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.listen(); + BOOST_CHECK_THROW(sock1.setInterruptableChildren(false), std::logic_error); + sock1.close(); +} + +void peekerWorker(boost::shared_ptr<TTransport> tt, bool expectedResult) { + uint8_t buf[400]; + + tt->read(buf, 1); + BOOST_CHECK_EQUAL(expectedResult, tt->peek()); +} + +void peekerWorkerInterrupt(boost::shared_ptr<TTransport> tt) { + uint8_t buf[400]; + try { + tt->read(buf, 1); + tt->peek(); + } catch (const TTransportException& tx) { + BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED, tx.getType()); + } +} + +BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_peek) { + std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl; + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.listen(); + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory(); + boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort); + clientSock->open(); + boost::shared_ptr<TTransport> accepted = sock1.accept(); + // peek() will return false if child is interrupted + boost::thread peekThread(boost::bind(peekerWorkerInterrupt, accepted)); + clientSock->write((const uint8_t*)"0", 1); + boost::this_thread::sleep(boost::posix_time::milliseconds(50)); + // peekThread is practically guaranteed to be blocking now + sock1.interruptChildren(); + BOOST_CHECK_MESSAGE(peekThread.try_join_for(boost::chrono::milliseconds(200)), + "server socket interruptChildren did not interrupt child peek"); +#ifdef __linux__ + signal(SIGPIPE, SIG_IGN); +#endif + clientSock->close(); + accepted->close(); + sock1.close(); +} + +BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_peek) { + std::cout << "An error message from SSL_Shutdown on the console is expected:" << std::endl; + boost::shared_ptr<TSSLSocketFactory> pServerSocketFactory = createServerSocketFactory(); + TSSLServerSocket sock1("localhost", m_serverPort, pServerSocketFactory); + sock1.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior + sock1.listen(); + boost::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory(); + boost::shared_ptr<TSSLSocket> clientSock = pClientSocketFactory->createSocket("localhost", m_serverPort); + clientSock->open(); + boost::shared_ptr<TTransport> accepted = sock1.accept(); + // peek() will return false when remote side is closed + boost::thread peekThread(boost::bind(peekerWorker, accepted, false)); + //boost::thread peekThread(boost::bind(peekerWorkerRead, clientSock, false)); + clientSock->write((const uint8_t*)"0", 1); + boost::this_thread::sleep(boost::posix_time::milliseconds(50)); + // peekThread is practically guaranteed to be blocking now + sock1.interruptChildren(); + BOOST_CHECK_MESSAGE(!peekThread.try_join_for(boost::chrono::milliseconds(200)), + "server socket interruptChildren interrupted child peek"); + + // only way to proceed is to have the client disconnect +#ifdef __linux__ + signal(SIGPIPE, SIG_IGN); +#endif + clientSock->close(); + peekThread.join(); + accepted->close(); + sock1.close(); +} + +BOOST_AUTO_TEST_SUITE_END()
