Copilot commented on code in PR #3295: URL: https://github.com/apache/kvrocks/pull/3295#discussion_r2623048693
########## src/cluster/cluster_failover.cc: ########## @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cluster_failover.h" + +#include <unistd.h> + +#include "cluster/cluster.h" +#include "common/io_util.h" +#include "common/time_util.h" +#include "logging.h" +#include "server/redis_reply.h" +#include "server/server.h" + +ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) { + t_ = std::thread([this]() { loop(); }); +} + +ClusterFailover::~ClusterFailover() { + { + std::lock_guard<std::mutex> lock(mutex_); + stop_thread_ = true; + cv_.notify_all(); + } + if (t_.joinable()) t_.join(); +} + +Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) { + std::lock_guard<std::mutex> lock(mutex_); + if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) { + return {Status::NotOK, "Failover is already in progress"}; + } + Review Comment: There is trailing whitespace on this line after the opening brace. This should be removed for consistency with code formatting standards. ```suggestion ``` ########## src/cluster/cluster_failover.cc: ########## @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cluster_failover.h" + +#include <unistd.h> + +#include "cluster/cluster.h" +#include "common/io_util.h" +#include "common/time_util.h" +#include "logging.h" +#include "server/redis_reply.h" +#include "server/server.h" + +ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) { + t_ = std::thread([this]() { loop(); }); +} + +ClusterFailover::~ClusterFailover() { + { + std::lock_guard<std::mutex> lock(mutex_); + stop_thread_ = true; + cv_.notify_all(); + } + if (t_.joinable()) t_.join(); +} + +Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) { + std::lock_guard<std::mutex> lock(mutex_); + if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) { + return {Status::NotOK, "Failover is already in progress"}; + } + + if (srv_->IsSlave()) { + return {Status::NotOK, "Current node is a slave, can't failover"}; + } + + slave_node_id_ = std::move(slave_node_id); + timeout_ms_ = timeout_ms; + state_ = FailoverState::kStarted; + failover_job_triggered_ = true; + cv_.notify_all(); + return Status::OK(); +} + +void ClusterFailover::loop() { + while (true) { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, [this]() { return stop_thread_ || failover_job_triggered_; }); + + if (stop_thread_) return; + + if (failover_job_triggered_) { + failover_job_triggered_ = false; + lock.unlock(); + runFailoverProcess(); + } + } +} + +void ClusterFailover::runFailoverProcess() { + auto ip_port = srv_->cluster->GetNodeIPPort(slave_node_id_); + if (!ip_port.IsOK()) { + error("[Failover] slave node not found in cluster {}", slave_node_id_); + abortFailover("Slave node not found in cluster"); + return; + } + node_ip_port_ = ip_port.GetValue().first + ":" + std::to_string(ip_port.GetValue().second); + node_ip_ = ip_port.GetValue().first; + node_port_ = ip_port.GetValue().second; + info("[Failover] slave node {} {} failover state: {}", slave_node_id_, node_ip_port_, static_cast<int>(state_.load())); + state_ = FailoverState::kCheck; + + auto s = checkSlaveStatus(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + s = checkSlaveLag(); + if (!s.IsOK()) { + abortFailover("Slave lag check failed: " + s.Msg()); + return; + } + + info("[Failover] slave node {} {} check slave status success, enter pause state", slave_node_id_, node_ip_port_); + start_time_ms_ = util::GetTimeStampMS(); + // Enter Pause state (Stop writing) + state_ = FailoverState::kPause; + // Get current sequence + target_seq_ = srv_->storage->LatestSeqNumber(); + info("[Failover] slave node {} {} target sequence {}", slave_node_id_, node_ip_port_, target_seq_); + + state_ = FailoverState::kSyncWait; + s = waitReplicationSync(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + info("[Failover] slave node {} {} wait replication sync success, enter switch state, cost {} ms", slave_node_id_, + node_ip_port_, util::GetTimeStampMS() - start_time_ms_); + + state_ = FailoverState::kSwitch; + s = sendTakeoverCmd(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + // Redirect slots + srv_->cluster->SetMySlotsMigrated(node_ip_port_); + + state_ = FailoverState::kSuccess; + info("[Failover] success {} {}", slave_node_id_, node_ip_port_); +} Review Comment: The member variables (slave_node_id_, node_ip_port_, node_ip_, node_port_, timeout_ms_, target_seq_, start_time_ms_) are accessed by the Run() method (with mutex held) and the runFailoverProcess() method (without mutex). This creates a potential race condition where Run() could be called while runFailoverProcess() is still executing and reading these variables. Consider protecting access to these shared state variables with the mutex throughout runFailoverProcess(), or ensure they are only modified when no failover is in progress. ########## src/server/redis_connection.cc: ########## @@ -480,6 +481,25 @@ void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) { continue; } + // Get slot for imported_slots_ check (needed for failover scenario) + int slot = -1; + if (config->cluster_enabled && (cmd_flags & kCmdWrite)) { + std::vector<int> key_indexes; + attributes->ForEachKeyRange( + [&](const std::vector<std::string> &, redis::CommandKeyRange key_range) { + key_range.ForEachKeyIndex([&](int i) { key_indexes.push_back(i); }, cmd_tokens.size()); + }, + cmd_tokens); + if (!key_indexes.empty()) { + for (auto i : key_indexes) { + if (i < static_cast<int>(cmd_tokens.size())) { + slot = GetSlotIdFromKey(cmd_tokens[i]); + break; + } + } + } + } Review Comment: The slot calculation is performed for all write commands, but it is only needed when the node is a slave (for the imported_slots_ check). This adds unnecessary overhead for the common case when the node is a master. Consider moving the slot calculation inside the slave_readonly check block to improve performance. ########## src/cluster/cluster_failover.cc: ########## @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cluster_failover.h" + +#include <unistd.h> + +#include "cluster/cluster.h" +#include "common/io_util.h" +#include "common/time_util.h" +#include "logging.h" +#include "server/redis_reply.h" +#include "server/server.h" + +ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) { + t_ = std::thread([this]() { loop(); }); +} + +ClusterFailover::~ClusterFailover() { + { + std::lock_guard<std::mutex> lock(mutex_); + stop_thread_ = true; + cv_.notify_all(); + } + if (t_.joinable()) t_.join(); +} + +Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) { + std::lock_guard<std::mutex> lock(mutex_); + if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) { + return {Status::NotOK, "Failover is already in progress"}; + } + + if (srv_->IsSlave()) { + return {Status::NotOK, "Current node is a slave, can't failover"}; + } + + slave_node_id_ = std::move(slave_node_id); + timeout_ms_ = timeout_ms; + state_ = FailoverState::kStarted; + failover_job_triggered_ = true; + cv_.notify_all(); + return Status::OK(); +} + +void ClusterFailover::loop() { + while (true) { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, [this]() { return stop_thread_ || failover_job_triggered_; }); + + if (stop_thread_) return; + + if (failover_job_triggered_) { + failover_job_triggered_ = false; + lock.unlock(); + runFailoverProcess(); + } + } +} + +void ClusterFailover::runFailoverProcess() { + auto ip_port = srv_->cluster->GetNodeIPPort(slave_node_id_); + if (!ip_port.IsOK()) { + error("[Failover] slave node not found in cluster {}", slave_node_id_); + abortFailover("Slave node not found in cluster"); + return; + } + node_ip_port_ = ip_port.GetValue().first + ":" + std::to_string(ip_port.GetValue().second); + node_ip_ = ip_port.GetValue().first; + node_port_ = ip_port.GetValue().second; + info("[Failover] slave node {} {} failover state: {}", slave_node_id_, node_ip_port_, static_cast<int>(state_.load())); + state_ = FailoverState::kCheck; + + auto s = checkSlaveStatus(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + s = checkSlaveLag(); + if (!s.IsOK()) { + abortFailover("Slave lag check failed: " + s.Msg()); + return; + } + + info("[Failover] slave node {} {} check slave status success, enter pause state", slave_node_id_, node_ip_port_); + start_time_ms_ = util::GetTimeStampMS(); + // Enter Pause state (Stop writing) + state_ = FailoverState::kPause; + // Get current sequence + target_seq_ = srv_->storage->LatestSeqNumber(); + info("[Failover] slave node {} {} target sequence {}", slave_node_id_, node_ip_port_, target_seq_); + + state_ = FailoverState::kSyncWait; + s = waitReplicationSync(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + info("[Failover] slave node {} {} wait replication sync success, enter switch state, cost {} ms", slave_node_id_, + node_ip_port_, util::GetTimeStampMS() - start_time_ms_); + + state_ = FailoverState::kSwitch; + s = sendTakeoverCmd(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + // Redirect slots + srv_->cluster->SetMySlotsMigrated(node_ip_port_); + + state_ = FailoverState::kSuccess; + info("[Failover] success {} {}", slave_node_id_, node_ip_port_); +} + +Status ClusterFailover::checkSlaveLag() { + auto start_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!start_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + start_offset_status.Msg()}; + } + uint64_t start_offset = *start_offset_status; + int64_t start_sampling_ms = util::GetTimeStampMS(); + + // Wait 3s or half of timeout, but at least a bit to measure speed + int64_t wait_time = std::max(100, std::min(3000, timeout_ms_ / 2)); + std::this_thread::sleep_for(std::chrono::milliseconds(wait_time)); + + auto end_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!end_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + end_offset_status.Msg()}; + } + uint64_t end_offset = *end_offset_status; + int64_t end_sampling_ms = util::GetTimeStampMS(); + + double elapsed_sec = (end_sampling_ms - start_sampling_ms) / 1000.0; + if (elapsed_sec <= 0) elapsed_sec = 0.001; + + uint64_t bytes = 0; + if (end_offset > start_offset) bytes = end_offset - start_offset; + double speed = bytes / elapsed_sec; + + uint64_t master_seq = srv_->storage->LatestSeqNumber(); + uint64_t lag = 0; + if (master_seq > end_offset) lag = master_seq - end_offset; + + if (lag == 0) return Status::OK(); + + if (speed <= 0.1) { // Basically 0 + return {Status::NotOK, fmt::format("Slave is not replicating (lag: {})", lag)}; + } + + double required_sec = lag / speed; + int64_t required_ms = static_cast<int64_t>(required_sec * 1000); + + int64_t elapsed_total = end_sampling_ms - start_sampling_ms; + int64_t remaining = timeout_ms_ - elapsed_total; + + if (required_ms > remaining) { + return {Status::NotOK, fmt::format("Estimated catchup time {}ms > remaining time {}ms (lag: {}, speed: {:.2f}/s)", + required_ms, remaining, lag, speed)}; + } + + info("[Failover] check: lag={}, speed={:.2f}/s, estimated_time={}ms, remaining={}ms", lag, speed, required_ms, + remaining); + return Status::OK(); +} + +Status ClusterFailover::checkSlaveStatus() { + // We could try to connect, but GetSlaveReplicationOffset checks connection. + auto offset = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!offset.IsOK()) { + error("[Failover] slave node {} {} not connected or not syncing", slave_node_id_, node_ip_port_); + return {Status::NotOK, "Slave not connected or not syncing"}; + } + info("[Failover] slave node {} {} is connected and syncing offset {}", slave_node_id_, node_ip_port_, offset.Msg()); + return Status::OK(); +} + +Status ClusterFailover::waitReplicationSync() { + while (true) { + if (util::GetTimeStampMS() - start_time_ms_ > static_cast<uint64_t>(timeout_ms_)) { + return {Status::NotOK, "Timeout waiting for replication sync"}; + } + + auto offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + offset_status.Msg()}; + } + + if (*offset_status >= target_seq_) { + return Status::OK(); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } +} + +Status ClusterFailover::sendTakeoverCmd() { + auto s = util::SockConnect(node_ip_, node_port_); + if (!s.IsOK()) { + return {Status::NotOK, "Failed to connect to slave: " + s.Msg()}; + } + int fd = *s; + + std::string pass = srv_->GetConfig()->requirepass; + if (!pass.empty()) { + std::string auth_cmd = redis::ArrayOfBulkStrings({"AUTH", pass}); + auto s_auth = util::SockSend(fd, auth_cmd); + if (!s_auth.IsOK()) { + close(fd); + return {Status::NotOK, "Failed to send AUTH: " + s_auth.Msg()}; + } + auto s_line = util::SockReadLine(fd); + if (!s_line.IsOK() || s_line.GetValue().substr(0, 3) != "+OK") { + close(fd); + return {Status::NotOK, "AUTH failed"}; + } Review Comment: Potential resource leak: if SockReadLine fails or returns an unexpected response, the socket fd is closed. However, if GetValue().substr() throws an exception (e.g., if the response is less than 3 characters), the socket will not be closed. Consider using RAII or ensuring close(fd) is called in all paths, including exception paths. ########## src/cluster/cluster.cc: ########## @@ -976,3 +990,51 @@ Status Cluster::Reset() { unlink(srv_->GetConfig()->NodesFilePath().data()); return Status::OK(); } + +StatusOr<std::pair<std::string, int>> Cluster::GetNodeIPPort(const std::string &node_id) { + auto it = nodes_.find(node_id); + if (it == nodes_.end()) { + return {Status::NotOK, "Node not found"}; + } + return std::make_pair(it->second->host, it->second->port); +} + +Status Cluster::OnTakeOver() { + info("[Failover] OnTakeOver received myself_: {}", myself_ ? myself_->id : "null"); + if (!myself_) { + return {Status::NotOK, "Cluster is not initialized"}; + } + if (myself_->role == kClusterMaster) { + info("[Failover] OnTakeOver myself_ is master, return"); + return Status::OK(); + } + + std::string old_master_id = myself_->master_id; + if (old_master_id.empty()) { + info("[Failover] OnTakeOver no master to takeover, return"); + return {Status::NotOK, "No master to takeover"}; + } + + for (int i = 0; i < kClusterSlots; i++) { + if (slots_nodes_[i] && slots_nodes_[i]->id == old_master_id) { + imported_slots_.insert(i); + } + } + info("[Failover] OnTakeOver Success "); + return Status::OK(); +} Review Comment: The OnTakeOver method modifies imported_slots_ without acquiring any lock or exclusivity guard. This data structure is accessed by IsSlotImported() from request processing threads without synchronization. This creates a potential race condition where imported_slots_ could be modified while being read by other threads, leading to undefined behavior. ########## tests/gocase/integration/failover/failover_test.go: ########## @@ -0,0 +1,972 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +package failover + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +// testNameWrapper wraps testing.TB to sanitize test names for MkdirTemp +// This is needed because subtest names contain "/" which causes MkdirTemp to fail +type testNameWrapper struct { + testing.TB + sanitizedName string +} + +func (w *testNameWrapper) Name() string { + return w.sanitizedName +} + +// sanitizeTestName replaces path separators in test names to avoid issues with MkdirTemp +func sanitizeTestName(tb testing.TB) testing.TB { + sanitizedName := strings.ReplaceAll(tb.Name(), "/", "_") + return &testNameWrapper{TB: tb, sanitizedName: sanitizedName} +} + +// startServerWithSanitizedName starts a server with a sanitized test name +func startServerWithSanitizedName(t testing.TB, configs map[string]string) *util.KvrocksServer { + return util.StartServer(sanitizeTestName(t), configs) +} + +type FailoverState string + +const ( + FailoverStateNone FailoverState = "none" + FailoverStateStarted FailoverState = "started" + FailoverStateCheckSlave FailoverState = "check_slave" + FailoverStatePauseWrite FailoverState = "pause_write" + FailoverStateWaitSync FailoverState = "wait_sync" + FailoverStateSwitching FailoverState = "switching" + FailoverStateSuccess FailoverState = "success" + FailoverStateFailed FailoverState = "failed" +) + +// TestFailoverBasicFlow tests the basic failover process and custom timeout parameter. +// Test Case 1.1: Basic Failover Flow - Master successfully transfers control to Slave +// Test Case 1.2: Failover with Custom Timeout - Using custom timeout parameter +func TestFailoverBasicFlow(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + // Wait for replication to establish + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 1.1: Basic Failover Flow + t.Run("FAILOVER - Basic failover flow", func(t *testing.T) { + // Write some data + require.NoError(t, masterClient.Set(ctx, "key1", "value1", 0).Err()) + require.NoError(t, masterClient.Set(ctx, "key2", "value2", 0).Err()) + + // Start failover + result := masterClient.Do(ctx, "clusterx", "failover", slaveID) + if result.Err() != nil { + t.Logf("FAILOVER command error: %v", result.Err()) + } + require.NoError(t, result.Err(), "FAILOVER command should succeed") + require.Equal(t, "OK", result.Val()) + + // Wait for failover to complete + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + + // Verify slots are migrated (MOVED response) + require.ErrorContains(t, masterClient.Set(ctx, "key1", "newvalue", 0).Err(), "MOVED") + require.ErrorContains(t, masterClient.Get(ctx, "key1").Err(), "MOVED") + + // Verify data is accessible on new master (slave) + require.Equal(t, "value1", slaveClient.Get(ctx, "key1").Val()) + require.Equal(t, "value2", slaveClient.Get(ctx, "key2").Val()) + }) + + // Test Case 1.2: Failover with Custom Timeout + t.Run("FAILOVER - Failover with custom timeout", func(t *testing.T) { + // Reset failover state by updating topology + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Start failover with custom timeout + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "5000").Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + }) +} + +// TestFailoverFailureCases tests various failure scenarios and timeout values. +// Test Case 2.1: Slave Node Not Found - Specified slave_node_id is not in cluster +// Test Case 2.2: Slave Not Connected - Slave node exists but no replication connection +// Test Case 3.5: Different Timeout Values - Testing various timeout values (0, 100, 10000) +func TestFailoverFailureCases(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + // Test Case 2.1: Slave Node Not Found + t.Run("FAILOVER - Failover to non-existent node", func(t *testing.T) { + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383", masterID, master.Host(), master.Port()) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + nonExistentID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx99" + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", nonExistentID).Val()) + waitForFailoverState(t, masterClient, FailoverStateFailed, 5*time.Second) + }) + + // Test Case 2.2: Slave Not Connected (node exists as master, not slave) + t.Run("FAILOVER - Failover to non-slave node", func(t *testing.T) { + slave := startServerWithSanitizedName(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + // Set slave as master (not slave) + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d master -", slaveID, slave.Host(), slave.Port()) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID).Val()) + waitForFailoverState(t, masterClient, FailoverStateFailed, 5*time.Second) + }) + + // Test Case 3.5: Invalid timeout value (negative) + t.Run("FAILOVER - Invalid timeout value", func(t *testing.T) { + slave := startServerWithSanitizedName(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "3").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "3").Err()) + + // Negative timeout should return error + require.Error(t, masterClient.Do(ctx, "clusterx", "failover", slaveID, "-1").Err()) + }) + + // Test Case 3.5: Different Timeout Values (0, 100, 10000) + t.Run("FAILOVER - Different timeout values", func(t *testing.T) { + slave := startServerWithSanitizedName(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "4").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "4").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test with timeout = 0 when slave is already synced (lag=0) + // When lag=0, failover should succeed because no waiting is needed + // But if slave has lag, it will fail + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "0").Val()) + // Wait for either success or failed state + require.Eventually(t, func() bool { + info := masterClient.ClusterInfo(ctx).Val() + return strings.Contains(info, "cluster_failover_state:success") || + strings.Contains(info, "cluster_failover_state:failed") + }, 5*time.Second, 100*time.Millisecond) + + // Check final state - can be success (if lag=0) or failed (if lag>0) + finalInfo := masterClient.ClusterInfo(ctx).Val() + if strings.Contains(finalInfo, "cluster_failover_state:success") { + t.Logf("timeout=0 with lag=0: failover succeeded as expected") + } else if strings.Contains(finalInfo, "cluster_failover_state:failed") { + t.Logf("timeout=0: failover failed (slave may have lag)") + } + + // Reset for next test + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "5").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "5").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test with timeout = 0 when slave has lag + // Create lag by writing data to master without waiting for sync + for i := 0; i < 100; i++ { + require.NoError(t, masterClient.Set(ctx, fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 0).Err()) + } + // Don't wait for sync, start failover immediately to create lag scenario + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "0").Val()) + // With lag, timeout=0 should fail + waitForFailoverState(t, masterClient, FailoverStateFailed, 5*time.Second) + + // Reset and test with small timeout + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "6").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "6").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "100").Val()) + // Small timeout may fail, but should start + time.Sleep(200 * time.Millisecond) + info := masterClient.ClusterInfo(ctx).Val() + require.True(t, strings.Contains(info, "cluster_failover_state:failed") || + strings.Contains(info, "cluster_failover_state:success") || + strings.Contains(info, "cluster_failover_state:wait_sync") || + strings.Contains(info, "cluster_failover_state:switching")) + + // Reset and test with large timeout + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "7").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "7").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "10000").Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 15*time.Second) + }) +} + +// TestFailoverConcurrency tests concurrent failover scenarios. +// Test Case 3.1: Duplicate Failover - Cannot start failover when one is in progress +// Test Case 3.2: Restart After Failure - Can restart failover after previous failure +func TestFailoverConcurrency(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 3.1: Duplicate Failover + t.Run("FAILOVER - Cannot start failover when one is in progress", func(t *testing.T) { + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID).Val()) + + // Try to start another failover immediately - should return error + // Wait a bit to ensure first failover has started + time.Sleep(100 * time.Millisecond) + result := masterClient.Do(ctx, "clusterx", "failover", slaveID) + // second failover may return an error indicating a failover is already in progress. + _, err := result.Result() + if err != nil { + require.Contains(t, err.Error(), "Failover is already in progress") + } else { + // should not reach here + require.Fail(t, "second failover should return error") + } + + // We verify the first one completes successfully + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + }) + + // Test Case 3.2: Restart After Failure + t.Run("FAILOVER - Can restart after failure", func(t *testing.T) { + // Reset state + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Start a failover with very short timeout + // If slave is synced (lag=0), it may succeed; if slave has lag, it will fail + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "1").Val()) + // Accept either success or failed state + require.Eventually(t, func() bool { + info := masterClient.ClusterInfo(ctx).Val() + return strings.Contains(info, "cluster_failover_state:success") || + strings.Contains(info, "cluster_failover_state:failed") + }, 5*time.Second, 100*time.Millisecond) + + // Can restart after failure + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "3").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "3").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "10000").Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 15*time.Second) + }) +} + +// TestFailoverWriteBlocking tests write and read request behavior during failover. +// Test Case 3.3: Write Requests During Failover - Write requests return TRYAGAIN in blocking states +// Test Case 3.4: Read Requests During Failover - Read requests are not blocked +func TestFailoverWriteBlocking(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 3.3: Write Requests During Failover + t.Run("FAILOVER - Write requests blocked during failover", func(t *testing.T) { + // Write initial data + require.NoError(t, masterClient.Set(ctx, "testkey", "testvalue", 0).Err()) + + // Start failover with long timeout to observe blocking + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "10000").Val()) + + // Try to write during failover - should return TRYAGAIN in blocking states + // Poll for blocking state (pause_write, wait_sync, or switching) + for i := 0; i < 50; i++ { + time.Sleep(50 * time.Millisecond) + err := masterClient.Set(ctx, "testkey", "newvalue", 0).Err() + if err != nil && (strings.Contains(err.Error(), "TRYAGAIN") || strings.Contains(err.Error(), "Failover in progress")) { + break + } + // Check if failover already completed + info := masterClient.ClusterInfo(ctx).Val() + if strings.Contains(info, "cluster_failover_state:success") { + break + } + } + // At least one write should have been blocked, or failover completed very quickly + waitForFailoverState(t, masterClient, FailoverStateSuccess, 15*time.Second) + + // After success, writes should return MOVED + require.ErrorContains(t, masterClient.Set(ctx, "testkey", "newvalue2", 0).Err(), "MOVED") + }) + + // Test Case 3.4: Read Requests During Failover + t.Run("FAILOVER - Read requests not blocked during failover", func(t *testing.T) { + // Reset state + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + require.NoError(t, masterClient.Set(ctx, "readkey", "readvalue", 0).Err()) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID, "10000").Val()) + + // Reads should work during failover (not blocked) + // Try reading multiple times during failover + for i := 0; i < 10; i++ { + time.Sleep(100 * time.Millisecond) + val := masterClient.Get(ctx, "readkey").Val() + require.Equal(t, "readvalue", val) + // Check if failover completed + info := masterClient.ClusterInfo(ctx).Val() + if strings.Contains(info, "cluster_failover_state:success") { + break + } + } + waitForFailoverState(t, masterClient, FailoverStateSuccess, 15*time.Second) + }) +} + +// TestFailoverWithAuth tests failover with password authentication. +// Test Case 1.4: Failover with Password Authentication - Cluster configured with requirepass +func TestFailoverWithAuth(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{ + "cluster-enabled": "yes", + "requirepass": "password123", + }) + defer func() { master.Close() }() + masterClient := master.NewClient() + masterClient = redis.NewClient(&redis.Options{ + Addr: master.HostPort(), + Password: "password123", + }) + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := startServerWithSanitizedName(t, map[string]string{ + "cluster-enabled": "yes", + "requirepass": "password123", + "masterauth": "password123", + }) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + slaveClient = redis.NewClient(&redis.Options{ + Addr: slave.HostPort(), + Password: "password123", + }) + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 1.4: Failover with Password Authentication + t.Run("FAILOVER - Failover with authentication", func(t *testing.T) { + require.NoError(t, masterClient.Set(ctx, "authkey", "authvalue", 0).Err()) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID).Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + + // Verify data on new master + require.Equal(t, "authvalue", slaveClient.Get(ctx, "authkey").Val()) + }) +} + +// TestFailoverStateQuery tests querying failover state information. +// Test Case 4.1: CLUSTER INFO State Output - Query failover state and verify all state transitions +func TestFailoverStateQuery(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 4.1: CLUSTER INFO State Output + t.Run("FAILOVER - Query failover state via CLUSTER INFO", func(t *testing.T) { + // Initial state should be none + info := masterClient.ClusterInfo(ctx).Val() + require.Contains(t, info, "cluster_failover_state:none") + + // Start failover + result := masterClient.Do(ctx, "clusterx", "failover", slaveID) + if result.Err() != nil { + t.Logf("FAILOVER command error: %v", result.Err()) + } + require.NoError(t, result.Err(), "FAILOVER command should succeed") + require.Equal(t, "OK", result.Val()) + + // Wait for success + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + + // Verify state is success + info = masterClient.ClusterInfo(ctx).Val() + require.Contains(t, info, "cluster_failover_state:success") + }) + + // Test Case 4.1: All State Transitions + t.Run("FAILOVER - All state transitions", func(t *testing.T) { + // Reset state + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Initial state: none + info := masterClient.ClusterInfo(ctx).Val() + require.Contains(t, info, "cluster_failover_state:none") + + // Start failover + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID).Val()) + + // We may catch intermediate states, but they're very fast + // The important thing is we transition through them and end at success + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + + // Verify final state + info = masterClient.ClusterInfo(ctx).Val() + require.Contains(t, info, "cluster_failover_state:success") + }) +} + +// TestFailoverTakeoverCommand tests the TAKEOVER command handling on slave. +// Test Case 5.2: TAKEOVER Command Processing - Slave receives and processes TAKEOVER command +func TestFailoverTakeoverCommand(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 5.2: TAKEOVER Command Processing + t.Run("FAILOVER - TAKEOVER command on slave", func(t *testing.T) { + // Slave should accept TAKEOVER command + require.Equal(t, "OK", slaveClient.Do(ctx, "clusterx", "takeover").Val()) + + // Verify imported slots are set + // After takeover, slave should be able to serve the slots + require.NoError(t, masterClient.Set(ctx, "takeoverkey", "takeovervalue", 0).Err()) + + // Start failover to test the full flow + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "2").Err()) + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + require.Equal(t, "OK", masterClient.Do(ctx, "clusterx", "failover", slaveID).Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 10*time.Second) + }) +} + +// TestFailoverDataConsistency tests data consistency after failover. +// Test Case 5.4: Data Consistency Verification - All data is replicated to new master without loss +func TestFailoverDataConsistency(t *testing.T) { + ctx := context.Background() + + master := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { master.Close() }() + masterClient := master.NewClient() + defer func() { require.NoError(t, masterClient.Close()) }() + masterID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx00" + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODEID", masterID).Err()) + + slave := util.StartServer(t, map[string]string{"cluster-enabled": "yes"}) + defer func() { slave.Close() }() + slaveClient := slave.NewClient() + defer func() { require.NoError(t, slaveClient.Close()) }() + slaveID := "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx01" + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODEID", slaveID).Err()) + + clusterNodes := fmt.Sprintf("%s %s %d master - 0-16383\n", masterID, master.Host(), master.Port()) + clusterNodes += fmt.Sprintf("%s %s %d slave %s", slaveID, slave.Host(), slave.Port(), masterID) + require.NoError(t, masterClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + require.NoError(t, slaveClient.Do(ctx, "clusterx", "SETNODES", clusterNodes, "1").Err()) + + require.Eventually(t, func() bool { + info := masterClient.Info(ctx, "replication").Val() + return strings.Contains(info, "connected_slaves:1") + }, 10*time.Second, 100*time.Millisecond) + + // Test Case 5.4: Data Consistency Verification + t.Run("FAILOVER - Data consistency after failover", func(t *testing.T) { + // Write various types of data + require.NoError(t, masterClient.Set(ctx, "string_key", "string_value", 0).Err()) + require.NoError(t, masterClient.LPush(ctx, "list_key", "item1", "item2", "item3").Err()) + require.NoError(t, masterClient.HSet(ctx, "hash_key", "field1", "value1", "field2", "value2").Err()) + require.NoError(t, masterClient.SAdd(ctx, "set_key", "member1", "member2").Err()) + require.NoError(t, masterClient.ZAdd(ctx, "zset_key", redis.Z{Score: 1.0, Member: "member1"}).Err()) + + // Start failover with longer timeout to ensure data sync + result := masterClient.Do(ctx, "clusterx", "failover", slaveID, "10000") + if result.Err() != nil { + t.Logf("FAILOVER command error: %v", result.Err()) + } + require.NoError(t, result.Err(), "FAILOVER command should succeed") + require.Equal(t, "OK", result.Val()) + waitForFailoverState(t, masterClient, FailoverStateSuccess, 15*time.Second) + + // Verify all data is on new master + require.Equal(t, "string_value", slaveClient.Get(ctx, "string_key").Val()) + require.EqualValues(t, []string{"item3", "item2", "item1"}, slaveClient.LRange(ctx, "list_key", 0, -1).Val()) + require.Equal(t, map[string]string{"field1": "value1", "field2": "value2"}, slaveClient.HGetAll(ctx, "hash_key").Val()) + require.EqualValues(t, []string{"member1", "member2"}, slaveClient.SMembers(ctx, "set_key").Val()) Review Comment: The SMembers result order is not deterministic in Redis - sets are unordered collections. The assertion comparing the result to a specific ordered slice may fail intermittently if Redis returns the members in a different order. Consider using a comparison that doesn't depend on order, such as converting both to sets or sorting before comparison. ```suggestion require.ElementsMatch(t, []string{"member1", "member2"}, slaveClient.SMembers(ctx, "set_key").Val()) ``` ########## src/cluster/cluster_failover.cc: ########## @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cluster_failover.h" + +#include <unistd.h> + +#include "cluster/cluster.h" +#include "common/io_util.h" +#include "common/time_util.h" +#include "logging.h" +#include "server/redis_reply.h" +#include "server/server.h" + +ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) { + t_ = std::thread([this]() { loop(); }); +} + +ClusterFailover::~ClusterFailover() { + { + std::lock_guard<std::mutex> lock(mutex_); + stop_thread_ = true; + cv_.notify_all(); + } + if (t_.joinable()) t_.join(); +} + +Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) { + std::lock_guard<std::mutex> lock(mutex_); + if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) { + return {Status::NotOK, "Failover is already in progress"}; + } + + if (srv_->IsSlave()) { + return {Status::NotOK, "Current node is a slave, can't failover"}; + } + + slave_node_id_ = std::move(slave_node_id); + timeout_ms_ = timeout_ms; + state_ = FailoverState::kStarted; + failover_job_triggered_ = true; + cv_.notify_all(); + return Status::OK(); +} + +void ClusterFailover::loop() { + while (true) { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, [this]() { return stop_thread_ || failover_job_triggered_; }); + + if (stop_thread_) return; + + if (failover_job_triggered_) { + failover_job_triggered_ = false; + lock.unlock(); + runFailoverProcess(); + } + } +} + +void ClusterFailover::runFailoverProcess() { + auto ip_port = srv_->cluster->GetNodeIPPort(slave_node_id_); + if (!ip_port.IsOK()) { + error("[Failover] slave node not found in cluster {}", slave_node_id_); + abortFailover("Slave node not found in cluster"); + return; + } + node_ip_port_ = ip_port.GetValue().first + ":" + std::to_string(ip_port.GetValue().second); + node_ip_ = ip_port.GetValue().first; + node_port_ = ip_port.GetValue().second; + info("[Failover] slave node {} {} failover state: {}", slave_node_id_, node_ip_port_, static_cast<int>(state_.load())); + state_ = FailoverState::kCheck; + + auto s = checkSlaveStatus(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + s = checkSlaveLag(); + if (!s.IsOK()) { + abortFailover("Slave lag check failed: " + s.Msg()); + return; + } + + info("[Failover] slave node {} {} check slave status success, enter pause state", slave_node_id_, node_ip_port_); + start_time_ms_ = util::GetTimeStampMS(); + // Enter Pause state (Stop writing) + state_ = FailoverState::kPause; + // Get current sequence + target_seq_ = srv_->storage->LatestSeqNumber(); + info("[Failover] slave node {} {} target sequence {}", slave_node_id_, node_ip_port_, target_seq_); + + state_ = FailoverState::kSyncWait; + s = waitReplicationSync(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + info("[Failover] slave node {} {} wait replication sync success, enter switch state, cost {} ms", slave_node_id_, + node_ip_port_, util::GetTimeStampMS() - start_time_ms_); + + state_ = FailoverState::kSwitch; + s = sendTakeoverCmd(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + // Redirect slots + srv_->cluster->SetMySlotsMigrated(node_ip_port_); + + state_ = FailoverState::kSuccess; + info("[Failover] success {} {}", slave_node_id_, node_ip_port_); +} + +Status ClusterFailover::checkSlaveLag() { + auto start_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!start_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + start_offset_status.Msg()}; + } + uint64_t start_offset = *start_offset_status; + int64_t start_sampling_ms = util::GetTimeStampMS(); + + // Wait 3s or half of timeout, but at least a bit to measure speed + int64_t wait_time = std::max(100, std::min(3000, timeout_ms_ / 2)); + std::this_thread::sleep_for(std::chrono::milliseconds(wait_time)); + + auto end_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!end_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + end_offset_status.Msg()}; + } + uint64_t end_offset = *end_offset_status; + int64_t end_sampling_ms = util::GetTimeStampMS(); + + double elapsed_sec = (end_sampling_ms - start_sampling_ms) / 1000.0; + if (elapsed_sec <= 0) elapsed_sec = 0.001; + + uint64_t bytes = 0; + if (end_offset > start_offset) bytes = end_offset - start_offset; + double speed = bytes / elapsed_sec; + + uint64_t master_seq = srv_->storage->LatestSeqNumber(); + uint64_t lag = 0; + if (master_seq > end_offset) lag = master_seq - end_offset; + + if (lag == 0) return Status::OK(); + + if (speed <= 0.1) { // Basically 0 + return {Status::NotOK, fmt::format("Slave is not replicating (lag: {})", lag)}; + } + + double required_sec = lag / speed; + int64_t required_ms = static_cast<int64_t>(required_sec * 1000); + + int64_t elapsed_total = end_sampling_ms - start_sampling_ms; + int64_t remaining = timeout_ms_ - elapsed_total; + + if (required_ms > remaining) { + return {Status::NotOK, fmt::format("Estimated catchup time {}ms > remaining time {}ms (lag: {}, speed: {:.2f}/s)", + required_ms, remaining, lag, speed)}; + } + + info("[Failover] check: lag={}, speed={:.2f}/s, estimated_time={}ms, remaining={}ms", lag, speed, required_ms, + remaining); + return Status::OK(); +} + +Status ClusterFailover::checkSlaveStatus() { + // We could try to connect, but GetSlaveReplicationOffset checks connection. + auto offset = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!offset.IsOK()) { + error("[Failover] slave node {} {} not connected or not syncing", slave_node_id_, node_ip_port_); + return {Status::NotOK, "Slave not connected or not syncing"}; + } + info("[Failover] slave node {} {} is connected and syncing offset {}", slave_node_id_, node_ip_port_, offset.Msg()); + return Status::OK(); +} + +Status ClusterFailover::waitReplicationSync() { + while (true) { + if (util::GetTimeStampMS() - start_time_ms_ > static_cast<uint64_t>(timeout_ms_)) { + return {Status::NotOK, "Timeout waiting for replication sync"}; + } + + auto offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + offset_status.Msg()}; + } + + if (*offset_status >= target_seq_) { + return Status::OK(); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } +} + +Status ClusterFailover::sendTakeoverCmd() { + auto s = util::SockConnect(node_ip_, node_port_); + if (!s.IsOK()) { + return {Status::NotOK, "Failed to connect to slave: " + s.Msg()}; + } + int fd = *s; + + std::string pass = srv_->GetConfig()->requirepass; + if (!pass.empty()) { + std::string auth_cmd = redis::ArrayOfBulkStrings({"AUTH", pass}); + auto s_auth = util::SockSend(fd, auth_cmd); + if (!s_auth.IsOK()) { + close(fd); + return {Status::NotOK, "Failed to send AUTH: " + s_auth.Msg()}; + } + auto s_line = util::SockReadLine(fd); + if (!s_line.IsOK() || s_line.GetValue().substr(0, 3) != "+OK") { + close(fd); + return {Status::NotOK, "AUTH failed"}; + } + } + + std::string cmd = redis::ArrayOfBulkStrings({"CLUSTERX", "TAKEOVER"}); + auto s_send = util::SockSend(fd, cmd); + if (!s_send.IsOK()) { + close(fd); + return {Status::NotOK, "Failed to send TAKEOVER: " + s_send.Msg()}; + } + + auto s_resp = util::SockReadLine(fd); + close(fd); + + if (!s_resp.IsOK()) { + return {Status::NotOK, "Failed to read TAKEOVER response: " + s_resp.Msg()}; + } + + if (s_resp.GetValue().substr(0, 3) != "+OK") { + return {Status::NotOK, "TAKEOVER failed: " + s_resp.GetValue()}; + } Review Comment: Potential resource leak: if SockReadLine fails or returns an unexpected response, the socket fd is closed. However, if GetValue().substr() throws an exception (e.g., if the response is less than 3 characters), the socket will not be closed. Consider using RAII or ensuring close(fd) is called in all paths, including exception paths. ########## src/commands/cmd_cluster.cc: ########## @@ -237,7 +238,29 @@ class CommandClusterX : public Commander { return Status::OK(); } - return {Status::RedisParseErr, "CLUSTERX command, CLUSTERX VERSION|MYID|SETNODEID|SETNODES|SETSLOT|MIGRATE"}; + if (subcommand_ == "failover") { + if (args.size() != 3 && args.size() != 4) return {Status::RedisParseErr, errWrongNumOfArguments}; + + slave_node_id_ = args_[2]; + + if (args.size() == 4) { + auto parse_result = ParseInt<int>(args_[3], 10); + if (!parse_result) return {Status::RedisParseErr, "Invalid timeout"}; + if (*parse_result < 0) return {Status::RedisParseErr, errTimeoutIsNegative}; + failover_timeout_ = *parse_result; + } else { + failover_timeout_ = 1000; + } + return Status::OK(); + } + + if (subcommand_ == "takeover") { + if (args.size() != 2) return {Status::RedisParseErr, errWrongNumOfArguments}; + return Status::OK(); + } + + return {Status::RedisParseErr, + "CLUSTERX command, CLUSTERX VERSION|MYID|SETNODEID|SETNODES|SETSLOT|MIGRATE|FAILOVER"}; Review Comment: The help message in the error does not include "TAKEOVER" which is a valid subcommand added in this PR. The message should be updated to include TAKEOVER in the list of valid subcommands. ```suggestion "CLUSTERX command, CLUSTERX VERSION|MYID|SETNODEID|SETNODES|SETSLOT|MIGRATE|FAILOVER|TAKEOVER"}; ``` ########## src/cluster/cluster_failover.cc: ########## @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cluster_failover.h" + +#include <unistd.h> + +#include "cluster/cluster.h" +#include "common/io_util.h" +#include "common/time_util.h" +#include "logging.h" +#include "server/redis_reply.h" +#include "server/server.h" + +ClusterFailover::ClusterFailover(Server *srv) : srv_(srv) { + t_ = std::thread([this]() { loop(); }); +} + +ClusterFailover::~ClusterFailover() { + { + std::lock_guard<std::mutex> lock(mutex_); + stop_thread_ = true; + cv_.notify_all(); + } + if (t_.joinable()) t_.join(); +} + +Status ClusterFailover::Run(std::string slave_node_id, int timeout_ms) { + std::lock_guard<std::mutex> lock(mutex_); + if (state_ != FailoverState::kNone && state_ != FailoverState::kFailed) { + return {Status::NotOK, "Failover is already in progress"}; + } + + if (srv_->IsSlave()) { + return {Status::NotOK, "Current node is a slave, can't failover"}; + } + + slave_node_id_ = std::move(slave_node_id); + timeout_ms_ = timeout_ms; + state_ = FailoverState::kStarted; + failover_job_triggered_ = true; + cv_.notify_all(); + return Status::OK(); +} + +void ClusterFailover::loop() { + while (true) { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, [this]() { return stop_thread_ || failover_job_triggered_; }); + + if (stop_thread_) return; + + if (failover_job_triggered_) { + failover_job_triggered_ = false; + lock.unlock(); + runFailoverProcess(); + } + } +} + +void ClusterFailover::runFailoverProcess() { + auto ip_port = srv_->cluster->GetNodeIPPort(slave_node_id_); + if (!ip_port.IsOK()) { + error("[Failover] slave node not found in cluster {}", slave_node_id_); + abortFailover("Slave node not found in cluster"); + return; + } + node_ip_port_ = ip_port.GetValue().first + ":" + std::to_string(ip_port.GetValue().second); + node_ip_ = ip_port.GetValue().first; + node_port_ = ip_port.GetValue().second; + info("[Failover] slave node {} {} failover state: {}", slave_node_id_, node_ip_port_, static_cast<int>(state_.load())); + state_ = FailoverState::kCheck; + + auto s = checkSlaveStatus(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + s = checkSlaveLag(); + if (!s.IsOK()) { + abortFailover("Slave lag check failed: " + s.Msg()); + return; + } + + info("[Failover] slave node {} {} check slave status success, enter pause state", slave_node_id_, node_ip_port_); + start_time_ms_ = util::GetTimeStampMS(); + // Enter Pause state (Stop writing) + state_ = FailoverState::kPause; + // Get current sequence + target_seq_ = srv_->storage->LatestSeqNumber(); + info("[Failover] slave node {} {} target sequence {}", slave_node_id_, node_ip_port_, target_seq_); + + state_ = FailoverState::kSyncWait; + s = waitReplicationSync(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + info("[Failover] slave node {} {} wait replication sync success, enter switch state, cost {} ms", slave_node_id_, + node_ip_port_, util::GetTimeStampMS() - start_time_ms_); + + state_ = FailoverState::kSwitch; + s = sendTakeoverCmd(); + if (!s.IsOK()) { + abortFailover(s.Msg()); + return; + } + + // Redirect slots + srv_->cluster->SetMySlotsMigrated(node_ip_port_); + + state_ = FailoverState::kSuccess; + info("[Failover] success {} {}", slave_node_id_, node_ip_port_); +} + +Status ClusterFailover::checkSlaveLag() { + auto start_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!start_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + start_offset_status.Msg()}; + } + uint64_t start_offset = *start_offset_status; + int64_t start_sampling_ms = util::GetTimeStampMS(); + + // Wait 3s or half of timeout, but at least a bit to measure speed + int64_t wait_time = std::max(100, std::min(3000, timeout_ms_ / 2)); + std::this_thread::sleep_for(std::chrono::milliseconds(wait_time)); + + auto end_offset_status = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!end_offset_status.IsOK()) { + return {Status::NotOK, "Failed to get slave offset: " + end_offset_status.Msg()}; + } + uint64_t end_offset = *end_offset_status; + int64_t end_sampling_ms = util::GetTimeStampMS(); + + double elapsed_sec = (end_sampling_ms - start_sampling_ms) / 1000.0; + if (elapsed_sec <= 0) elapsed_sec = 0.001; + + uint64_t bytes = 0; + if (end_offset > start_offset) bytes = end_offset - start_offset; + double speed = bytes / elapsed_sec; + + uint64_t master_seq = srv_->storage->LatestSeqNumber(); + uint64_t lag = 0; + if (master_seq > end_offset) lag = master_seq - end_offset; + + if (lag == 0) return Status::OK(); + + if (speed <= 0.1) { // Basically 0 + return {Status::NotOK, fmt::format("Slave is not replicating (lag: {})", lag)}; + } + + double required_sec = lag / speed; + int64_t required_ms = static_cast<int64_t>(required_sec * 1000); + + int64_t elapsed_total = end_sampling_ms - start_sampling_ms; + int64_t remaining = timeout_ms_ - elapsed_total; + + if (required_ms > remaining) { + return {Status::NotOK, fmt::format("Estimated catchup time {}ms > remaining time {}ms (lag: {}, speed: {:.2f}/s)", + required_ms, remaining, lag, speed)}; + } + + info("[Failover] check: lag={}, speed={:.2f}/s, estimated_time={}ms, remaining={}ms", lag, speed, required_ms, + remaining); + return Status::OK(); +} + +Status ClusterFailover::checkSlaveStatus() { + // We could try to connect, but GetSlaveReplicationOffset checks connection. + auto offset = srv_->GetSlaveReplicationOffset(node_ip_port_); + if (!offset.IsOK()) { + error("[Failover] slave node {} {} not connected or not syncing", slave_node_id_, node_ip_port_); + return {Status::NotOK, "Slave not connected or not syncing"}; + } + info("[Failover] slave node {} {} is connected and syncing offset {}", slave_node_id_, node_ip_port_, offset.Msg()); Review Comment: The log message uses offset.Msg() to log the slave's replication offset, but offset is a StatusOr containing a sequence number, not a Status object. The Msg() method is intended for error messages, not for logging numeric values. This should use the actual value, e.g., *offset or offset.GetValue(). ```suggestion info("[Failover] slave node {} {} is connected and syncing offset {}", slave_node_id_, node_ip_port_, *offset); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
