This is an automated email from the ASF dual-hosted git repository. alexey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/kudu.git
commit bdeaa122f6b14c41b3640ab5d575b11d4c5ffe1c Author: Alexey Serbin <[email protected]> AuthorDate: Wed Mar 17 13:21:00 2021 -0700 [security] small re-factoring on TlsHandshake A follow-up patch needs to know what side the TlsHandshake is running at. This changelist introduces that piece of information into the TlsHandshake class and also does other related re-factoring in the related code. Change-Id: Id8dfb16d0f59dc467d1d4dcdf5c1bc568062087e Reviewed-on: http://gerrit.cloudera.org:8080/17193 Tested-by: Kudu Jenkins Reviewed-by: Grant Henke <[email protected]> Reviewed-by: Attila Bukor <[email protected]> --- src/kudu/rpc/client_negotiation.cc | 4 +-- src/kudu/rpc/server_negotiation.cc | 4 +-- src/kudu/security/tls_context.cc | 28 ++++++------------- src/kudu/security/tls_context.h | 9 +++---- src/kudu/security/tls_handshake-test.cc | 18 ++++++------- src/kudu/security/tls_handshake.cc | 48 ++++++++++++++++++++++++++++++--- src/kudu/security/tls_handshake.h | 28 +++++++++---------- src/kudu/security/tls_socket-test.cc | 8 +++--- 8 files changed, 84 insertions(+), 63 deletions(-) diff --git a/src/kudu/rpc/client_negotiation.cc b/src/kudu/rpc/client_negotiation.cc index efa69a1..e5696b9 100644 --- a/src/kudu/rpc/client_negotiation.cc +++ b/src/kudu/rpc/client_negotiation.cc @@ -126,6 +126,7 @@ ClientNegotiation::ClientNegotiation(unique_ptr<Socket> socket, : socket_(std::move(socket)), helper_(SaslHelper::CLIENT), tls_context_(tls_context), + tls_handshake_(security::TlsHandshakeType::CLIENT), encryption_(encryption), tls_negotiated_(false), authn_token_(std::move(authn_token)), @@ -191,8 +192,7 @@ Status ClientNegotiation::Negotiate(unique_ptr<ErrorStatusPB>* rpc_error) { // TODO(KUDU-1921): allow the client to require TLS. if (encryption_ != RpcEncryption::DISABLED && ContainsKey(server_features_, TLS)) { - RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::CLIENT, - &tls_handshake_)); + RETURN_NOT_OK(tls_context_->InitiateHandshake(&tls_handshake_)); if (negotiated_authn_ == AuthenticationType::SASL) { // When using SASL authentication, verifying the server's certificate is diff --git a/src/kudu/rpc/server_negotiation.cc b/src/kudu/rpc/server_negotiation.cc index ac55399..c636856 100644 --- a/src/kudu/rpc/server_negotiation.cc +++ b/src/kudu/rpc/server_negotiation.cc @@ -160,6 +160,7 @@ ServerNegotiation::ServerNegotiation(unique_ptr<Socket> socket, : socket_(std::move(socket)), helper_(SaslHelper::SERVER), tls_context_(tls_context), + tls_handshake_(security::TlsHandshakeType::SERVER), encryption_(encryption), tls_negotiated_(false), token_verifier_(token_verifier), @@ -223,8 +224,7 @@ Status ServerNegotiation::Negotiate() { if (encryption_ != RpcEncryption::DISABLED && tls_context_->has_cert() && ContainsKey(client_features_, TLS)) { - RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::SERVER, - &tls_handshake_)); + RETURN_NOT_OK(tls_context_->InitiateHandshake(&tls_handshake_)); if (negotiated_authn_ != AuthenticationType::CERTIFICATE) { // The server does not need to verify the client's certificate unless it's diff --git a/src/kudu/security/tls_context.cc b/src/kudu/security/tls_context.cc index dcc1ee4..e32e78a 100644 --- a/src/kudu/security/tls_context.cc +++ b/src/kudu/security/tls_context.cc @@ -531,33 +531,21 @@ Status TlsContext::LoadCertificateAuthority(const string& certificate_path) { return AddTrustedCertificate(c); } -Status TlsContext::InitiateHandshake(TlsHandshakeType handshake_type, - TlsHandshake* handshake) const { +Status TlsContext::InitiateHandshake(TlsHandshake* handshake) const { SCOPED_OPENSSL_NO_PENDING_ERRORS; + DCHECK(handshake); CHECK(ctx_); - CHECK(!handshake->ssl_); + c_unique_ptr<SSL> ssl; { + // This lock is to protect against concurrent change of certificates + // while calling SSL_new() here. shared_lock<RWMutex> lock(lock_); - handshake->adopt_ssl(ssl_make_unique(SSL_new(ctx_.get()))); + ssl = ssl_make_unique(SSL_new(ctx_.get())); } - if (!handshake->ssl_) { + if (!ssl) { return Status::RuntimeError("failed to create SSL handle", GetOpenSSLErrors()); } - - SSL_set_bio(handshake->ssl(), - BIO_new(BIO_s_mem()), - BIO_new(BIO_s_mem())); - - switch (handshake_type) { - case TlsHandshakeType::SERVER: - SSL_set_accept_state(handshake->ssl()); - break; - case TlsHandshakeType::CLIENT: - SSL_set_connect_state(handshake->ssl()); - break; - } - - return Status::OK(); + return handshake->Init(std::move(ssl)); } } // namespace security diff --git a/src/kudu/security/tls_context.h b/src/kudu/security/tls_context.h index edf37f3..a13e838 100644 --- a/src/kudu/security/tls_context.h +++ b/src/kudu/security/tls_context.h @@ -26,19 +26,17 @@ #include <boost/optional/optional.hpp> #include "kudu/gutil/port.h" +#include "kudu/security/cert.h" // IWYU pragma: keep #include "kudu/security/openssl_util.h" -#include "kudu/security/tls_handshake.h" #include "kudu/util/locks.h" #include "kudu/util/rw_mutex.h" #include "kudu/util/status.h" -// IWYU pragma: no_include "kudu/security/cert.h" namespace kudu { namespace security { -class Cert; // IWYU pragma: keep -class CertSignRequest;// IWYU pragma: keep class PrivateKey; +class TlsHandshake; // TlsContext wraps data required by the OpenSSL library for creating and // accepting TLS protected channels. A single TlsContext instance should be used @@ -162,8 +160,7 @@ class TlsContext { Status LoadCertificateAuthority(const std::string& certificate_path) WARN_UNUSED_RESULT; // Initiates a new TlsHandshake instance. - Status InitiateHandshake(TlsHandshakeType handshake_type, - TlsHandshake* handshake) const WARN_UNUSED_RESULT; + Status InitiateHandshake(TlsHandshake* handshake) const WARN_UNUSED_RESULT; // Return the number of certs that have been marked as trusted. // Used by tests. diff --git a/src/kudu/security/tls_handshake-test.cc b/src/kudu/security/tls_handshake-test.cc index 5e6a71d..698af49 100644 --- a/src/kudu/security/tls_handshake-test.cc +++ b/src/kudu/security/tls_handshake-test.cc @@ -40,6 +40,7 @@ #include "kudu/util/test_macros.h" #include "kudu/util/test_util.h" +using kudu::security::ca::CertSigner; using std::string; using std::vector; @@ -48,8 +49,6 @@ DECLARE_int32(ipki_server_key_size); namespace kudu { namespace security { -using ca::CertSigner; - struct Case { PkiConfig client_pki; TlsVerificationMode client_verification; @@ -91,9 +90,10 @@ class TestTlsHandshakeBase : public KuduTest { // verification modes are set to 'client_verify' and 'server_verify' respectively. Status RunHandshake(TlsVerificationMode client_verify, TlsVerificationMode server_verify) { - TlsHandshake client, server; - RETURN_NOT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client)); - RETURN_NOT_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server)); + TlsHandshake client(TlsHandshakeType::CLIENT); + RETURN_NOT_OK(client_tls_.InitiateHandshake(&client)); + TlsHandshake server(TlsHandshakeType::SERVER); + RETURN_NOT_OK(server_tls_.InitiateHandshake(&server)); client.set_verification_mode(client_verify); server.set_verification_mode(server_verify); @@ -186,10 +186,10 @@ TEST_F(TestTlsHandshake, TestHandshakeSequence) { ASSERT_OK(ConfigureTlsContext(PkiConfig::SIGNED, ca_cert, ca_key, &client_tls_)); ASSERT_OK(ConfigureTlsContext(PkiConfig::SIGNED, ca_cert, ca_key, &server_tls_)); - TlsHandshake server; - TlsHandshake client; - ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server)); - ASSERT_OK(server_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client)); + TlsHandshake server(TlsHandshakeType::SERVER); + ASSERT_OK(client_tls_.InitiateHandshake(&server)); + TlsHandshake client(TlsHandshakeType::CLIENT); + ASSERT_OK(server_tls_.InitiateHandshake(&client)); string buf1; string buf2; diff --git a/src/kudu/security/tls_handshake.cc b/src/kudu/security/tls_handshake.cc index 52162c1..96e3b88 100644 --- a/src/kudu/security/tls_handshake.cc +++ b/src/kudu/security/tls_handshake.cc @@ -23,10 +23,12 @@ #include <memory> #include <string> +#include <utility> #include "kudu/gutil/strings/strip.h" #include "kudu/gutil/strings/substitute.h" #include "kudu/security/cert.h" +#include "kudu/security/openssl_util.h" #include "kudu/security/tls_socket.h" #include "kudu/util/net/socket.h" #include "kudu/util/status.h" @@ -79,22 +81,60 @@ void TlsHandshake::SetSSLVerify() { SSL_set_verify(ssl_.get(), ssl_mode, /* callback = */nullptr); } +TlsHandshake::TlsHandshake(TlsHandshakeType type) + : type_(type) { +} + +Status TlsHandshake::Init(c_unique_ptr<SSL> s) { + SCOPED_OPENSSL_NO_PENDING_ERRORS; + DCHECK(s); + + if (ssl_) { + return Status::IllegalState("TlsHandshake is already initialized"); + } + + auto rbio = ssl_make_unique(BIO_new(BIO_s_mem())); + if (!rbio) { + return Status::RuntimeError( + "failed to create memory-based read BIO", GetOpenSSLErrors()); + } + auto wbio = ssl_make_unique(BIO_new(BIO_s_mem())); + if (!wbio) { + return Status::RuntimeError( + "failed to create memory-based write BIO", GetOpenSSLErrors()); + } + ssl_ = std::move(s); + auto* ssl = ssl_.get(); + SSL_set_bio(ssl, rbio.release(), wbio.release()); + + switch (type_) { + case TlsHandshakeType::SERVER: + SSL_set_accept_state(ssl); + break; + case TlsHandshakeType::CLIENT: + SSL_set_connect_state(ssl); + break; + } + return Status::OK(); +} + Status TlsHandshake::Continue(const string& recv, string* send) { SCOPED_OPENSSL_NO_PENDING_ERRORS; if (!has_started_) { SetSSLVerify(); has_started_ = true; } - CHECK(ssl_); + DCHECK(ssl_); + auto* ssl = ssl_.get(); - BIO* rbio = SSL_get_rbio(ssl_.get()); + BIO* rbio = SSL_get_rbio(ssl); int n = BIO_write(rbio, recv.data(), recv.size()); DCHECK(n == recv.size() || (n == -1 && recv.empty())); DCHECK_EQ(BIO_ctrl_pending(rbio), recv.size()); - int rc = SSL_do_handshake(ssl_.get()); + int rc = SSL_do_handshake(ssl); if (rc != 1) { - int ssl_err = SSL_get_error(ssl_.get(), rc); + int ssl_err = SSL_get_error(ssl, rc); // WANT_READ and WANT_WRITE indicate that the handshake is not yet complete. if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) { return Status::RuntimeError("TLS Handshake error", GetSSLErrorDescription(ssl_err)); diff --git a/src/kudu/security/tls_handshake.h b/src/kudu/security/tls_handshake.h index da70331..41a7627 100644 --- a/src/kudu/security/tls_handshake.h +++ b/src/kudu/security/tls_handshake.h @@ -20,7 +20,6 @@ #include <functional> #include <memory> #include <string> -#include <utility> #include <glog/logging.h> @@ -65,13 +64,16 @@ enum class TlsVerificationMode { // TlsHandshake manages an ongoing TLS handshake between a client and server. // // TlsHandshake instances are default constructed, but must be initialized -// before use using TlsContext::InitiateHandshake. +// before using: call the Init() method to initialize an instance. class TlsHandshake { public: - - TlsHandshake() = default; + explicit TlsHandshake(TlsHandshakeType type); ~TlsHandshake() = default; + // Initialize the instance for the specified type of handshake + // using the given SSL handle. + Status Init(c_unique_ptr<SSL> s) WARN_UNUSED_RESULT; + // Set the verification mode for this handshake. The default verification mode // is VERIFY_REMOTE_CERT_AND_HOST. // @@ -136,21 +138,9 @@ class TlsHandshake { std::string GetCipherDescription() const; private: - friend class TlsContext; - - bool has_started_ = false; - TlsVerificationMode verification_mode_ = TlsVerificationMode::VERIFY_REMOTE_CERT_AND_HOST; - // Set the verification mode on the underlying SSL object. void SetSSLVerify(); - // Set the SSL to use during the handshake. Called once by - // TlsContext::InitiateHandshake before starting the handshake processes. - void adopt_ssl(c_unique_ptr<SSL> ssl) { - CHECK(!ssl_); - ssl_ = std::move(ssl); - } - SSL* ssl() { return ssl_.get(); } @@ -161,9 +151,15 @@ class TlsHandshake { // Verifies that the handshake is valid for the provided socket. Status Verify(const Socket& socket) const WARN_UNUSED_RESULT; + // The type of TLS handshake this wrapper represents: client or server. + const TlsHandshakeType type_; + // Owned SSL handle. c_unique_ptr<SSL> ssl_; + bool has_started_ = false; + TlsVerificationMode verification_mode_ = TlsVerificationMode::VERIFY_REMOTE_CERT_AND_HOST; + Cert local_cert_; Cert remote_cert_; }; diff --git a/src/kudu/security/tls_socket-test.cc b/src/kudu/security/tls_socket-test.cc index 2db41d8..53b1185 100644 --- a/src/kudu/security/tls_socket-test.cc +++ b/src/kudu/security/tls_socket-test.cc @@ -114,8 +114,8 @@ void TlsSocketTest::ConnectClient(const Sockaddr& addr, unique_ptr<Socket>* sock ASSERT_OK(client_sock->Init(addr.family(), 0)); ASSERT_OK(client_sock->Connect(addr)); - TlsHandshake client; - ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client)); + TlsHandshake client(TlsHandshakeType::CLIENT); + ASSERT_OK(client_tls_.InitiateHandshake(&client)); ASSERT_OK(DoNegotiationSide(client_sock.get(), &client, "client")); ASSERT_OK(client.Finish(&client_sock)); *sock = std::move(client_sock); @@ -146,8 +146,8 @@ class EchoServer { Sockaddr remote; CHECK_OK(listener_.Accept(sock.get(), &remote, /*flags=*/0)); - TlsHandshake server; - CHECK_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server)); + TlsHandshake server(TlsHandshakeType::SERVER); + CHECK_OK(server_tls_.InitiateHandshake(&server)); CHECK_OK(DoNegotiationSide(sock.get(), &server, "server")); CHECK_OK(server.Finish(&sock));
