This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new f3d796dd Support TLS for replication (#1630)
f3d796dd is described below

commit f3d796dd4a25de4e8a0e3bb14bbd1043aed5c9c6
Author: Twice <[email protected]>
AuthorDate: Tue Sep 12 13:42:27 2023 +0900

    Support TLS for replication (#1630)
---
 kvrocks.conf                    |   7 +++
 src/cluster/replication.cc      |  74 +++++++++++++++++-----
 src/cluster/replication.h       |   7 ++-
 src/commands/cmd_replication.cc |  10 +--
 src/common/io_util.cc           | 133 +++++++++++++++++++++++++++++++++++++---
 src/common/io_util.h            |  14 +++++
 src/config/config.cc            |   1 +
 src/config/config.h             |   3 +
 src/main.cc                     |   2 +-
 src/server/server.cc            |   2 +-
 src/server/worker.cc            |   4 +-
 tests/gocase/tls/tls_test.go    |  57 +++++++++++++++++
 tests/gocase/util/server.go     |   4 ++
 13 files changed, 279 insertions(+), 39 deletions(-)

diff --git a/kvrocks.conf b/kvrocks.conf
index 6a015319..d2b027cb 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -381,6 +381,13 @@ redis-cursor-compatible no
 #
 # tls-session-cache-timeout 60
 
+# By default, a replica does not attempt to establish a TLS connection
+# with its master.
+#
+# Use the following directive to enable TLS on replication links.
+#
+# tls-replication yes
+
 ################################## SLOW LOG ###################################
 
 # The Kvrocks Slow Log is a mechanism to log queries that exceeded a specified
diff --git a/src/cluster/replication.cc b/src/cluster/replication.cc
index d6d8d281..38548418 100644
--- a/src/cluster/replication.cc
+++ b/src/cluster/replication.cc
@@ -37,6 +37,7 @@
 #include "fmt/format.h"
 #include "io_util.h"
 #include "rocksdb_crc32c.h"
+#include "scope_exit.h"
 #include "server/redis_reply.h"
 #include "server/server.h"
 #include "status.h"
@@ -45,6 +46,12 @@
 #include "time_util.h"
 #include "unique_fd.h"
 
+#ifdef ENABLE_OPENSSL
+#include <event2/bufferevent_ssl.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#endif
+
 Status FeedSlaveThread::Start() {
   auto s = util::CreateThread("feed-replica", [this] {
     sigset_t mask, omask;
@@ -54,7 +61,7 @@ Status FeedSlaveThread::Start() {
     sigaddset(&mask, SIGHUP);
     sigaddset(&mask, SIGPIPE);
     pthread_sigmask(SIG_BLOCK, &mask, &omask);
-    auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"));
+    auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"), 
conn_->GetBufferEvent());
     if (!s.IsOK()) {
       LOG(ERROR) << "failed to send OK response to the replica: " << s.Msg();
       return;
@@ -85,7 +92,7 @@ void FeedSlaveThread::Join() {
 void FeedSlaveThread::checkLivenessIfNeed() {
   if (++interval_ % 1000) return;
   const auto ping_command = redis::BulkString("ping");
-  auto s = util::SockSend(conn_->GetFD(), ping_command);
+  auto s = util::SockSend(conn_->GetFD(), ping_command, 
conn_->GetBufferEvent());
   if (!s.IsOK()) {
     LOG(ERROR) << "Ping slave[" << conn_->GetAddr() << "] err: " << s.Msg() << 
", would stop the thread";
     Stop();
@@ -134,7 +141,7 @@ void FeedSlaveThread::loop() {
     if (is_first_repl_batch || batches_bulk.size() >= kMaxDelayBytes || 
updates_in_batches >= kMaxDelayUpdates ||
         srv_->storage->LatestSeqNumber() - batch.sequence <= kMaxDelayUpdates) 
{
       // Send entire bulk which contain multiple batches
-      auto s = util::SockSend(conn_->GetFD(), batches_bulk);
+      auto s = util::SockSend(conn_->GetFD(), batches_bulk, 
conn_->GetBufferEvent());
       if (!s.IsOK()) {
         LOG(ERROR) << "Write error while sending batch to slave: " << s.Msg() 
<< ". batches: 0x"
                    << util::StringToHex(batches_bulk);
@@ -257,12 +264,35 @@ void ReplicationThread::CallbacksStateMachine::Start() {
       LOG(ERROR) << "[replication] Failed to connect the master, err: " << 
cfd.Msg();
       continue;
     }
+#ifdef ENABLE_OPENSSL
+    SSL *ssl = nullptr;
+    if (repl_->srv_->GetConfig()->tls_replication) {
+      ssl = SSL_new(repl_->srv_->ssl_ctx.get());
+      if (!ssl) {
+        LOG(ERROR) << "Failed to construct SSL structure for new connection: " 
<< SSLErrors{};
+        evutil_closesocket(*cfd);
+        return;
+      }
+      bev = bufferevent_openssl_socket_new(repl_->base_, *cfd, ssl, 
BUFFEREVENT_SSL_CONNECTING, BEV_OPT_CLOSE_ON_FREE);
+    } else {
+      bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
+    }
+#else
     bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
+#endif
     if (bev == nullptr) {
+#ifdef ENABLE_OPENSSL
+      if (ssl) SSL_free(ssl);
+#endif
       close(*cfd);
       LOG(ERROR) << "[replication] Failed to create the event socket";
       continue;
     }
+#ifdef ENABLE_OPENSSL
+    if (repl_->srv_->GetConfig()->tls_replication) {
+      bufferevent_openssl_set_allow_dirty_shutdown(bev, 1);
+    }
+#endif
   }
   if (bev == nullptr) {  // failed to connect the master and received the stop 
signal
     return;
@@ -728,9 +758,19 @@ Status ReplicationThread::parallelFetchFile(const 
std::string &dir,
           if (this->stop_flag_) {
             return {Status::NotOK, "replication thread was stopped"};
           }
-          int sock_fd = GET_OR_RET(util::SockConnect(this->host_, 
this->port_).Prefixed("connect the server err"));
+          ssl_st *ssl = nullptr;
+#ifdef ENABLE_OPENSSL
+          if (this->srv_->GetConfig()->tls_replication) {
+            ssl = SSL_new(this->srv_->ssl_ctx.get());
+          }
+          auto exit = MakeScopeExit([ssl] { SSL_free(ssl); });
+#endif
+          int sock_fd = GET_OR_RET(util::SockConnect(this->host_, this->port_, 
ssl).Prefixed("connect the server err"));
+#ifdef ENABLE_OPENSSL
+          exit.Disable();
+#endif
           UniqueFD unique_fd{sock_fd};
-          auto s = this->sendAuth(sock_fd);
+          auto s = this->sendAuth(sock_fd, ssl);
           if (!s.IsOK()) {
             return s.Prefixed("send the auth command err");
           }
@@ -770,12 +810,12 @@ Status ReplicationThread::parallelFetchFile(const 
std::string &dir,
           // command, so we need to fetch all files by multiple command 
interactions.
           if (srv_->GetConfig()->master_use_repl_port) {
             for (unsigned i = 0; i < fetch_files.size(); i++) {
-              s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]}, 
fn);
+              s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]}, 
fn, ssl);
               if (!s.IsOK()) break;
             }
           } else {
             if (!fetch_files.empty()) {
-              s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn);
+              s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn, ssl);
             }
           }
           return s;
@@ -790,13 +830,13 @@ Status ReplicationThread::parallelFetchFile(const 
std::string &dir,
   return Status::OK();
 }
 
-Status ReplicationThread::sendAuth(int sock_fd) {
+Status ReplicationThread::sendAuth(int sock_fd, ssl_st *ssl) {
   // Send auth when needed
   std::string auth = srv_->GetConfig()->masterauth;
   if (!auth.empty()) {
     UniqueEvbuf evbuf;
     const auto auth_command = redis::MultiBulkString({"AUTH", auth});
-    auto s = util::SockSend(sock_fd, auth_command);
+    auto s = util::SockSend(sock_fd, auth_command, ssl);
     if (!s.IsOK()) return s.Prefixed("send auth command err");
     while (true) {
       if (evbuffer_read(evbuf.get(), sock_fd, -1) <= 0) {
@@ -814,15 +854,15 @@ Status ReplicationThread::sendAuth(int sock_fd) {
 }
 
 Status ReplicationThread::fetchFile(int sock_fd, evbuffer *evbuf, const 
std::string &dir, const std::string &file,
-                                    uint32_t crc, const FetchFileCallback &fn) 
{
+                                    uint32_t crc, const FetchFileCallback &fn, 
ssl_st *ssl) {
   size_t file_size = 0;
 
   // Read file size line
   while (true) {
     UniqueEvbufReadln line(evbuf, EVBUFFER_EOL_CRLF_STRICT);
     if (!line) {
-      if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
-        return {Status::NotOK, fmt::format("read size: {}", strerror(errno))};
+      if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
+        return std::move(s).Prefixed("read size");
       }
       continue;
     }
@@ -854,8 +894,8 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer 
*evbuf, const std::str
       tmp_crc = rocksdb::crc32c::Extend(tmp_crc, data, data_len);
       remain -= data_len;
     } else {
-      if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
-        return {Status::NotOK, fmt::format("read sst file: {}", 
strerror(errno))};
+      if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
+        return std::move(s).Prefixed("read sst file");
       }
     }
   }
@@ -873,7 +913,7 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer 
*evbuf, const std::str
 }
 
 Status ReplicationThread::fetchFiles(int sock_fd, const std::string &dir, 
const std::vector<std::string> &files,
-                                     const std::vector<uint32_t> &crcs, const 
FetchFileCallback &fn) {
+                                     const std::vector<uint32_t> &crcs, const 
FetchFileCallback &fn, ssl_st *ssl) {
   std::string files_str;
   for (const auto &file : files) {
     files_str += file;
@@ -882,13 +922,13 @@ Status ReplicationThread::fetchFiles(int sock_fd, const 
std::string &dir, const
   files_str.pop_back();
 
   const auto fetch_command = redis::MultiBulkString({"_fetch_file", 
files_str});
-  auto s = util::SockSend(sock_fd, fetch_command);
+  auto s = util::SockSend(sock_fd, fetch_command, ssl);
   if (!s.IsOK()) return s.Prefixed("send fetch file command");
 
   UniqueEvbuf evbuf;
   for (unsigned i = 0; i < files.size(); i++) {
     DLOG(INFO) << "[fetch] Start to fetch file " << files[i];
-    s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn);
+    s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn, ssl);
     if (!s.IsOK()) {
       s = Status(Status::NotOK, "fetch file err: " + s.Msg());
       LOG(WARNING) << "[fetch] Fail to fetch file " << files[i] << ", err: " 
<< s.Msg();
diff --git a/src/cluster/replication.h b/src/cluster/replication.h
index 6bf5954b..b7f49717 100644
--- a/src/cluster/replication.h
+++ b/src/cluster/replication.h
@@ -32,6 +32,7 @@
 #include <vector>
 
 #include "event_util.h"
+#include "io_util.h"
 #include "server/redis_connection.h"
 #include "status.h"
 #include "storage/storage.h"
@@ -197,11 +198,11 @@ class ReplicationThread : private 
EventCallbackBase<ReplicationThread> {
   CBState fullSyncReadCB(bufferevent *bev);
 
   // Synchronized-Blocking ops
-  Status sendAuth(int sock_fd);
+  Status sendAuth(int sock_fd, ssl_st *ssl);
   Status fetchFile(int sock_fd, evbuffer *evbuf, const std::string &dir, const 
std::string &file, uint32_t crc,
-                   const FetchFileCallback &fn);
+                   const FetchFileCallback &fn, ssl_st *ssl);
   Status fetchFiles(int sock_fd, const std::string &dir, const 
std::vector<std::string> &files,
-                    const std::vector<uint32_t> &crcs, const FetchFileCallback 
&fn);
+                    const std::vector<uint32_t> &crcs, const FetchFileCallback 
&fn, ssl_st *ssl);
   Status parallelFetchFile(const std::string &dir, const 
std::vector<std::pair<std::string, uint32_t>> &files);
   static bool isRestoringError(const char *err);
   static bool isWrongPsyncNum(const char *err);
diff --git a/src/commands/cmd_replication.cc b/src/commands/cmd_replication.cc
index 8ccdfc19..23ae269a 100644
--- a/src/commands/cmd_replication.cc
+++ b/src/commands/cmd_replication.cc
@@ -102,7 +102,7 @@ class CommandPSync : public Commander {
     s = svr->AddSlave(conn, next_repl_seq_);
     if (!s.IsOK()) {
       std::string err = "-ERR " + s.Msg() + "\r\n";
-      s = util::SockSend(conn->GetFD(), err);
+      s = util::SockSend(conn->GetFD(), err, conn->GetBufferEvent());
       if (!s.IsOK()) {
         LOG(WARNING) << "failed to send error message to the replica: " << 
s.Msg();
       }
@@ -229,7 +229,7 @@ class CommandFetchMeta : public Commander {
       std::string files;
       auto s = 
engine::Storage::ReplDataManager::GetFullReplDataInfo(svr->storage, &files);
       if (!s.IsOK()) {
-        s = util::SockSend(repl_fd, "-ERR can't create db checkpoint");
+        s = util::SockSend(repl_fd, "-ERR can't create db checkpoint", bev);
         if (!s.IsOK()) {
           LOG(WARNING) << "[replication] Failed to send error response: " << 
s.Msg();
         }
@@ -237,7 +237,7 @@ class CommandFetchMeta : public Commander {
         return;
       }
       // Send full data file info
-      if (util::SockSend(repl_fd, files + CRLF).IsOK()) {
+      if (util::SockSend(repl_fd, files + CRLF, bev).IsOK()) {
         LOG(INFO) << "[replication] Succeed sending full data file info to " 
<< ip;
       } else {
         LOG(WARNING) << "[replication] Fail to send full data file info " << 
ip << ", error: " << strerror(errno);
@@ -291,8 +291,8 @@ class CommandFetchFile : public Commander {
         if (!fd) break;
 
         // Send file size and content
-        if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF).IsOK() &&
-            util::SockSendFile(repl_fd, *fd, file_size).IsOK()) {
+        if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF, 
bev).IsOK() &&
+            util::SockSendFile(repl_fd, *fd, file_size, bev).IsOK()) {
           LOG(INFO) << "[replication] Succeed sending file " << file << " to " 
<< ip;
         } else {
           LOG(WARNING) << "[replication] Fail to send file " << file << " to " 
<< ip << ", error: " << strerror(errno);
diff --git a/src/common/io_util.cc b/src/common/io_util.cc
index b4779f86..2c273580 100644
--- a/src/common/io_util.cc
+++ b/src/common/io_util.cc
@@ -29,10 +29,19 @@
 #include <poll.h>
 #include <sys/types.h>
 
+#include "fmt/ostream.h"
+#include "server/tls_util.h"
+
 #ifdef __linux__
 #include <sys/sendfile.h>
 #endif
 
+#ifdef ENABLE_OPENSSL
+#include <openssl/ssl.h>
+
+#include "event2/bufferevent_ssl.h"
+#endif
+
 #include "event_util.h"
 #include "scope_exit.h"
 #include "unique_fd.h"
@@ -194,7 +203,7 @@ StatusOr<int> SockConnect(const std::string &host, uint32_t 
port, int conn_timeo
 // NOTE: fd should be blocking here
 Status SockSend(int fd, const std::string &data) { return Write(fd, data); }
 
-// Implements SockSendFileCore to transfer data between file descriptors and
+// Implements SockSendFileImpl to transfer data between file descriptors and
 // avoid transferring data to and from user space.
 //
 // The function prototype is just like sendfile(2) on Linux. in_fd is a file
@@ -204,7 +213,7 @@ Status SockSend(int fd, const std::string &data) { return 
Write(fd, data); }
 //
 // The return value is the number of bytes written to out_fd, if the transfer
 // was successful. On error, -1 is returned, and errno is set appropriately.
-ssize_t SockSendFileCore(int out_fd, int in_fd, off_t offset, size_t count) {
+ssize_t SendFileImpl(int out_fd, int in_fd, off_t offset, size_t count) {
 #if defined(__linux__)
   return sendfile(out_fd, in_fd, &offset, count);
 
@@ -215,18 +224,37 @@ ssize_t SockSendFileCore(int out_fd, int in_fd, off_t 
offset, size_t count) {
   else
     return (ssize_t)len;
 
-#endif
+#else
   errno = ENOSYS;
   return -1;
+
+#endif
 }
 
-// Send file by sendfile actually according to different operation systems,
-// please note that, the out socket fd should be in blocking mode.
-Status SockSendFile(int out_fd, int in_fd, size_t size) {
+#ifdef ENABLE_OPENSSL
+ssize_t SendFileSSLImpl(ssl_st *ssl, int in_fd, off_t offset, size_t count) {
+  constexpr size_t BUFFER_SIZE = 16 * 1024;
+  char buf[BUFFER_SIZE];
+  if (off_t ret = lseek(in_fd, offset, SEEK_SET); ret == -1) {
+    return -1;
+  }
+  count = count <= BUFFER_SIZE ? count : BUFFER_SIZE;
+  if (ssize_t ret = read(in_fd, buf, count); ret == -1) {
+    return -1;
+  } else {
+    count = ret;
+  }
+  return SSL_write(ssl, buf, (int)count);
+}
+#endif
+
+template <auto F, typename FD, typename... Args>
+Status SockSendFileImpl(FD out_fd, int in_fd, size_t size, Args... args) {
+  constexpr size_t BUFFER_SIZE = 16 * 1024;
   off_t offset = 0;
   while (size != 0) {
-    size_t n = size <= 16 * 1024 ? size : 16 * 1024;
-    ssize_t nwritten = SockSendFileCore(out_fd, in_fd, offset, n);
+    size_t n = size <= BUFFER_SIZE ? size : BUFFER_SIZE;
+    ssize_t nwritten = F(out_fd, in_fd, offset, n, args...);
     if (nwritten == -1) {
       if (errno == EINTR)
         continue;
@@ -239,6 +267,27 @@ Status SockSendFile(int out_fd, int in_fd, size_t size) {
   return Status::OK();
 }
 
+// Send file by sendfile actually according to different operation systems,
+// please note that, the out socket fd should be in blocking mode.
+Status SockSendFile(int out_fd, int in_fd, size_t size) { return 
SockSendFileImpl<SendFileImpl>(out_fd, in_fd, size); }
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+  if (ssl) {
+    return SockSendFileImpl<SendFileSSLImpl>(ssl, in_fd, size);
+  }
+#endif
+  return SockSendFile(out_fd, in_fd, size);
+}
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, bufferevent *bev) {
+#ifdef ENABLE_OPENSSL
+  return SockSendFile(out_fd, in_fd, size, bufferevent_openssl_get_ssl(bev));
+#else
+  return SockSendFile(out_fd, in_fd, size);
+#endif
+}
+
 Status SockSetBlocking(int fd, int blocking) {
   int flags = 0;
   // Old flags
@@ -384,8 +433,8 @@ std::vector<std::string> GetLocalIPAddresses() {
   return ip_addresses;
 }
 
-template <auto syscall, typename... Args>
-Status WriteImpl(int fd, std::string_view data, Args &&...args) {
+template <auto syscall, typename FD, typename... Args>
+Status WriteImpl(FD fd, std::string_view data, Args &&...args) {
   ssize_t n = 0;
   while (n < static_cast<ssize_t>(data.size())) {
     ssize_t nwritten = syscall(fd, data.data() + n, data.size() - n, 
std::forward<Args>(args)...);
@@ -401,4 +450,68 @@ Status Write(int fd, const std::string &data) { return 
WriteImpl<write>(fd, data
 
 Status Pwrite(int fd, const std::string &data, off_t offset) { return 
WriteImpl<pwrite>(fd, data, offset); }
 
+Status SockSend(int fd, const std::string &data, ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+  if (ssl) {
+    return WriteImpl<SSL_write>(ssl, data);
+  }
+#endif
+  return SockSend(fd, data);
+}
+
+Status SockSend(int fd, const std::string &data, bufferevent *bev) {
+#ifdef ENABLE_OPENSSL
+  return SockSend(fd, data, bufferevent_openssl_get_ssl(bev));
+#else
+  return SockSend(fd, data);
+#endif
+}
+
+StatusOr<int> SockConnect(const std::string &host, uint32_t port, ssl_st *ssl, 
int conn_timeout, int timeout) {
+#ifdef ENABLE_OPENSSL
+  if (ssl) {
+    auto fd = GET_OR_RET(SockConnect(host, port, conn_timeout, timeout));
+    SSL_set_fd(ssl, fd);
+
+    auto bio = BIO_new_socket(fd, BIO_NOCLOSE);
+    SSL_set_bio(ssl, bio, bio);
+
+    if (int err = SSL_connect(ssl); err != 1) {
+      BIO_free(bio);
+      return {Status::NotOK, fmt::format("socket failed to do SSL handshake: 
{}", fmt::streamed(SSLError(err)))};
+    }
+
+    return fd;
+  }
+#endif
+  return SockConnect(host, port, conn_timeout, timeout);
+}
+
+StatusOr<int> EvbufferRead(evbuffer *buf, evutil_socket_t fd, int howmuch, 
ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+  if (ssl) {
+    constexpr int BUFFER_SIZE = 4096;
+    char tmp[BUFFER_SIZE];
+
+    if (howmuch <= 0 || howmuch > BUFFER_SIZE) {
+      howmuch = BUFFER_SIZE;
+    }
+    if (howmuch = SSL_read(ssl, tmp, howmuch); howmuch <= 0) {
+      return {Status::NotOK, fmt::format("failed to read from SSL connection: 
{}", fmt::streamed(SSLError(howmuch)))};
+    }
+
+    if (int ret = evbuffer_add(buf, tmp, howmuch); ret == -1) {
+      return {Status::NotOK, fmt::format("failed to add buffer: {}", 
strerror(errno))};
+    }
+
+    return howmuch;
+  }
+#endif
+  if (int ret = evbuffer_read(buf, fd, howmuch); ret > 0) {
+    return ret;
+  } else {
+    return {Status::NotOK, fmt::format("failed to read from socket: {}", 
strerror(errno))};
+  }
+}
+
 }  // namespace util
diff --git a/src/common/io_util.h b/src/common/io_util.h
index 9ab288ae..d30789ae 100644
--- a/src/common/io_util.h
+++ b/src/common/io_util.h
@@ -24,6 +24,11 @@
 
 #include "status.h"
 
+// forward declarations
+struct ssl_st;
+struct bufferevent;
+struct evbuffer;
+
 namespace util {
 
 sockaddr_in NewSockaddrInet(const std::string &host, uint32_t port);
@@ -46,4 +51,13 @@ int AeWait(int fd, int mask, int milliseconds);
 Status Write(int fd, const std::string &data);
 Status Pwrite(int fd, const std::string &data, off_t offset);
 
+Status SockSend(int fd, const std::string &data, ssl_st *ssl);
+Status SockSend(int fd, const std::string &data, bufferevent *bev);
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, ssl_st *ssl);
+Status SockSendFile(int out_fd, int in_fd, size_t size, bufferevent *bev);
+
+StatusOr<int> SockConnect(const std::string &host, uint32_t port, ssl_st *ssl, 
int conn_timeout = 0, int timeout = 0);
+StatusOr<int> EvbufferRead(evbuffer *buf, int fd, int howmuch, ssl_st *ssl);
+
 }  // namespace util
diff --git a/src/config/config.cc b/src/config/config.cc
index e051b5e4..e12edc38 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -103,6 +103,7 @@ Config::Config() {
       {"tls-session-caching", false, new YesNoField(&tls_session_caching, 
true)},
       {"tls-session-cache-size", false, new IntField(&tls_session_cache_size, 
1024 * 20, 0, INT_MAX)},
       {"tls-session-cache-timeout", false, new 
IntField(&tls_session_cache_timeout, 300, 0, INT_MAX)},
+      {"tls-replication", true, new YesNoField(&tls_replication, false)},
 #endif
       {"workers", true, new IntField(&workers, 8, 1, 256)},
       {"timeout", false, new IntField(&timeout, 0, 0, INT_MAX)},
diff --git a/src/config/config.h b/src/config/config.h
index dba87c94..b93496b7 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -74,6 +74,7 @@ struct Config {
   Config();
   ~Config() = default;
   uint32_t port = 0;
+
   uint32_t tls_port = 0;
   std::string tls_cert_file;
   std::string tls_key_file;
@@ -88,6 +89,8 @@ struct Config {
   bool tls_session_caching = true;
   int tls_session_cache_size = 1024 * 20;
   int tls_session_cache_timeout = 300;
+  bool tls_replication = false;
+
   int workers = 0;
   int timeout = 0;
   int log_level = 0;
diff --git a/src/main.cc b/src/main.cc
index 1a2cefc5..a04f8aa8 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -355,7 +355,7 @@ int main(int argc, char *argv[]) {
 
 #ifdef ENABLE_OPENSSL
   // initialize OpenSSL
-  if (config.tls_port) {
+  if (config.tls_port || config.tls_replication) {
     InitSSL();
   }
 #endif
diff --git a/src/server/server.cc b/src/server/server.cc
index 254f3bd5..4d49bd5f 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -66,7 +66,7 @@ Server::Server(engine::Storage *storage, Config *config)
 
 #ifdef ENABLE_OPENSSL
   // init ssl context
-  if (config->tls_port) {
+  if (config->tls_port || config->tls_replication) {
     ssl_ctx = CreateSSLContext(config);
     if (!ssl_ctx) {
       exit(1);
diff --git a/src/server/worker.cc b/src/server/worker.cc
index 3383c407..9506d566 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -129,8 +129,8 @@ void Worker::newTCPConnection(evconnlistener *listener, 
evutil_socket_t fd, sock
       BEV_OPT_THREADSAFE | BEV_OPT_DEFER_CALLBACKS | BEV_OPT_UNLOCK_CALLBACKS 
| BEV_OPT_CLOSE_ON_FREE;
 
   bufferevent *bev = nullptr;
+  ssl_st *ssl = nullptr;
 #ifdef ENABLE_OPENSSL
-  SSL *ssl = nullptr;
   if (uint32_t(local_port) == svr->GetConfig()->tls_port) {
     ssl = SSL_new(svr->ssl_ctx.get());
     if (!ssl) {
@@ -168,7 +168,7 @@ void Worker::newTCPConnection(evconnlistener *listener, 
evutil_socket_t fd, sock
   s = AddConnection(conn);
   if (!s.IsOK()) {
     std::string err_msg = redis::Error("ERR " + s.Msg());
-    s = util::SockSend(fd, err_msg);
+    s = util::SockSend(fd, err_msg, ssl);
     if (!s.IsOK()) {
       LOG(WARNING) << "Failed to send error response to socket: " << s.Msg();
     }
diff --git a/tests/gocase/tls/tls_test.go b/tests/gocase/tls/tls_test.go
index 77df2202..460f53c4 100644
--- a/tests/gocase/tls/tls_test.go
+++ b/tests/gocase/tls/tls_test.go
@@ -22,7 +22,9 @@ package tls
 import (
        "context"
        "crypto/tls"
+       "fmt"
        "testing"
+       "time"
 
        "github.com/apache/kvrocks/tests/gocase/util"
        "github.com/redis/go-redis/v9"
@@ -136,3 +138,58 @@ func TestTLS(t *testing.T) {
                require.NoError(t, rdb.ConfigSet(ctx, "tls-ciphers", 
"DEFAULT").Err())
        })
 }
+
+func TestTLSReplica(t *testing.T) {
+       if !util.TLSEnable() {
+               t.Skip("TLS tests run only if tls enabled.")
+       }
+
+       ctx := context.Background()
+
+       srv := util.StartTLSServer(t, map[string]string{})
+       defer srv.Close()
+
+       defaultTLSConfig, err := util.DefaultTLSConfig()
+       require.NoError(t, err)
+
+       sc := srv.NewClientWithOption(&redis.Options{TLSConfig: 
defaultTLSConfig, Addr: srv.TLSAddr()})
+       defer func() { require.NoError(t, sc.Close()) }()
+
+       replica := util.StartTLSServer(t, map[string]string{
+               "tls-replication": "yes",
+               "slaveof":         fmt.Sprintf("%s %d", srv.Host(), 
srv.TLSPort()),
+       })
+       defer replica.Close()
+
+       rc := replica.NewClientWithOption(&redis.Options{TLSConfig: 
defaultTLSConfig, Addr: replica.TLSAddr()})
+       defer func() { require.NoError(t, rc.Close()) }()
+
+       t.Run("TLS: Replication (incremental)", func(t *testing.T) {
+               time.Sleep(1000 * time.Millisecond)
+               require.Equal(t, rc.Get(ctx, "a").Val(), "")
+               require.Equal(t, rc.Get(ctx, "b").Val(), "")
+               require.NoError(t, sc.Set(ctx, "a", "1", 0).Err())
+               require.NoError(t, sc.Set(ctx, "b", "2", 0).Err())
+               util.WaitForOffsetSync(t, sc, rc)
+               require.Equal(t, rc.Get(ctx, "a").Val(), "1")
+               require.Equal(t, rc.Get(ctx, "b").Val(), "2")
+       })
+
+       require.NoError(t, sc.Set(ctx, "c", "3", 0).Err())
+
+       replica2 := util.StartTLSServer(t, map[string]string{
+               "tls-replication": "yes",
+               "slaveof":         fmt.Sprintf("%s %d", srv.Host(), 
srv.TLSPort()),
+       })
+       defer replica2.Close()
+
+       rc2 := replica2.NewClientWithOption(&redis.Options{TLSConfig: 
defaultTLSConfig, Addr: replica2.TLSAddr()})
+       defer func() { require.NoError(t, rc2.Close()) }()
+
+       t.Run("TLS: Replication (full)", func(t *testing.T) {
+               util.WaitForOffsetSync(t, sc, rc2)
+               require.Equal(t, rc2.Get(ctx, "a").Val(), "1")
+               require.Equal(t, rc2.Get(ctx, "b").Val(), "2")
+               require.Equal(t, rc2.Get(ctx, "c").Val(), "3")
+       })
+}
diff --git a/tests/gocase/util/server.go b/tests/gocase/util/server.go
index 995d0e6b..5d190750 100644
--- a/tests/gocase/util/server.go
+++ b/tests/gocase/util/server.go
@@ -63,6 +63,10 @@ func (s *KvrocksServer) Port() uint64 {
        return uint64(s.addr.AddrPort().Port())
 }
 
+func (s *KvrocksServer) TLSPort() uint64 {
+       return uint64(s.tlsAddr.AddrPort().Port())
+}
+
 func (s *KvrocksServer) TLSAddr() string {
        return s.tlsAddr.String()
 }

Reply via email to