This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks-controller.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 83355d5  Add support of changing raft peers via HTTP api (#232)
83355d5 is described below

commit 83355d54727509e13952e67645f476768248f266
Author: hulk <[email protected]>
AuthorDate: Tue Dec 17 23:48:10 2024 +0800

    Add support of changing raft peers via HTTP api (#232)
---
 consts/context_key.go            |   2 +-
 controller/controller.go         |   1 +
 server/api/handler.go            |   6 ++-
 server/api/raft.go               | 110 +++++++++++++++++++++++++++++++++++++++
 server/middleware/middleware.go  |  19 ++++++-
 server/route.go                  |   7 +++
 store/engine/engine.go           |   1 +
 store/engine/raft/config.go      |  19 +++++--
 store/engine/raft/config_test.go |   7 +++
 store/engine/raft/node.go        |  46 ++++++++++++----
 store/engine/raft/node_test.go   |  21 ++++++--
 store/engine/raft/store.go       |  18 ++++---
 store/engine/raft/store_test.go  |   6 +--
 store/store.go                   |   4 ++
 14 files changed, 236 insertions(+), 31 deletions(-)

diff --git a/consts/context_key.go b/consts/context_key.go
index d0d2080..cfb0e3e 100644
--- a/consts/context_key.go
+++ b/consts/context_key.go
@@ -23,7 +23,7 @@ const (
        ContextKeyStore        = "_context_key_storage"
        ContextKeyCluster      = "_context_key_cluster"
        ContextKeyClusterShard = "_context_key_cluster_shard"
-       ContextKeyHost         = "_context_key_host"
+       ContextKeyRaftNode     = "_context_key_raft_node"
 )
 
 const (
diff --git a/controller/controller.go b/controller/controller.go
index cd302d7..30bcdfb 100644
--- a/controller/controller.go
+++ b/controller/controller.go
@@ -104,6 +104,7 @@ func (c *Controller) resume(ctx context.Context) error {
                }
                for _, cluster := range clusters {
                        c.addCluster(ns, cluster)
+                       logger.Get().Debug("Resume the cluster", 
zap.String("namespace", ns), zap.String("cluster", cluster))
                }
        }
        return nil
diff --git a/server/api/handler.go b/server/api/handler.go
index 262b4a2..d80b3ea 100644
--- a/server/api/handler.go
+++ b/server/api/handler.go
@@ -20,13 +20,16 @@
 
 package api
 
-import "github.com/apache/kvrocks-controller/store"
+import (
+       "github.com/apache/kvrocks-controller/store"
+)
 
 type Handler struct {
        Namespace *NamespaceHandler
        Cluster   *ClusterHandler
        Shard     *ShardHandler
        Node      *NodeHandler
+       Raft      *RaftHandler
 }
 
 func NewHandler(s *store.ClusterStore) *Handler {
@@ -35,5 +38,6 @@ func NewHandler(s *store.ClusterStore) *Handler {
                Cluster:   &ClusterHandler{s: s},
                Shard:     &ShardHandler{s: s},
                Node:      &NodeHandler{s: s},
+               Raft:      &RaftHandler{},
        }
 }
diff --git a/server/api/raft.go b/server/api/raft.go
new file mode 100644
index 0000000..fde1d58
--- /dev/null
+++ b/server/api/raft.go
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+package api
+
+import (
+       "errors"
+       "fmt"
+       "strings"
+
+       "github.com/apache/kvrocks-controller/consts"
+       "github.com/apache/kvrocks-controller/logger"
+       "github.com/apache/kvrocks-controller/server/helper"
+       "github.com/apache/kvrocks-controller/store/engine/raft"
+
+       "github.com/gin-gonic/gin"
+       "go.uber.org/zap"
+)
+
+const (
+       OperationAdd    = "add"
+       OperationRemove = "remove"
+)
+
+type RaftHandler struct{}
+
+type MemberRequest struct {
+       ID        uint64 `json:"id" validate:"required,gt=0"`
+       Operation string `json:"operation" validate:"required"`
+       Peer      string `json:"peer"`
+}
+
+func (r *MemberRequest) validate() error {
+       r.Operation = strings.ToLower(r.Operation)
+       if r.Operation != OperationAdd && r.Operation != OperationRemove {
+               return fmt.Errorf("operation must be one of [%s]",
+                       strings.Join([]string{OperationAdd, OperationRemove}, 
","))
+       }
+       if r.Operation == OperationAdd && r.Peer == "" {
+               return fmt.Errorf("peer should NOT be empty")
+       }
+       return nil
+}
+
+func (handler *RaftHandler) ListPeers(c *gin.Context) {
+       raftNode, _ := c.MustGet(consts.ContextKeyRaftNode).(*raft.Node)
+       helper.ResponseOK(c, gin.H{
+               "leader": raftNode.GetRaftLead(),
+               "peers":  raftNode.ListPeers(),
+       })
+}
+
+func (handler *RaftHandler) UpdatePeer(c *gin.Context) {
+       var req MemberRequest
+       if err := c.BindJSON(&req); err != nil {
+               helper.ResponseBadRequest(c, err)
+               return
+       }
+       if err := req.validate(); err != nil {
+               helper.ResponseBadRequest(c, err)
+               return
+       }
+
+       raftNode, _ := c.MustGet(consts.ContextKeyRaftNode).(*raft.Node)
+       peers := raftNode.ListPeers()
+
+       var err error
+       if req.Operation == OperationAdd {
+               for _, peer := range peers {
+                       if peer == req.Peer {
+                               helper.ResponseError(c, fmt.Errorf("peer '%s' 
already exists", req.Peer))
+                               return
+                       }
+               }
+               err = raftNode.AddPeer(c, req.ID, req.Peer)
+       } else {
+               if _, ok := peers[req.ID]; !ok {
+                       helper.ResponseBadRequest(c, errors.New("peer not 
exists"))
+                       return
+               }
+               if len(peers) == 1 {
+                       helper.ResponseBadRequest(c, errors.New("can't remove 
the last peer"))
+                       return
+               }
+               err = raftNode.RemovePeer(c, req.ID)
+       }
+       if err != nil {
+               helper.ResponseError(c, err)
+       } else {
+               logger.Get().With(zap.Any("request", req)).Info("Update peer 
success")
+               helper.ResponseOK(c, nil)
+       }
+}
diff --git a/server/middleware/middleware.go b/server/middleware/middleware.go
index 9e87939..36f4056 100644
--- a/server/middleware/middleware.go
+++ b/server/middleware/middleware.go
@@ -26,6 +26,7 @@ import (
        "strconv"
        "time"
 
+       "github.com/apache/kvrocks-controller/store/engine/raft"
        "github.com/gin-gonic/gin"
        "github.com/prometheus/client_golang/prometheus"
 
@@ -66,7 +67,11 @@ func RedirectIfNotLeader(c *gin.Context) {
                c.Abort()
                return
        }
-       if !storage.IsLeader() {
+
+       _, isRaftMode := storage.GetEngine().(*raft.Node)
+       // Raft engine will forward the request to the leader node under the 
hood,
+       // so we don't need to do the redirect.
+       if !storage.IsLeader() && !isRaftMode {
                if !c.GetBool(consts.HeaderIsRedirect) {
                        c.Set(consts.HeaderIsRedirect, true)
                        peerAddr := 
helper.ExtractAddrFromSessionID(storage.Leader())
@@ -131,3 +136,15 @@ func RequiredClusterShard(c *gin.Context) {
        c.Set(consts.ContextKeyClusterShard, shard)
        c.Next()
 }
+
+func RequiredRaftEngine(c *gin.Context) {
+       storage, _ := c.MustGet(consts.ContextKeyStore).(*store.ClusterStore)
+       raftNode, ok := storage.GetEngine().(*raft.Node)
+       if !ok {
+               helper.ResponseBadRequest(c, errors.New("raft engine is not 
enabled"))
+               c.Abort()
+               return
+       }
+       c.Set(consts.ContextKeyRaftNode, raftNode)
+       c.Next()
+}
diff --git a/server/route.go b/server/route.go
index 90d0455..5e44ebe 100644
--- a/server/route.go
+++ b/server/route.go
@@ -47,6 +47,13 @@ func (srv *Server) initHandlers() {
 
        apiV1 := engine.Group("/api/v1/")
        {
+               raftAPI := apiV1.Group("raft")
+               {
+                       raftAPI.Use(middleware.RequiredRaftEngine)
+                       raftAPI.POST("/peers", handler.Raft.UpdatePeer)
+                       raftAPI.GET("/peers", handler.Raft.ListPeers)
+               }
+
                namespaces := apiV1.Group("namespaces")
                {
                        namespaces.GET("", handler.Namespace.List)
diff --git a/store/engine/engine.go b/store/engine/engine.go
index b25685c..e5f9259 100644
--- a/store/engine/engine.go
+++ b/store/engine/engine.go
@@ -17,6 +17,7 @@
  * under the License.
  *
  */
+
 package engine
 
 import (
diff --git a/store/engine/raft/config.go b/store/engine/raft/config.go
index 7a3b52e..d545f5c 100644
--- a/store/engine/raft/config.go
+++ b/store/engine/raft/config.go
@@ -20,15 +20,23 @@
 
 package raft
 
-import "errors"
+import (
+       "errors"
+       "strings"
+)
+
+const (
+       ClusterStateNew      = "new"
+       ClusterStateExisting = "existing"
+)
 
 type Config struct {
        // ID is the identity of the local raft. ID cannot be 0.
        ID uint64 `yaml:"id"`
        // DataDir is the directory to store the raft data which includes 
snapshot and WALs.
        DataDir string `yaml:"data_dir"`
-       // Join should be set to true if the node is joining an existing 
cluster.
-       Join bool `yaml:"join"`
+       // ClusterState is the state of the cluster, can be one of "new" and 
"existing".
+       ClusterState string `yaml:"cluster_state"`
        // Peers is the list of raft peers.
        Peers []string `yaml:"peers"`
        // HeartbeatSeconds is the interval to send heartbeat message. Default 
is 2 seconds.
@@ -47,10 +55,15 @@ func (c *Config) validate() error {
        if c.ID > uint64(len(c.Peers)) {
                return errors.New("ID cannot be greater than the number of 
peers")
        }
+       clusterState := strings.ToLower(c.ClusterState)
+       if clusterState != ClusterStateNew && clusterState != 
ClusterStateExisting {
+               return errors.New("cluster state must be one of [new, 
existing]")
+       }
        return nil
 }
 
 func (c *Config) init() {
+       c.ClusterState = ClusterStateNew
        if c.DataDir == "" {
                c.DataDir = "."
        }
diff --git a/store/engine/raft/config_test.go b/store/engine/raft/config_test.go
index 88ea00d..e2f61b0 100644
--- a/store/engine/raft/config_test.go
+++ b/store/engine/raft/config_test.go
@@ -28,6 +28,7 @@ import (
 
 func TestConfig_Validate(t *testing.T) {
        c := &Config{}
+       c.init()
 
        // missing ID
        require.ErrorContains(t, c.validate(), "ID cannot be 0")
@@ -40,6 +41,12 @@ func TestConfig_Validate(t *testing.T) {
        // ID greater than the number of peers
        c.ID = 2
        require.ErrorContains(t, c.validate(), "ID cannot be greater than the 
number of peers")
+
+       c.ID = 1
+       c.ClusterState = "invalid"
+       require.ErrorContains(t, c.validate(), "cluster state must be one of 
[new, existing]")
+       c.ClusterState = ClusterStateNew
+       require.NoError(t, c.validate())
 }
 
 func TestConfig_Init(t *testing.T) {
diff --git a/store/engine/raft/node.go b/store/engine/raft/node.go
index dcfa859..3786f58 100644
--- a/store/engine/raft/node.go
+++ b/store/engine/raft/node.go
@@ -115,11 +115,12 @@ func (n *Node) Addr() string {
        return n.addr
 }
 
-func (n *Node) Peers() []string {
-       peers := make([]string, 0)
+func (n *Node) ListPeers() map[uint64]string {
+       peers := make(map[uint64]string)
        n.peers.Range(func(key, value interface{}) bool {
+               id, _ := key.(uint64)
                peer, _ := value.(string)
-               peers = append(peers, peer)
+               peers[id] = peer
                return true
        })
        return peers
@@ -165,7 +166,7 @@ func (n *Node) run() error {
        n.snapshotIndex = snapshot.Metadata.Index
        n.confState = snapshot.Metadata.ConfState
 
-       if n.config.Join || walExists {
+       if n.config.ClusterState == ClusterStateExisting || walExists {
                n.raftNode = raft.RestartNode(raftConfig)
        } else {
                n.raftNode = raft.StartNode(raftConfig, peers)
@@ -174,6 +175,7 @@ func (n *Node) run() error {
        if err := n.runTransport(); err != nil {
                return err
        }
+       n.watchLeaderChange()
        return n.runRaftMessages()
 }
 
@@ -224,10 +226,33 @@ func (n *Node) runTransport() error {
        return nil
 }
 
+func (n *Node) watchLeaderChange() {
+       n.wg.Add(1)
+       go func() {
+               defer n.wg.Done()
+
+               ticker := time.NewTicker(time.Second)
+               defer ticker.Stop()
+               for {
+                       select {
+                       case <-n.shutdown:
+                               return
+                       case <-ticker.C:
+                               lead := n.GetRaftLead()
+                               if lead != n.leader {
+                                       n.leader = lead
+                                       n.leaderChanged <- true
+                                       n.logger.Info("Found leader changed", 
zap.Uint64("leader", lead))
+                               }
+                       }
+               }
+       }()
+}
+
 func (n *Node) runRaftMessages() error {
        n.wg.Add(1)
        go func() {
-               ticker := time.NewTicker(100 * time.Millisecond)
+               ticker := time.NewTicker(time.Second)
                defer func() {
                        ticker.Stop()
                        n.wg.Done()
@@ -252,9 +277,7 @@ func (n *Node) runRaftMessages() error {
                                if err := n.applySnapshot(rd.Snapshot); err != 
nil {
                                        n.logger.Error("Failed to apply 
snapshot", zap.Error(err))
                                }
-                               if len(rd.Entries) > 0 {
-                                       _ = 
n.dataStore.raftStorage.Append(rd.Entries)
-                               }
+                               _ = n.dataStore.raftStorage.Append(rd.Entries)
 
                                for _, msg := range rd.Messages {
                                        if msg.Type == raftpb.MsgApp {
@@ -465,13 +488,14 @@ func (n *Node) applyEntry(entry raftpb.Entry) error {
                                n.peers.Store(cc.NodeID, string(cc.Context))
                        }
                case raftpb.ConfChangeRemoveNode:
-                       n.peers.Delete(cc.NodeID)
-                       n.transport.RemovePeer(types.ID(cc.NodeID))
                        if cc.NodeID == n.config.ID {
+                               n.logger.Info("I have been removed from the 
cluster, will shutdown")
                                n.Close()
-                               n.logger.Info("Node removed from the cluster")
                                return nil
                        }
+                       n.transport.RemovePeer(types.ID(cc.NodeID))
+                       n.peers.Delete(cc.NodeID)
+                       n.logger.Info("Remove the peer", zap.Uint64("node_id", 
cc.NodeID))
                case raftpb.ConfChangeUpdateNode:
                        n.transport.UpdatePeer(types.ID(cc.NodeID), 
[]string{string(cc.Context)})
                        if _, ok := n.peers.Load(cc.NodeID); ok {
diff --git a/store/engine/raft/node_test.go b/store/engine/raft/node_test.go
index 1960161..15643b6 100644
--- a/store/engine/raft/node_test.go
+++ b/store/engine/raft/node_test.go
@@ -54,6 +54,11 @@ func NewTestCluster(n int) *TestCluster {
                        HeartbeatSeconds: 1,
                        ElectionSeconds:  2,
                })
+               // drain leader change events
+               go func() {
+                       for range nodes[i].LeaderChange() {
+                       }
+               }()
        }
        return &TestCluster{nodes: nodes}
 }
@@ -68,7 +73,11 @@ func (c *TestCluster) createNode(peers []string) (*Node, 
error) {
                HeartbeatSeconds: 1,
                ElectionSeconds:  2,
        })
-       return node, err
+       if err != nil {
+               return nil, err
+       }
+       c.nodes = append(c.nodes, node)
+       return node, nil
 }
 
 func (c *TestCluster) AddNode(ctx context.Context, nodeID uint64, peer string) 
error {
@@ -79,16 +88,18 @@ func (c *TestCluster) AddNode(ctx context.Context, nodeID 
uint64, peer string) e
 }
 
 func (c *TestCluster) RemoveNode(ctx context.Context, nodeID uint64) error {
+       var node *Node
        for i, n := range c.nodes {
                if n.config.ID == nodeID {
+                       node = n
                        c.nodes = append(c.nodes[:i], c.nodes[i+1:]...)
                        break
                }
        }
-       if len(c.nodes) == 0 {
+       if len(c.nodes) == 0 || node == nil {
                return nil
        }
-       return c.nodes[0].RemovePeer(ctx, nodeID)
+       return node.RemovePeer(ctx, nodeID)
 }
 
 func (c *TestCluster) SetSnapshotThreshold(threshold uint64) {
@@ -234,13 +245,13 @@ func TestCluster_AddRemovePeer(t *testing.T) {
                        got, _ := n1.Get(ctx, "foo")
                        return string(got) == "bar-1"
                }, 1*time.Second, 100*time.Millisecond)
-               require.Len(t, n1.Peers(), 4)
+               require.Len(t, n1.ListPeers(), 4)
        })
 
        t.Run("remove a peer node", func(t *testing.T) {
                cluster.RemoveNode(ctx, 4)
                require.Eventually(t, func() bool {
-                       return len(n1.Peers()) == 3
+                       return len(n1.ListPeers()) == 3
                }, 10*time.Second, 100*time.Millisecond)
        })
 }
diff --git a/store/engine/raft/store.go b/store/engine/raft/store.go
index fd3f773..b549c8f 100644
--- a/store/engine/raft/store.go
+++ b/store/engine/raft/store.go
@@ -235,13 +235,19 @@ func (ds *DataStore) List(prefix string) []engine.Entry {
        ds.mu.RLock()
        defer ds.mu.RUnlock()
        entries := make([]engine.Entry, 0)
-       for k := range ds.kvs {
-               if strings.HasPrefix(k, prefix) {
-                       entries = append(entries, engine.Entry{
-                               Key:   strings.TrimLeft(strings.TrimPrefix(k, 
prefix), "/"),
-                               Value: ds.kvs[k],
-                       })
+       for key := range ds.kvs {
+               if !strings.HasPrefix(key, prefix) || key == prefix {
+                       continue
                }
+               trimmedKey := strings.TrimLeft(key[len(prefix)+1:], "/")
+               if strings.ContainsRune(trimmedKey, '/') {
+                       continue
+               }
+
+               entries = append(entries, engine.Entry{
+                       Key:   trimmedKey,
+                       Value: ds.kvs[trimmedKey],
+               })
        }
        slices.SortFunc(entries, func(i, j engine.Entry) int {
                return strings.Compare(i.Key, j.Key)
diff --git a/store/engine/raft/store_test.go b/store/engine/raft/store_test.go
index de378dd..b679603 100644
--- a/store/engine/raft/store_test.go
+++ b/store/engine/raft/store_test.go
@@ -97,10 +97,10 @@ func TestDataStore(t *testing.T) {
                entries = store.List("ba")
                require.Len(t, entries, 4)
 
-               entries = store.List("bar-2")
-               require.Len(t, entries, 1)
+               entries = store.List("bar")
+               require.Len(t, entries, 2)
 
-               entries = store.List("foo")
+               entries = store.List("fo")
                require.Len(t, entries, 1)
 
                store.Delete("bar-2")
diff --git a/store/store.go b/store/store.go
index ce730b4..6c64980 100644
--- a/store/store.go
+++ b/store/store.go
@@ -310,6 +310,10 @@ func (s *ClusterStore) EmitEvent(event EventPayload) {
        s.eventNotifyCh <- event
 }
 
+func (s *ClusterStore) GetEngine() engine.Engine {
+       return s.e
+}
+
 func (s *ClusterStore) LeaderChange() <-chan bool {
        return s.e.LeaderChange()
 }

Reply via email to