This is an automated email from the ASF dual-hosted git repository.

alexey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kudu.git

commit bdeaa122f6b14c41b3640ab5d575b11d4c5ffe1c
Author: Alexey Serbin <[email protected]>
AuthorDate: Wed Mar 17 13:21:00 2021 -0700

    [security] small re-factoring on TlsHandshake
    
    A follow-up patch needs to know what side the TlsHandshake is running
    at.  This changelist introduces that piece of information into the
    TlsHandshake class and also does other related re-factoring in the
    related code.
    
    Change-Id: Id8dfb16d0f59dc467d1d4dcdf5c1bc568062087e
    Reviewed-on: http://gerrit.cloudera.org:8080/17193
    Tested-by: Kudu Jenkins
    Reviewed-by: Grant Henke <[email protected]>
    Reviewed-by: Attila Bukor <[email protected]>
---
 src/kudu/rpc/client_negotiation.cc      |  4 +--
 src/kudu/rpc/server_negotiation.cc      |  4 +--
 src/kudu/security/tls_context.cc        | 28 ++++++-------------
 src/kudu/security/tls_context.h         |  9 +++----
 src/kudu/security/tls_handshake-test.cc | 18 ++++++-------
 src/kudu/security/tls_handshake.cc      | 48 ++++++++++++++++++++++++++++++---
 src/kudu/security/tls_handshake.h       | 28 +++++++++----------
 src/kudu/security/tls_socket-test.cc    |  8 +++---
 8 files changed, 84 insertions(+), 63 deletions(-)

diff --git a/src/kudu/rpc/client_negotiation.cc 
b/src/kudu/rpc/client_negotiation.cc
index efa69a1..e5696b9 100644
--- a/src/kudu/rpc/client_negotiation.cc
+++ b/src/kudu/rpc/client_negotiation.cc
@@ -126,6 +126,7 @@ ClientNegotiation::ClientNegotiation(unique_ptr<Socket> 
socket,
     : socket_(std::move(socket)),
       helper_(SaslHelper::CLIENT),
       tls_context_(tls_context),
+      tls_handshake_(security::TlsHandshakeType::CLIENT),
       encryption_(encryption),
       tls_negotiated_(false),
       authn_token_(std::move(authn_token)),
@@ -191,8 +192,7 @@ Status 
ClientNegotiation::Negotiate(unique_ptr<ErrorStatusPB>* rpc_error) {
   // TODO(KUDU-1921): allow the client to require TLS.
   if (encryption_ != RpcEncryption::DISABLED &&
       ContainsKey(server_features_, TLS)) {
-    
RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::CLIENT,
-                                                  &tls_handshake_));
+    RETURN_NOT_OK(tls_context_->InitiateHandshake(&tls_handshake_));
 
     if (negotiated_authn_ == AuthenticationType::SASL) {
       // When using SASL authentication, verifying the server's certificate is
diff --git a/src/kudu/rpc/server_negotiation.cc 
b/src/kudu/rpc/server_negotiation.cc
index ac55399..c636856 100644
--- a/src/kudu/rpc/server_negotiation.cc
+++ b/src/kudu/rpc/server_negotiation.cc
@@ -160,6 +160,7 @@ ServerNegotiation::ServerNegotiation(unique_ptr<Socket> 
socket,
     : socket_(std::move(socket)),
       helper_(SaslHelper::SERVER),
       tls_context_(tls_context),
+      tls_handshake_(security::TlsHandshakeType::SERVER),
       encryption_(encryption),
       tls_negotiated_(false),
       token_verifier_(token_verifier),
@@ -223,8 +224,7 @@ Status ServerNegotiation::Negotiate() {
   if (encryption_ != RpcEncryption::DISABLED &&
       tls_context_->has_cert() &&
       ContainsKey(client_features_, TLS)) {
-    
RETURN_NOT_OK(tls_context_->InitiateHandshake(security::TlsHandshakeType::SERVER,
-                                                  &tls_handshake_));
+    RETURN_NOT_OK(tls_context_->InitiateHandshake(&tls_handshake_));
 
     if (negotiated_authn_ != AuthenticationType::CERTIFICATE) {
       // The server does not need to verify the client's certificate unless 
it's
diff --git a/src/kudu/security/tls_context.cc b/src/kudu/security/tls_context.cc
index dcc1ee4..e32e78a 100644
--- a/src/kudu/security/tls_context.cc
+++ b/src/kudu/security/tls_context.cc
@@ -531,33 +531,21 @@ Status TlsContext::LoadCertificateAuthority(const string& 
certificate_path) {
   return AddTrustedCertificate(c);
 }
 
-Status TlsContext::InitiateHandshake(TlsHandshakeType handshake_type,
-                                     TlsHandshake* handshake) const {
+Status TlsContext::InitiateHandshake(TlsHandshake* handshake) const {
   SCOPED_OPENSSL_NO_PENDING_ERRORS;
+  DCHECK(handshake);
   CHECK(ctx_);
-  CHECK(!handshake->ssl_);
+  c_unique_ptr<SSL> ssl;
   {
+    // This lock is to protect against concurrent change of certificates
+    // while calling SSL_new() here.
     shared_lock<RWMutex> lock(lock_);
-    handshake->adopt_ssl(ssl_make_unique(SSL_new(ctx_.get())));
+    ssl = ssl_make_unique(SSL_new(ctx_.get()));
   }
-  if (!handshake->ssl_) {
+  if (!ssl) {
     return Status::RuntimeError("failed to create SSL handle", 
GetOpenSSLErrors());
   }
-
-  SSL_set_bio(handshake->ssl(),
-              BIO_new(BIO_s_mem()),
-              BIO_new(BIO_s_mem()));
-
-  switch (handshake_type) {
-    case TlsHandshakeType::SERVER:
-      SSL_set_accept_state(handshake->ssl());
-      break;
-    case TlsHandshakeType::CLIENT:
-      SSL_set_connect_state(handshake->ssl());
-      break;
-  }
-
-  return Status::OK();
+  return handshake->Init(std::move(ssl));
 }
 
 } // namespace security
diff --git a/src/kudu/security/tls_context.h b/src/kudu/security/tls_context.h
index edf37f3..a13e838 100644
--- a/src/kudu/security/tls_context.h
+++ b/src/kudu/security/tls_context.h
@@ -26,19 +26,17 @@
 #include <boost/optional/optional.hpp>
 
 #include "kudu/gutil/port.h"
+#include "kudu/security/cert.h" // IWYU pragma: keep
 #include "kudu/security/openssl_util.h"
-#include "kudu/security/tls_handshake.h"
 #include "kudu/util/locks.h"
 #include "kudu/util/rw_mutex.h"
 #include "kudu/util/status.h"
-// IWYU pragma: no_include "kudu/security/cert.h"
 
 namespace kudu {
 namespace security {
 
-class Cert;           // IWYU pragma: keep
-class CertSignRequest;// IWYU pragma: keep
 class PrivateKey;
+class TlsHandshake;
 
 // TlsContext wraps data required by the OpenSSL library for creating and
 // accepting TLS protected channels. A single TlsContext instance should be 
used
@@ -162,8 +160,7 @@ class TlsContext {
   Status LoadCertificateAuthority(const std::string& certificate_path) 
WARN_UNUSED_RESULT;
 
   // Initiates a new TlsHandshake instance.
-  Status InitiateHandshake(TlsHandshakeType handshake_type,
-                           TlsHandshake* handshake) const WARN_UNUSED_RESULT;
+  Status InitiateHandshake(TlsHandshake* handshake) const WARN_UNUSED_RESULT;
 
   // Return the number of certs that have been marked as trusted.
   // Used by tests.
diff --git a/src/kudu/security/tls_handshake-test.cc 
b/src/kudu/security/tls_handshake-test.cc
index 5e6a71d..698af49 100644
--- a/src/kudu/security/tls_handshake-test.cc
+++ b/src/kudu/security/tls_handshake-test.cc
@@ -40,6 +40,7 @@
 #include "kudu/util/test_macros.h"
 #include "kudu/util/test_util.h"
 
+using kudu::security::ca::CertSigner;
 using std::string;
 using std::vector;
 
@@ -48,8 +49,6 @@ DECLARE_int32(ipki_server_key_size);
 namespace kudu {
 namespace security {
 
-using ca::CertSigner;
-
 struct Case {
   PkiConfig client_pki;
   TlsVerificationMode client_verification;
@@ -91,9 +90,10 @@ class TestTlsHandshakeBase : public KuduTest {
   // verification modes are set to 'client_verify' and 'server_verify' 
respectively.
   Status RunHandshake(TlsVerificationMode client_verify,
                       TlsVerificationMode server_verify) {
-    TlsHandshake client, server;
-    RETURN_NOT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, 
&client));
-    RETURN_NOT_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, 
&server));
+    TlsHandshake client(TlsHandshakeType::CLIENT);
+    RETURN_NOT_OK(client_tls_.InitiateHandshake(&client));
+    TlsHandshake server(TlsHandshakeType::SERVER);
+    RETURN_NOT_OK(server_tls_.InitiateHandshake(&server));
 
     client.set_verification_mode(client_verify);
     server.set_verification_mode(server_verify);
@@ -186,10 +186,10 @@ TEST_F(TestTlsHandshake, TestHandshakeSequence) {
   ASSERT_OK(ConfigureTlsContext(PkiConfig::SIGNED, ca_cert, ca_key, 
&client_tls_));
   ASSERT_OK(ConfigureTlsContext(PkiConfig::SIGNED, ca_cert, ca_key, 
&server_tls_));
 
-  TlsHandshake server;
-  TlsHandshake client;
-  ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::SERVER, &server));
-  ASSERT_OK(server_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client));
+  TlsHandshake server(TlsHandshakeType::SERVER);
+  ASSERT_OK(client_tls_.InitiateHandshake(&server));
+  TlsHandshake client(TlsHandshakeType::CLIENT);
+  ASSERT_OK(server_tls_.InitiateHandshake(&client));
 
   string buf1;
   string buf2;
diff --git a/src/kudu/security/tls_handshake.cc 
b/src/kudu/security/tls_handshake.cc
index 52162c1..96e3b88 100644
--- a/src/kudu/security/tls_handshake.cc
+++ b/src/kudu/security/tls_handshake.cc
@@ -23,10 +23,12 @@
 
 #include <memory>
 #include <string>
+#include <utility>
 
 #include "kudu/gutil/strings/strip.h"
 #include "kudu/gutil/strings/substitute.h"
 #include "kudu/security/cert.h"
+#include "kudu/security/openssl_util.h"
 #include "kudu/security/tls_socket.h"
 #include "kudu/util/net/socket.h"
 #include "kudu/util/status.h"
@@ -79,22 +81,60 @@ void TlsHandshake::SetSSLVerify() {
   SSL_set_verify(ssl_.get(), ssl_mode, /* callback = */nullptr);
 }
 
+TlsHandshake::TlsHandshake(TlsHandshakeType type)
+    : type_(type) {
+}
+
+Status TlsHandshake::Init(c_unique_ptr<SSL> s) {
+  SCOPED_OPENSSL_NO_PENDING_ERRORS;
+  DCHECK(s);
+
+  if (ssl_) {
+    return Status::IllegalState("TlsHandshake is already initialized");
+  }
+
+  auto rbio = ssl_make_unique(BIO_new(BIO_s_mem()));
+  if (!rbio) {
+    return Status::RuntimeError(
+        "failed to create memory-based read BIO", GetOpenSSLErrors());
+  }
+  auto wbio = ssl_make_unique(BIO_new(BIO_s_mem()));
+  if (!wbio) {
+    return Status::RuntimeError(
+        "failed to create memory-based write BIO", GetOpenSSLErrors());
+  }
+  ssl_ = std::move(s);
+  auto* ssl = ssl_.get();
+  SSL_set_bio(ssl, rbio.release(), wbio.release());
+
+  switch (type_) {
+    case TlsHandshakeType::SERVER:
+      SSL_set_accept_state(ssl);
+      break;
+    case TlsHandshakeType::CLIENT:
+      SSL_set_connect_state(ssl);
+      break;
+  }
+  return Status::OK();
+}
+
 Status TlsHandshake::Continue(const string& recv, string* send) {
   SCOPED_OPENSSL_NO_PENDING_ERRORS;
   if (!has_started_) {
     SetSSLVerify();
     has_started_ = true;
   }
-  CHECK(ssl_);
+  DCHECK(ssl_);
+  auto* ssl = ssl_.get();
 
-  BIO* rbio = SSL_get_rbio(ssl_.get());
+  BIO* rbio = SSL_get_rbio(ssl);
   int n = BIO_write(rbio, recv.data(), recv.size());
   DCHECK(n == recv.size() || (n == -1 && recv.empty()));
   DCHECK_EQ(BIO_ctrl_pending(rbio), recv.size());
 
-  int rc = SSL_do_handshake(ssl_.get());
+  int rc = SSL_do_handshake(ssl);
   if (rc != 1) {
-    int ssl_err = SSL_get_error(ssl_.get(), rc);
+    int ssl_err = SSL_get_error(ssl, rc);
     // WANT_READ and WANT_WRITE indicate that the handshake is not yet 
complete.
     if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) {
       return Status::RuntimeError("TLS Handshake error", 
GetSSLErrorDescription(ssl_err));
diff --git a/src/kudu/security/tls_handshake.h 
b/src/kudu/security/tls_handshake.h
index da70331..41a7627 100644
--- a/src/kudu/security/tls_handshake.h
+++ b/src/kudu/security/tls_handshake.h
@@ -20,7 +20,6 @@
 #include <functional>
 #include <memory>
 #include <string>
-#include <utility>
 
 #include <glog/logging.h>
 
@@ -65,13 +64,16 @@ enum class TlsVerificationMode {
 // TlsHandshake manages an ongoing TLS handshake between a client and server.
 //
 // TlsHandshake instances are default constructed, but must be initialized
-// before use using TlsContext::InitiateHandshake.
+// before using: call the Init() method to initialize an instance.
 class TlsHandshake {
  public:
-
-   TlsHandshake() = default;
+   explicit TlsHandshake(TlsHandshakeType type);
    ~TlsHandshake() = default;
 
+   // Initialize the instance for the specified type of handshake
+   // using the given SSL handle.
+   Status Init(c_unique_ptr<SSL> s) WARN_UNUSED_RESULT;
+
   // Set the verification mode for this handshake. The default verification 
mode
   // is VERIFY_REMOTE_CERT_AND_HOST.
   //
@@ -136,21 +138,9 @@ class TlsHandshake {
   std::string GetCipherDescription() const;
 
  private:
-  friend class TlsContext;
-
-  bool has_started_ = false;
-  TlsVerificationMode verification_mode_ = 
TlsVerificationMode::VERIFY_REMOTE_CERT_AND_HOST;
-
   // Set the verification mode on the underlying SSL object.
   void SetSSLVerify();
 
-  // Set the SSL to use during the handshake. Called once by
-  // TlsContext::InitiateHandshake before starting the handshake processes.
-  void adopt_ssl(c_unique_ptr<SSL> ssl) {
-    CHECK(!ssl_);
-    ssl_ = std::move(ssl);
-  }
-
   SSL* ssl() {
     return ssl_.get();
   }
@@ -161,9 +151,15 @@ class TlsHandshake {
   // Verifies that the handshake is valid for the provided socket.
   Status Verify(const Socket& socket) const WARN_UNUSED_RESULT;
 
+  // The type of TLS handshake this wrapper represents: client or server.
+  const TlsHandshakeType type_;
+
   // Owned SSL handle.
   c_unique_ptr<SSL> ssl_;
 
+  bool has_started_ = false;
+  TlsVerificationMode verification_mode_ = 
TlsVerificationMode::VERIFY_REMOTE_CERT_AND_HOST;
+
   Cert local_cert_;
   Cert remote_cert_;
 };
diff --git a/src/kudu/security/tls_socket-test.cc 
b/src/kudu/security/tls_socket-test.cc
index 2db41d8..53b1185 100644
--- a/src/kudu/security/tls_socket-test.cc
+++ b/src/kudu/security/tls_socket-test.cc
@@ -114,8 +114,8 @@ void TlsSocketTest::ConnectClient(const Sockaddr& addr, 
unique_ptr<Socket>* sock
   ASSERT_OK(client_sock->Init(addr.family(), 0));
   ASSERT_OK(client_sock->Connect(addr));
 
-  TlsHandshake client;
-  ASSERT_OK(client_tls_.InitiateHandshake(TlsHandshakeType::CLIENT, &client));
+  TlsHandshake client(TlsHandshakeType::CLIENT);
+  ASSERT_OK(client_tls_.InitiateHandshake(&client));
   ASSERT_OK(DoNegotiationSide(client_sock.get(), &client, "client"));
   ASSERT_OK(client.Finish(&client_sock));
   *sock = std::move(client_sock);
@@ -146,8 +146,8 @@ class EchoServer {
         Sockaddr remote;
         CHECK_OK(listener_.Accept(sock.get(), &remote, /*flags=*/0));
 
-        TlsHandshake server;
-        CHECK_OK(server_tls_.InitiateHandshake(TlsHandshakeType::SERVER, 
&server));
+        TlsHandshake server(TlsHandshakeType::SERVER);
+        CHECK_OK(server_tls_.InitiateHandshake(&server));
         CHECK_OK(DoNegotiationSide(sock.get(), &server, "server"));
         CHECK_OK(server.Finish(&sock));
 

Reply via email to