This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new f3d796dd Support TLS for replication (#1630)
f3d796dd is described below
commit f3d796dd4a25de4e8a0e3bb14bbd1043aed5c9c6
Author: Twice <[email protected]>
AuthorDate: Tue Sep 12 13:42:27 2023 +0900
Support TLS for replication (#1630)
---
kvrocks.conf | 7 +++
src/cluster/replication.cc | 74 +++++++++++++++++-----
src/cluster/replication.h | 7 ++-
src/commands/cmd_replication.cc | 10 +--
src/common/io_util.cc | 133 +++++++++++++++++++++++++++++++++++++---
src/common/io_util.h | 14 +++++
src/config/config.cc | 1 +
src/config/config.h | 3 +
src/main.cc | 2 +-
src/server/server.cc | 2 +-
src/server/worker.cc | 4 +-
tests/gocase/tls/tls_test.go | 57 +++++++++++++++++
tests/gocase/util/server.go | 4 ++
13 files changed, 279 insertions(+), 39 deletions(-)
diff --git a/kvrocks.conf b/kvrocks.conf
index 6a015319..d2b027cb 100644
--- a/kvrocks.conf
+++ b/kvrocks.conf
@@ -381,6 +381,13 @@ redis-cursor-compatible no
#
# tls-session-cache-timeout 60
+# By default, a replica does not attempt to establish a TLS connection
+# with its master.
+#
+# Use the following directive to enable TLS on replication links.
+#
+# tls-replication yes
+
################################## SLOW LOG ###################################
# The Kvrocks Slow Log is a mechanism to log queries that exceeded a specified
diff --git a/src/cluster/replication.cc b/src/cluster/replication.cc
index d6d8d281..38548418 100644
--- a/src/cluster/replication.cc
+++ b/src/cluster/replication.cc
@@ -37,6 +37,7 @@
#include "fmt/format.h"
#include "io_util.h"
#include "rocksdb_crc32c.h"
+#include "scope_exit.h"
#include "server/redis_reply.h"
#include "server/server.h"
#include "status.h"
@@ -45,6 +46,12 @@
#include "time_util.h"
#include "unique_fd.h"
+#ifdef ENABLE_OPENSSL
+#include <event2/bufferevent_ssl.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#endif
+
Status FeedSlaveThread::Start() {
auto s = util::CreateThread("feed-replica", [this] {
sigset_t mask, omask;
@@ -54,7 +61,7 @@ Status FeedSlaveThread::Start() {
sigaddset(&mask, SIGHUP);
sigaddset(&mask, SIGPIPE);
pthread_sigmask(SIG_BLOCK, &mask, &omask);
- auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"));
+ auto s = util::SockSend(conn_->GetFD(), redis::SimpleString("OK"),
conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "failed to send OK response to the replica: " << s.Msg();
return;
@@ -85,7 +92,7 @@ void FeedSlaveThread::Join() {
void FeedSlaveThread::checkLivenessIfNeed() {
if (++interval_ % 1000) return;
const auto ping_command = redis::BulkString("ping");
- auto s = util::SockSend(conn_->GetFD(), ping_command);
+ auto s = util::SockSend(conn_->GetFD(), ping_command,
conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "Ping slave[" << conn_->GetAddr() << "] err: " << s.Msg() <<
", would stop the thread";
Stop();
@@ -134,7 +141,7 @@ void FeedSlaveThread::loop() {
if (is_first_repl_batch || batches_bulk.size() >= kMaxDelayBytes ||
updates_in_batches >= kMaxDelayUpdates ||
srv_->storage->LatestSeqNumber() - batch.sequence <= kMaxDelayUpdates)
{
// Send entire bulk which contain multiple batches
- auto s = util::SockSend(conn_->GetFD(), batches_bulk);
+ auto s = util::SockSend(conn_->GetFD(), batches_bulk,
conn_->GetBufferEvent());
if (!s.IsOK()) {
LOG(ERROR) << "Write error while sending batch to slave: " << s.Msg()
<< ". batches: 0x"
<< util::StringToHex(batches_bulk);
@@ -257,12 +264,35 @@ void ReplicationThread::CallbacksStateMachine::Start() {
LOG(ERROR) << "[replication] Failed to connect the master, err: " <<
cfd.Msg();
continue;
}
+#ifdef ENABLE_OPENSSL
+ SSL *ssl = nullptr;
+ if (repl_->srv_->GetConfig()->tls_replication) {
+ ssl = SSL_new(repl_->srv_->ssl_ctx.get());
+ if (!ssl) {
+ LOG(ERROR) << "Failed to construct SSL structure for new connection: "
<< SSLErrors{};
+ evutil_closesocket(*cfd);
+ return;
+ }
+ bev = bufferevent_openssl_socket_new(repl_->base_, *cfd, ssl,
BUFFEREVENT_SSL_CONNECTING, BEV_OPT_CLOSE_ON_FREE);
+ } else {
+ bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
+ }
+#else
bev = bufferevent_socket_new(repl_->base_, *cfd, BEV_OPT_CLOSE_ON_FREE);
+#endif
if (bev == nullptr) {
+#ifdef ENABLE_OPENSSL
+ if (ssl) SSL_free(ssl);
+#endif
close(*cfd);
LOG(ERROR) << "[replication] Failed to create the event socket";
continue;
}
+#ifdef ENABLE_OPENSSL
+ if (repl_->srv_->GetConfig()->tls_replication) {
+ bufferevent_openssl_set_allow_dirty_shutdown(bev, 1);
+ }
+#endif
}
if (bev == nullptr) { // failed to connect the master and received the stop
signal
return;
@@ -728,9 +758,19 @@ Status ReplicationThread::parallelFetchFile(const
std::string &dir,
if (this->stop_flag_) {
return {Status::NotOK, "replication thread was stopped"};
}
- int sock_fd = GET_OR_RET(util::SockConnect(this->host_,
this->port_).Prefixed("connect the server err"));
+ ssl_st *ssl = nullptr;
+#ifdef ENABLE_OPENSSL
+ if (this->srv_->GetConfig()->tls_replication) {
+ ssl = SSL_new(this->srv_->ssl_ctx.get());
+ }
+ auto exit = MakeScopeExit([ssl] { SSL_free(ssl); });
+#endif
+ int sock_fd = GET_OR_RET(util::SockConnect(this->host_, this->port_,
ssl).Prefixed("connect the server err"));
+#ifdef ENABLE_OPENSSL
+ exit.Disable();
+#endif
UniqueFD unique_fd{sock_fd};
- auto s = this->sendAuth(sock_fd);
+ auto s = this->sendAuth(sock_fd, ssl);
if (!s.IsOK()) {
return s.Prefixed("send the auth command err");
}
@@ -770,12 +810,12 @@ Status ReplicationThread::parallelFetchFile(const
std::string &dir,
// command, so we need to fetch all files by multiple command
interactions.
if (srv_->GetConfig()->master_use_repl_port) {
for (unsigned i = 0; i < fetch_files.size(); i++) {
- s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]},
fn);
+ s = this->fetchFiles(sock_fd, dir, {fetch_files[i]}, {crcs[i]},
fn, ssl);
if (!s.IsOK()) break;
}
} else {
if (!fetch_files.empty()) {
- s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn);
+ s = this->fetchFiles(sock_fd, dir, fetch_files, crcs, fn, ssl);
}
}
return s;
@@ -790,13 +830,13 @@ Status ReplicationThread::parallelFetchFile(const
std::string &dir,
return Status::OK();
}
-Status ReplicationThread::sendAuth(int sock_fd) {
+Status ReplicationThread::sendAuth(int sock_fd, ssl_st *ssl) {
// Send auth when needed
std::string auth = srv_->GetConfig()->masterauth;
if (!auth.empty()) {
UniqueEvbuf evbuf;
const auto auth_command = redis::MultiBulkString({"AUTH", auth});
- auto s = util::SockSend(sock_fd, auth_command);
+ auto s = util::SockSend(sock_fd, auth_command, ssl);
if (!s.IsOK()) return s.Prefixed("send auth command err");
while (true) {
if (evbuffer_read(evbuf.get(), sock_fd, -1) <= 0) {
@@ -814,15 +854,15 @@ Status ReplicationThread::sendAuth(int sock_fd) {
}
Status ReplicationThread::fetchFile(int sock_fd, evbuffer *evbuf, const
std::string &dir, const std::string &file,
- uint32_t crc, const FetchFileCallback &fn)
{
+ uint32_t crc, const FetchFileCallback &fn,
ssl_st *ssl) {
size_t file_size = 0;
// Read file size line
while (true) {
UniqueEvbufReadln line(evbuf, EVBUFFER_EOL_CRLF_STRICT);
if (!line) {
- if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
- return {Status::NotOK, fmt::format("read size: {}", strerror(errno))};
+ if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
+ return std::move(s).Prefixed("read size");
}
continue;
}
@@ -854,8 +894,8 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer
*evbuf, const std::str
tmp_crc = rocksdb::crc32c::Extend(tmp_crc, data, data_len);
remain -= data_len;
} else {
- if (evbuffer_read(evbuf, sock_fd, -1) <= 0) {
- return {Status::NotOK, fmt::format("read sst file: {}",
strerror(errno))};
+ if (auto s = util::EvbufferRead(evbuf, sock_fd, -1, ssl); !s) {
+ return std::move(s).Prefixed("read sst file");
}
}
}
@@ -873,7 +913,7 @@ Status ReplicationThread::fetchFile(int sock_fd, evbuffer
*evbuf, const std::str
}
Status ReplicationThread::fetchFiles(int sock_fd, const std::string &dir,
const std::vector<std::string> &files,
- const std::vector<uint32_t> &crcs, const
FetchFileCallback &fn) {
+ const std::vector<uint32_t> &crcs, const
FetchFileCallback &fn, ssl_st *ssl) {
std::string files_str;
for (const auto &file : files) {
files_str += file;
@@ -882,13 +922,13 @@ Status ReplicationThread::fetchFiles(int sock_fd, const
std::string &dir, const
files_str.pop_back();
const auto fetch_command = redis::MultiBulkString({"_fetch_file",
files_str});
- auto s = util::SockSend(sock_fd, fetch_command);
+ auto s = util::SockSend(sock_fd, fetch_command, ssl);
if (!s.IsOK()) return s.Prefixed("send fetch file command");
UniqueEvbuf evbuf;
for (unsigned i = 0; i < files.size(); i++) {
DLOG(INFO) << "[fetch] Start to fetch file " << files[i];
- s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn);
+ s = fetchFile(sock_fd, evbuf.get(), dir, files[i], crcs[i], fn, ssl);
if (!s.IsOK()) {
s = Status(Status::NotOK, "fetch file err: " + s.Msg());
LOG(WARNING) << "[fetch] Fail to fetch file " << files[i] << ", err: "
<< s.Msg();
diff --git a/src/cluster/replication.h b/src/cluster/replication.h
index 6bf5954b..b7f49717 100644
--- a/src/cluster/replication.h
+++ b/src/cluster/replication.h
@@ -32,6 +32,7 @@
#include <vector>
#include "event_util.h"
+#include "io_util.h"
#include "server/redis_connection.h"
#include "status.h"
#include "storage/storage.h"
@@ -197,11 +198,11 @@ class ReplicationThread : private
EventCallbackBase<ReplicationThread> {
CBState fullSyncReadCB(bufferevent *bev);
// Synchronized-Blocking ops
- Status sendAuth(int sock_fd);
+ Status sendAuth(int sock_fd, ssl_st *ssl);
Status fetchFile(int sock_fd, evbuffer *evbuf, const std::string &dir, const
std::string &file, uint32_t crc,
- const FetchFileCallback &fn);
+ const FetchFileCallback &fn, ssl_st *ssl);
Status fetchFiles(int sock_fd, const std::string &dir, const
std::vector<std::string> &files,
- const std::vector<uint32_t> &crcs, const FetchFileCallback
&fn);
+ const std::vector<uint32_t> &crcs, const FetchFileCallback
&fn, ssl_st *ssl);
Status parallelFetchFile(const std::string &dir, const
std::vector<std::pair<std::string, uint32_t>> &files);
static bool isRestoringError(const char *err);
static bool isWrongPsyncNum(const char *err);
diff --git a/src/commands/cmd_replication.cc b/src/commands/cmd_replication.cc
index 8ccdfc19..23ae269a 100644
--- a/src/commands/cmd_replication.cc
+++ b/src/commands/cmd_replication.cc
@@ -102,7 +102,7 @@ class CommandPSync : public Commander {
s = svr->AddSlave(conn, next_repl_seq_);
if (!s.IsOK()) {
std::string err = "-ERR " + s.Msg() + "\r\n";
- s = util::SockSend(conn->GetFD(), err);
+ s = util::SockSend(conn->GetFD(), err, conn->GetBufferEvent());
if (!s.IsOK()) {
LOG(WARNING) << "failed to send error message to the replica: " <<
s.Msg();
}
@@ -229,7 +229,7 @@ class CommandFetchMeta : public Commander {
std::string files;
auto s =
engine::Storage::ReplDataManager::GetFullReplDataInfo(svr->storage, &files);
if (!s.IsOK()) {
- s = util::SockSend(repl_fd, "-ERR can't create db checkpoint");
+ s = util::SockSend(repl_fd, "-ERR can't create db checkpoint", bev);
if (!s.IsOK()) {
LOG(WARNING) << "[replication] Failed to send error response: " <<
s.Msg();
}
@@ -237,7 +237,7 @@ class CommandFetchMeta : public Commander {
return;
}
// Send full data file info
- if (util::SockSend(repl_fd, files + CRLF).IsOK()) {
+ if (util::SockSend(repl_fd, files + CRLF, bev).IsOK()) {
LOG(INFO) << "[replication] Succeed sending full data file info to "
<< ip;
} else {
LOG(WARNING) << "[replication] Fail to send full data file info " <<
ip << ", error: " << strerror(errno);
@@ -291,8 +291,8 @@ class CommandFetchFile : public Commander {
if (!fd) break;
// Send file size and content
- if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF).IsOK() &&
- util::SockSendFile(repl_fd, *fd, file_size).IsOK()) {
+ if (util::SockSend(repl_fd, std::to_string(file_size) + CRLF,
bev).IsOK() &&
+ util::SockSendFile(repl_fd, *fd, file_size, bev).IsOK()) {
LOG(INFO) << "[replication] Succeed sending file " << file << " to "
<< ip;
} else {
LOG(WARNING) << "[replication] Fail to send file " << file << " to "
<< ip << ", error: " << strerror(errno);
diff --git a/src/common/io_util.cc b/src/common/io_util.cc
index b4779f86..2c273580 100644
--- a/src/common/io_util.cc
+++ b/src/common/io_util.cc
@@ -29,10 +29,19 @@
#include <poll.h>
#include <sys/types.h>
+#include "fmt/ostream.h"
+#include "server/tls_util.h"
+
#ifdef __linux__
#include <sys/sendfile.h>
#endif
+#ifdef ENABLE_OPENSSL
+#include <openssl/ssl.h>
+
+#include "event2/bufferevent_ssl.h"
+#endif
+
#include "event_util.h"
#include "scope_exit.h"
#include "unique_fd.h"
@@ -194,7 +203,7 @@ StatusOr<int> SockConnect(const std::string &host, uint32_t
port, int conn_timeo
// NOTE: fd should be blocking here
Status SockSend(int fd, const std::string &data) { return Write(fd, data); }
-// Implements SockSendFileCore to transfer data between file descriptors and
+// Implements SockSendFileImpl to transfer data between file descriptors and
// avoid transferring data to and from user space.
//
// The function prototype is just like sendfile(2) on Linux. in_fd is a file
@@ -204,7 +213,7 @@ Status SockSend(int fd, const std::string &data) { return
Write(fd, data); }
//
// The return value is the number of bytes written to out_fd, if the transfer
// was successful. On error, -1 is returned, and errno is set appropriately.
-ssize_t SockSendFileCore(int out_fd, int in_fd, off_t offset, size_t count) {
+ssize_t SendFileImpl(int out_fd, int in_fd, off_t offset, size_t count) {
#if defined(__linux__)
return sendfile(out_fd, in_fd, &offset, count);
@@ -215,18 +224,37 @@ ssize_t SockSendFileCore(int out_fd, int in_fd, off_t
offset, size_t count) {
else
return (ssize_t)len;
-#endif
+#else
errno = ENOSYS;
return -1;
+
+#endif
}
-// Send file by sendfile actually according to different operation systems,
-// please note that, the out socket fd should be in blocking mode.
-Status SockSendFile(int out_fd, int in_fd, size_t size) {
+#ifdef ENABLE_OPENSSL
+ssize_t SendFileSSLImpl(ssl_st *ssl, int in_fd, off_t offset, size_t count) {
+ constexpr size_t BUFFER_SIZE = 16 * 1024;
+ char buf[BUFFER_SIZE];
+ if (off_t ret = lseek(in_fd, offset, SEEK_SET); ret == -1) {
+ return -1;
+ }
+ count = count <= BUFFER_SIZE ? count : BUFFER_SIZE;
+ if (ssize_t ret = read(in_fd, buf, count); ret == -1) {
+ return -1;
+ } else {
+ count = ret;
+ }
+ return SSL_write(ssl, buf, (int)count);
+}
+#endif
+
+template <auto F, typename FD, typename... Args>
+Status SockSendFileImpl(FD out_fd, int in_fd, size_t size, Args... args) {
+ constexpr size_t BUFFER_SIZE = 16 * 1024;
off_t offset = 0;
while (size != 0) {
- size_t n = size <= 16 * 1024 ? size : 16 * 1024;
- ssize_t nwritten = SockSendFileCore(out_fd, in_fd, offset, n);
+ size_t n = size <= BUFFER_SIZE ? size : BUFFER_SIZE;
+ ssize_t nwritten = F(out_fd, in_fd, offset, n, args...);
if (nwritten == -1) {
if (errno == EINTR)
continue;
@@ -239,6 +267,27 @@ Status SockSendFile(int out_fd, int in_fd, size_t size) {
return Status::OK();
}
+// Send file by sendfile actually according to different operation systems,
+// please note that, the out socket fd should be in blocking mode.
+Status SockSendFile(int out_fd, int in_fd, size_t size) { return
SockSendFileImpl<SendFileImpl>(out_fd, in_fd, size); }
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+ if (ssl) {
+ return SockSendFileImpl<SendFileSSLImpl>(ssl, in_fd, size);
+ }
+#endif
+ return SockSendFile(out_fd, in_fd, size);
+}
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, bufferevent *bev) {
+#ifdef ENABLE_OPENSSL
+ return SockSendFile(out_fd, in_fd, size, bufferevent_openssl_get_ssl(bev));
+#else
+ return SockSendFile(out_fd, in_fd, size);
+#endif
+}
+
Status SockSetBlocking(int fd, int blocking) {
int flags = 0;
// Old flags
@@ -384,8 +433,8 @@ std::vector<std::string> GetLocalIPAddresses() {
return ip_addresses;
}
-template <auto syscall, typename... Args>
-Status WriteImpl(int fd, std::string_view data, Args &&...args) {
+template <auto syscall, typename FD, typename... Args>
+Status WriteImpl(FD fd, std::string_view data, Args &&...args) {
ssize_t n = 0;
while (n < static_cast<ssize_t>(data.size())) {
ssize_t nwritten = syscall(fd, data.data() + n, data.size() - n,
std::forward<Args>(args)...);
@@ -401,4 +450,68 @@ Status Write(int fd, const std::string &data) { return
WriteImpl<write>(fd, data
Status Pwrite(int fd, const std::string &data, off_t offset) { return
WriteImpl<pwrite>(fd, data, offset); }
+Status SockSend(int fd, const std::string &data, ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+ if (ssl) {
+ return WriteImpl<SSL_write>(ssl, data);
+ }
+#endif
+ return SockSend(fd, data);
+}
+
+Status SockSend(int fd, const std::string &data, bufferevent *bev) {
+#ifdef ENABLE_OPENSSL
+ return SockSend(fd, data, bufferevent_openssl_get_ssl(bev));
+#else
+ return SockSend(fd, data);
+#endif
+}
+
+StatusOr<int> SockConnect(const std::string &host, uint32_t port, ssl_st *ssl,
int conn_timeout, int timeout) {
+#ifdef ENABLE_OPENSSL
+ if (ssl) {
+ auto fd = GET_OR_RET(SockConnect(host, port, conn_timeout, timeout));
+ SSL_set_fd(ssl, fd);
+
+ auto bio = BIO_new_socket(fd, BIO_NOCLOSE);
+ SSL_set_bio(ssl, bio, bio);
+
+ if (int err = SSL_connect(ssl); err != 1) {
+ BIO_free(bio);
+ return {Status::NotOK, fmt::format("socket failed to do SSL handshake:
{}", fmt::streamed(SSLError(err)))};
+ }
+
+ return fd;
+ }
+#endif
+ return SockConnect(host, port, conn_timeout, timeout);
+}
+
+StatusOr<int> EvbufferRead(evbuffer *buf, evutil_socket_t fd, int howmuch,
ssl_st *ssl) {
+#ifdef ENABLE_OPENSSL
+ if (ssl) {
+ constexpr int BUFFER_SIZE = 4096;
+ char tmp[BUFFER_SIZE];
+
+ if (howmuch <= 0 || howmuch > BUFFER_SIZE) {
+ howmuch = BUFFER_SIZE;
+ }
+ if (howmuch = SSL_read(ssl, tmp, howmuch); howmuch <= 0) {
+ return {Status::NotOK, fmt::format("failed to read from SSL connection:
{}", fmt::streamed(SSLError(howmuch)))};
+ }
+
+ if (int ret = evbuffer_add(buf, tmp, howmuch); ret == -1) {
+ return {Status::NotOK, fmt::format("failed to add buffer: {}",
strerror(errno))};
+ }
+
+ return howmuch;
+ }
+#endif
+ if (int ret = evbuffer_read(buf, fd, howmuch); ret > 0) {
+ return ret;
+ } else {
+ return {Status::NotOK, fmt::format("failed to read from socket: {}",
strerror(errno))};
+ }
+}
+
} // namespace util
diff --git a/src/common/io_util.h b/src/common/io_util.h
index 9ab288ae..d30789ae 100644
--- a/src/common/io_util.h
+++ b/src/common/io_util.h
@@ -24,6 +24,11 @@
#include "status.h"
+// forward declarations
+struct ssl_st;
+struct bufferevent;
+struct evbuffer;
+
namespace util {
sockaddr_in NewSockaddrInet(const std::string &host, uint32_t port);
@@ -46,4 +51,13 @@ int AeWait(int fd, int mask, int milliseconds);
Status Write(int fd, const std::string &data);
Status Pwrite(int fd, const std::string &data, off_t offset);
+Status SockSend(int fd, const std::string &data, ssl_st *ssl);
+Status SockSend(int fd, const std::string &data, bufferevent *bev);
+
+Status SockSendFile(int out_fd, int in_fd, size_t size, ssl_st *ssl);
+Status SockSendFile(int out_fd, int in_fd, size_t size, bufferevent *bev);
+
+StatusOr<int> SockConnect(const std::string &host, uint32_t port, ssl_st *ssl,
int conn_timeout = 0, int timeout = 0);
+StatusOr<int> EvbufferRead(evbuffer *buf, int fd, int howmuch, ssl_st *ssl);
+
} // namespace util
diff --git a/src/config/config.cc b/src/config/config.cc
index e051b5e4..e12edc38 100644
--- a/src/config/config.cc
+++ b/src/config/config.cc
@@ -103,6 +103,7 @@ Config::Config() {
{"tls-session-caching", false, new YesNoField(&tls_session_caching,
true)},
{"tls-session-cache-size", false, new IntField(&tls_session_cache_size,
1024 * 20, 0, INT_MAX)},
{"tls-session-cache-timeout", false, new
IntField(&tls_session_cache_timeout, 300, 0, INT_MAX)},
+ {"tls-replication", true, new YesNoField(&tls_replication, false)},
#endif
{"workers", true, new IntField(&workers, 8, 1, 256)},
{"timeout", false, new IntField(&timeout, 0, 0, INT_MAX)},
diff --git a/src/config/config.h b/src/config/config.h
index dba87c94..b93496b7 100644
--- a/src/config/config.h
+++ b/src/config/config.h
@@ -74,6 +74,7 @@ struct Config {
Config();
~Config() = default;
uint32_t port = 0;
+
uint32_t tls_port = 0;
std::string tls_cert_file;
std::string tls_key_file;
@@ -88,6 +89,8 @@ struct Config {
bool tls_session_caching = true;
int tls_session_cache_size = 1024 * 20;
int tls_session_cache_timeout = 300;
+ bool tls_replication = false;
+
int workers = 0;
int timeout = 0;
int log_level = 0;
diff --git a/src/main.cc b/src/main.cc
index 1a2cefc5..a04f8aa8 100644
--- a/src/main.cc
+++ b/src/main.cc
@@ -355,7 +355,7 @@ int main(int argc, char *argv[]) {
#ifdef ENABLE_OPENSSL
// initialize OpenSSL
- if (config.tls_port) {
+ if (config.tls_port || config.tls_replication) {
InitSSL();
}
#endif
diff --git a/src/server/server.cc b/src/server/server.cc
index 254f3bd5..4d49bd5f 100644
--- a/src/server/server.cc
+++ b/src/server/server.cc
@@ -66,7 +66,7 @@ Server::Server(engine::Storage *storage, Config *config)
#ifdef ENABLE_OPENSSL
// init ssl context
- if (config->tls_port) {
+ if (config->tls_port || config->tls_replication) {
ssl_ctx = CreateSSLContext(config);
if (!ssl_ctx) {
exit(1);
diff --git a/src/server/worker.cc b/src/server/worker.cc
index 3383c407..9506d566 100644
--- a/src/server/worker.cc
+++ b/src/server/worker.cc
@@ -129,8 +129,8 @@ void Worker::newTCPConnection(evconnlistener *listener,
evutil_socket_t fd, sock
BEV_OPT_THREADSAFE | BEV_OPT_DEFER_CALLBACKS | BEV_OPT_UNLOCK_CALLBACKS
| BEV_OPT_CLOSE_ON_FREE;
bufferevent *bev = nullptr;
+ ssl_st *ssl = nullptr;
#ifdef ENABLE_OPENSSL
- SSL *ssl = nullptr;
if (uint32_t(local_port) == svr->GetConfig()->tls_port) {
ssl = SSL_new(svr->ssl_ctx.get());
if (!ssl) {
@@ -168,7 +168,7 @@ void Worker::newTCPConnection(evconnlistener *listener,
evutil_socket_t fd, sock
s = AddConnection(conn);
if (!s.IsOK()) {
std::string err_msg = redis::Error("ERR " + s.Msg());
- s = util::SockSend(fd, err_msg);
+ s = util::SockSend(fd, err_msg, ssl);
if (!s.IsOK()) {
LOG(WARNING) << "Failed to send error response to socket: " << s.Msg();
}
diff --git a/tests/gocase/tls/tls_test.go b/tests/gocase/tls/tls_test.go
index 77df2202..460f53c4 100644
--- a/tests/gocase/tls/tls_test.go
+++ b/tests/gocase/tls/tls_test.go
@@ -22,7 +22,9 @@ package tls
import (
"context"
"crypto/tls"
+ "fmt"
"testing"
+ "time"
"github.com/apache/kvrocks/tests/gocase/util"
"github.com/redis/go-redis/v9"
@@ -136,3 +138,58 @@ func TestTLS(t *testing.T) {
require.NoError(t, rdb.ConfigSet(ctx, "tls-ciphers",
"DEFAULT").Err())
})
}
+
+func TestTLSReplica(t *testing.T) {
+ if !util.TLSEnable() {
+ t.Skip("TLS tests run only if tls enabled.")
+ }
+
+ ctx := context.Background()
+
+ srv := util.StartTLSServer(t, map[string]string{})
+ defer srv.Close()
+
+ defaultTLSConfig, err := util.DefaultTLSConfig()
+ require.NoError(t, err)
+
+ sc := srv.NewClientWithOption(&redis.Options{TLSConfig:
defaultTLSConfig, Addr: srv.TLSAddr()})
+ defer func() { require.NoError(t, sc.Close()) }()
+
+ replica := util.StartTLSServer(t, map[string]string{
+ "tls-replication": "yes",
+ "slaveof": fmt.Sprintf("%s %d", srv.Host(),
srv.TLSPort()),
+ })
+ defer replica.Close()
+
+ rc := replica.NewClientWithOption(&redis.Options{TLSConfig:
defaultTLSConfig, Addr: replica.TLSAddr()})
+ defer func() { require.NoError(t, rc.Close()) }()
+
+ t.Run("TLS: Replication (incremental)", func(t *testing.T) {
+ time.Sleep(1000 * time.Millisecond)
+ require.Equal(t, rc.Get(ctx, "a").Val(), "")
+ require.Equal(t, rc.Get(ctx, "b").Val(), "")
+ require.NoError(t, sc.Set(ctx, "a", "1", 0).Err())
+ require.NoError(t, sc.Set(ctx, "b", "2", 0).Err())
+ util.WaitForOffsetSync(t, sc, rc)
+ require.Equal(t, rc.Get(ctx, "a").Val(), "1")
+ require.Equal(t, rc.Get(ctx, "b").Val(), "2")
+ })
+
+ require.NoError(t, sc.Set(ctx, "c", "3", 0).Err())
+
+ replica2 := util.StartTLSServer(t, map[string]string{
+ "tls-replication": "yes",
+ "slaveof": fmt.Sprintf("%s %d", srv.Host(),
srv.TLSPort()),
+ })
+ defer replica2.Close()
+
+ rc2 := replica2.NewClientWithOption(&redis.Options{TLSConfig:
defaultTLSConfig, Addr: replica2.TLSAddr()})
+ defer func() { require.NoError(t, rc2.Close()) }()
+
+ t.Run("TLS: Replication (full)", func(t *testing.T) {
+ util.WaitForOffsetSync(t, sc, rc2)
+ require.Equal(t, rc2.Get(ctx, "a").Val(), "1")
+ require.Equal(t, rc2.Get(ctx, "b").Val(), "2")
+ require.Equal(t, rc2.Get(ctx, "c").Val(), "3")
+ })
+}
diff --git a/tests/gocase/util/server.go b/tests/gocase/util/server.go
index 995d0e6b..5d190750 100644
--- a/tests/gocase/util/server.go
+++ b/tests/gocase/util/server.go
@@ -63,6 +63,10 @@ func (s *KvrocksServer) Port() uint64 {
return uint64(s.addr.AddrPort().Port())
}
+func (s *KvrocksServer) TLSPort() uint64 {
+ return uint64(s.tlsAddr.AddrPort().Port())
+}
+
func (s *KvrocksServer) TLSAddr() string {
return s.tlsAddr.String()
}