This is an automated email from the ASF dual-hosted git repository. szaszm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/nifi-minifi-cpp.git
commit 0553df446aefb395b34c1e4e2b5ced31f07a4e04 Author: Adam Debreceni <[email protected]> AuthorDate: Wed Dec 1 14:23:32 2021 +0100 MINIFICPP-1692 TLSSocket: Break infinite loop when no more data can be read Closes #1218 Signed-off-by: Marton Szasz <[email protected]> --- cmake/BuildTests.cmake | 13 +++ .../TLSClientSocketSupportedProtocolsTest.cpp | 118 ++++--------------- libminifi/src/io/tls/TLSSocket.cpp | 4 +- libminifi/test/SimpleSSLTestServer.h | 130 +++++++++++++++++++++ libminifi/test/unit/tls/TLSStreamTests.cpp | 82 +++++++++++++ 5 files changed, 247 insertions(+), 100 deletions(-) diff --git a/cmake/BuildTests.cmake b/cmake/BuildTests.cmake index ecb5c70..6d5402e 100644 --- a/cmake/BuildTests.cmake +++ b/cmake/BuildTests.cmake @@ -101,6 +101,7 @@ target_include_directories(${CATCH_MAIN_LIB} SYSTEM BEFORE PRIVATE "${CMAKE_SOUR SET(TEST_RESOURCES ${TEST_DIR}/resources) GETSOURCEFILES(UNIT_TESTS "${TEST_DIR}/unit/") +GETSOURCEFILES(TLS_UNIT_TESTS "${TEST_DIR}/unit/tls/") GETSOURCEFILES(NANOFI_UNIT_TESTS "${NANOFI_TEST_DIR}") GETSOURCEFILES(INTEGRATION_TESTS "${TEST_DIR}/integration/") @@ -115,6 +116,18 @@ FOREACH(testfile ${UNIT_TESTS}) ENDFOREACH() message("-- Finished building ${UNIT_TEST_COUNT} unit test file(s)...") +if (NOT OPENSSL_OFF) + SET(UNIT_TEST_COUNT 0) + FOREACH(testfile ${TLS_UNIT_TESTS}) + get_filename_component(testfilename "${testfile}" NAME_WE) + add_executable("${testfilename}" "${TEST_DIR}/unit/tls/${testfile}") + createTests("${testfilename}") + MATH(EXPR UNIT_TEST_COUNT "${UNIT_TEST_COUNT}+1") + add_test(NAME "${testfilename}" COMMAND "${testfilename}" "${TEST_RESOURCES}/" WORKING_DIRECTORY ${TEST_DIR}) + ENDFOREACH() + message("-- Finished building ${UNIT_TEST_COUNT} TLS unit test file(s)...") +endif() + if(NOT WIN32 AND ENABLE_NANOFI) SET(UNIT_TEST_COUNT 0) FOREACH(testfile ${NANOFI_UNIT_TESTS}) diff --git a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp index 928bc56..6cc643c 100644 --- a/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp +++ b/extensions/standard-processors/tests/integration/TLSClientSocketSupportedProtocolsTest.cpp @@ -19,6 +19,7 @@ #include <sys/stat.h> #include <chrono> #include <thread> +#include <filesystem> #undef NDEBUG #include <cassert> #include <utility> @@ -26,114 +27,34 @@ #include <string> #include "properties/Configure.h" #include "io/tls/TLSSocket.h" +#include "SimpleSSLTestServer.h" namespace minifi = org::apache::nifi::minifi; -#ifdef WIN32 -#pragma comment(lib, "Ws2_32.lib") -using SocketDescriptor = SOCKET; -#else -using SocketDescriptor = int; -static constexpr SocketDescriptor INVALID_SOCKET = -1; -#endif /* WIN32 */ - - -class SimpleSSLTestServer { - public: - SimpleSSLTestServer(const SSL_METHOD* method, const std::string& port, const std::string& path) - : port_(port), had_connection_(false) { - ctx_ = SSL_CTX_new(method); - configureContext(path); - socket_descriptor_ = createSocket(std::stoi(port_)); - } - - ~SimpleSSLTestServer() { - SSL_shutdown(ssl_); - SSL_free(ssl_); - SSL_CTX_free(ctx_); - } - - void waitForConnection() { - server_read_thread_ = std::thread([this]() -> void { - SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr); - if (client != INVALID_SOCKET) { - ssl_ = SSL_new(ctx_); - SSL_set_fd(ssl_, client); - had_connection_ = (SSL_accept(ssl_) == 1); - } - }); - } - - void shutdownServer() { -#ifdef WIN32 - shutdown(socket_descriptor_, SD_BOTH); - closesocket(socket_descriptor_); -#else - shutdown(socket_descriptor_, SHUT_RDWR); - close(socket_descriptor_); -#endif - server_read_thread_.join(); - } - - bool hadConnection() const { - return had_connection_; - } - - private: - SSL_CTX *ctx_ = nullptr; - SSL* ssl_ = nullptr; - std::string port_; - SocketDescriptor socket_descriptor_; - bool had_connection_; - std::thread server_read_thread_; - - void configureContext(const std::string& path) { - SSL_CTX_set_ecdh_auto(ctx_, 1); - /* Set the key and cert */ - assert(SSL_CTX_use_certificate_file(ctx_, (path + "cn.crt.pem").c_str(), SSL_FILETYPE_PEM) == 1); - assert(SSL_CTX_use_PrivateKey_file(ctx_, (path + "cn.ckey.pem").c_str(), SSL_FILETYPE_PEM) == 1); - } - - static SocketDescriptor createSocket(int port) { - struct sockaddr_in addr; - - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - addr.sin_addr.s_addr = htonl(INADDR_ANY); - - SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0); - assert(socket_descriptor >= 0); - assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0); - assert(listen(socket_descriptor, 1) >= 0); - - return socket_descriptor; - } -}; - class SimpleSSLTestServerTLSv1 : public SimpleSSLTestServer { public: - SimpleSSLTestServerTLSv1(const std::string& port, const std::string& path) - : SimpleSSLTestServer(TLSv1_server_method(), port, path) { + SimpleSSLTestServerTLSv1(int port, const std::filesystem::path& key_dir) + : SimpleSSLTestServer(TLSv1_server_method(), port, key_dir) { } }; class SimpleSSLTestServerTLSv1_1 : public SimpleSSLTestServer { public: - SimpleSSLTestServerTLSv1_1(const std::string& port, const std::string& path) - : SimpleSSLTestServer(TLSv1_1_server_method(), port, path) { + SimpleSSLTestServerTLSv1_1(int port, const std::filesystem::path& key_dir) + : SimpleSSLTestServer(TLSv1_1_server_method(), port, key_dir) { } }; class SimpleSSLTestServerTLSv1_2 : public SimpleSSLTestServer { public: - SimpleSSLTestServerTLSv1_2(const std::string& port, const std::string& path) - : SimpleSSLTestServer(TLSv1_2_server_method(), port, path) { + SimpleSSLTestServerTLSv1_2(int port, const std::filesystem::path& key_dir) + : SimpleSSLTestServer(TLSv1_2_server_method(), port, key_dir) { } }; class TLSClientSocketSupportedProtocolsTest { public: - explicit TLSClientSocketSupportedProtocolsTest(const std::string& key_dir) + explicit TLSClientSocketSupportedProtocolsTest(const std::filesystem::path& key_dir) : key_dir_(key_dir), configuration_(std::make_shared<minifi::Configure>()) { } @@ -147,14 +68,13 @@ class TLSClientSocketSupportedProtocolsTest { protected: void configureSecurity() { host_ = minifi::io::Socket::getMyHostName(); - port_ = "38777"; if (!key_dir_.empty()) { configuration_->set(minifi::Configure::nifi_remote_input_secure, "true"); - configuration_->set(minifi::Configure::nifi_security_client_certificate, key_dir_ + "cn.crt.pem"); - configuration_->set(minifi::Configure::nifi_security_client_private_key, key_dir_ + "cn.ckey.pem"); - configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, key_dir_ + "cn.pass"); - configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, key_dir_ + "nifi-cert.pem"); - configuration_->set(minifi::Configure::nifi_default_directory, key_dir_); + configuration_->set(minifi::Configure::nifi_security_client_certificate, (key_dir_ / "cn.crt.pem").string()); + configuration_->set(minifi::Configure::nifi_security_client_private_key, (key_dir_ / "cn.ckey.pem").string()); + configuration_->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir_ / "cn.pass").string()); + configuration_->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir_ / "nifi-cert.pem").string()); + configuration_->set(minifi::Configure::nifi_default_directory, key_dir_.string()); } } @@ -166,11 +86,14 @@ class TLSClientSocketSupportedProtocolsTest { template <class TLSTestSever> void verifyTLSProtocolCompatibility(const bool should_be_compatible) { - TLSTestSever server(port_, key_dir_); + // bind to random port + TLSTestSever server(0, key_dir_); server.waitForConnection(); + int port = server.getPort(); + const auto socket_context = std::make_shared<minifi::io::TLSContext>(configuration_); - client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, std::stoi(port_), 0); + client_socket_ = std::make_unique<minifi::io::TLSSocket>(socket_context, host_, port, 0); const bool client_initialized_successfully = (client_socket_->initialize() == 0); assert(client_initialized_successfully == should_be_compatible); server.shutdownServer(); @@ -180,8 +103,7 @@ class TLSClientSocketSupportedProtocolsTest { protected: std::unique_ptr<minifi::io::TLSSocket> client_socket_; std::string host_; - std::string port_; - std::string key_dir_; + std::filesystem::path key_dir_; std::shared_ptr<minifi::Configure> configuration_; }; diff --git a/libminifi/src/io/tls/TLSSocket.cpp b/libminifi/src/io/tls/TLSSocket.cpp index 5d76e8c..af8772a 100644 --- a/libminifi/src/io/tls/TLSSocket.cpp +++ b/libminifi/src/io/tls/TLSSocket.cpp @@ -434,9 +434,9 @@ size_t TLSSocket::read(uint8_t *buf, size_t buflen) { const auto ssl_read_size = gsl::narrow<int>(std::min(buflen, gsl::narrow<size_t>(std::numeric_limits<int>::max()))); status = SSL_read(fd_ssl, buf, ssl_read_size); sslStatus = SSL_get_error(fd_ssl, status); - } while (status < 0 && sslStatus == SSL_ERROR_WANT_READ); + } while (status <= 0 && sslStatus == SSL_ERROR_WANT_READ); - if (status < 0) + if (status <= 0) break; buflen -= gsl::narrow<size_t>(status); diff --git a/libminifi/test/SimpleSSLTestServer.h b/libminifi/test/SimpleSSLTestServer.h new file mode 100644 index 0000000..b6cecf5 --- /dev/null +++ b/libminifi/test/SimpleSSLTestServer.h @@ -0,0 +1,130 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <openssl/ssl.h> +#include <openssl/err.h> +#include <filesystem> +#include <string> +#include "io/tls/TLSSocket.h" + +#ifdef WIN32 +#include <winsock2.h> +#include <ws2tcpip.h> +#pragma comment(lib, "Ws2_32.lib") +using SocketDescriptor = SOCKET; +#else +using SocketDescriptor = int; +static constexpr SocketDescriptor INVALID_SOCKET = -1; +#endif /* WIN32 */ + +namespace minifi = org::apache::nifi::minifi; + +class SimpleSSLTestServer { + struct SocketInitializer { + SocketInitializer() { +#ifdef WIN32 + static WSADATA s_wsaData; + const int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData); + if (0 != iWinSockInitResult) { + throw std::runtime_error("Cannot initialize socket"); + } +#endif + } + }; + + public: + SimpleSSLTestServer(const SSL_METHOD* method, int port, const std::filesystem::path& key_dir) + : port_(port), had_connection_(false) { + static SocketInitializer socket_initializer{}; + minifi::io::OpenSSLInitializer::getInstance(); + ctx_ = SSL_CTX_new(method); + configureContext(key_dir); + socket_descriptor_ = createSocket(port_); + } + + ~SimpleSSLTestServer() { + SSL_shutdown(ssl_); + SSL_free(ssl_); + SSL_CTX_free(ctx_); + } + + void waitForConnection() { + server_read_thread_ = std::thread([this]() -> void { + SocketDescriptor client = accept(socket_descriptor_, nullptr, nullptr); + if (client != INVALID_SOCKET) { + ssl_ = SSL_new(ctx_); + SSL_set_fd(ssl_, client); + had_connection_ = (SSL_accept(ssl_) == 1); + } + }); + } + + void shutdownServer() { +#ifdef WIN32 + shutdown(socket_descriptor_, SD_BOTH); + closesocket(socket_descriptor_); +#else + shutdown(socket_descriptor_, SHUT_RDWR); + close(socket_descriptor_); +#endif + server_read_thread_.join(); + } + + bool hadConnection() const { + return had_connection_; + } + + int getPort() const { + struct sockaddr_in addr; + socklen_t addr_len = sizeof(addr); + assert(getsockname(socket_descriptor_, (struct sockaddr*)&addr, &addr_len) == 0); + return ntohs(addr.sin_port); + } + + private: + SSL_CTX *ctx_ = nullptr; + SSL* ssl_ = nullptr; + int port_; + SocketDescriptor socket_descriptor_; + bool had_connection_; + std::thread server_read_thread_; + + void configureContext(const std::filesystem::path& key_dir) { + SSL_CTX_set_ecdh_auto(ctx_, 1); + /* Set the key and cert */ + assert(SSL_CTX_use_certificate_file(ctx_, (key_dir / "cn.crt.pem").string().c_str(), SSL_FILETYPE_PEM) == 1); + assert(SSL_CTX_use_PrivateKey_file(ctx_, (key_dir / "cn.ckey.pem").string().c_str(), SSL_FILETYPE_PEM) == 1); + } + + static SocketDescriptor createSocket(int port) { + struct sockaddr_in addr; + + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = htonl(INADDR_ANY); + + SocketDescriptor socket_descriptor = socket(AF_INET, SOCK_STREAM, 0); + assert(socket_descriptor >= 0); + assert(bind(socket_descriptor, (struct sockaddr*)&addr, sizeof(addr)) >= 0); + assert(listen(socket_descriptor, 1) >= 0); + + return socket_descriptor; + } +}; diff --git a/libminifi/test/unit/tls/TLSStreamTests.cpp b/libminifi/test/unit/tls/TLSStreamTests.cpp new file mode 100644 index 0000000..9fc5939 --- /dev/null +++ b/libminifi/test/unit/tls/TLSStreamTests.cpp @@ -0,0 +1,82 @@ +/** + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef LOAD_EXTENSIONS +#undef NDEBUG + +#include <cassert> + +#include "io/tls/TLSServerSocket.h" +#include "io/tls/TLSSocket.h" +#include "../../TestBase.h" +#include "../../SimpleSSLTestServer.h" +#include "../utils/IntegrationTestUtils.h" + +using namespace std::chrono_literals; + +static std::shared_ptr<minifi::io::TLSContext> createContext(const std::filesystem::path& key_dir) { + auto configuration = std::make_shared<minifi::Configure>(); + configuration->set(minifi::Configure::nifi_remote_input_secure, "true"); + configuration->set(minifi::Configure::nifi_security_client_certificate, (key_dir / "cn.crt.pem").string()); + configuration->set(minifi::Configure::nifi_security_client_private_key, (key_dir / "cn.ckey.pem").string()); + configuration->set(minifi::Configure::nifi_security_client_pass_phrase, (key_dir / "cn.pass").string()); + configuration->set(minifi::Configure::nifi_security_client_ca_certificate, (key_dir / "nifi-cert.pem").string()); + configuration->set(minifi::Configure::nifi_default_directory, key_dir.string()); + + return std::make_shared<minifi::io::TLSContext>(configuration); +} + +int main(int argc, char** argv) { + if (argc < 2) { + throw std::logic_error("Specify the key directory"); + } + std::filesystem::path key_dir(argv[1]); + + LogTestController::getInstance().setTrace<minifi::io::Socket>(); + LogTestController::getInstance().setTrace<minifi::io::TLSSocket>(); + LogTestController::getInstance().setTrace<minifi::io::TLSServerSocket>(); + LogTestController::getInstance().setTrace<minifi::io::TLSContext>(); + + auto server = std::make_unique<SimpleSSLTestServer>(TLSv1_2_server_method(), 0, key_dir); + int port = server->getPort(); + server->waitForConnection(); + + std::string host = minifi::io::Socket::getMyHostName(); + + auto client_ctx = createContext(key_dir); + assert(client_ctx->initialize(false) == 0); + + minifi::io::TLSSocket client_socket(client_ctx, host, port); + assert(client_socket.initialize() == 0); + + std::atomic_bool read_complete{false}; + + std::thread read_thread{[&] { + std::vector<uint8_t> buffer; + auto read_count = client_socket.read(buffer, 10); + assert(read_count == 0); + read_complete = true; + }}; + + server->shutdownServer(); + server.reset(); + + assert(utils::verifyEventHappenedInPollTime(1s, [&] {return read_complete.load();})); + + read_thread.join(); +}
