This is an automated email from the ASF dual-hosted git repository.

hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks-controller.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 4455442  Fix data race while updating the cluster info (#217)
4455442 is described below

commit 4455442857685e47c4ca71eaebc0001037cc6046
Author: hulk <[email protected]>
AuthorDate: Thu Nov 14 11:35:11 2024 +0800

    Fix data race while updating the cluster info (#217)
---
 controller/cluster.go |  2 +-
 store/store.go        | 64 +++++++++++++++++++++++++++++++++++++++++++++++----
 store/store_test.go   |  3 +--
 3 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/controller/cluster.go b/controller/cluster.go
index 340dd3b..5b8d539 100644
--- a/controller/cluster.go
+++ b/controller/cluster.go
@@ -305,7 +305,7 @@ func (c *ClusterChecker) tryUpdateMigrationStatus(ctx 
context.Context, cluster *
                        cluster.Shards[shard.TargetShardIndex].SlotRanges = 
store.AddSlotToSlotRanges(
                                
cluster.Shards[shard.TargetShardIndex].SlotRanges, shard.MigratingSlot)
                        cluster.Shards[i].ClearMigrateState()
-                       if err := c.clusterStore.SetCluster(ctx, c.namespace, 
cluster); err != nil {
+                       if err := c.clusterStore.UpdateCluster(ctx, 
c.namespace, cluster); err != nil {
                                log.Error("Failed to update the cluster", 
zap.Error(err))
                        } else {
                                log.Info("Migrate the slot successfully", 
zap.Int("slot", shard.MigratingSlot))
diff --git a/store/store.go b/store/store.go
index 5afeeef..860df19 100644
--- a/store/store.go
+++ b/store/store.go
@@ -23,6 +23,7 @@ import (
        "context"
        "encoding/json"
        "fmt"
+       "sync"
 
        "github.com/apache/kvrocks-controller/consts"
        "github.com/apache/kvrocks-controller/store/engine"
@@ -51,6 +52,7 @@ var _ Store = (*ClusterStore)(nil)
 type ClusterStore struct {
        e engine.Engine
 
+       locks         sync.Map
        eventNotifyCh chan EventPayload
        quitCh        chan struct{}
 }
@@ -124,6 +126,12 @@ func (s *ClusterStore) RemoveNamespace(ctx 
context.Context, ns string) error {
        return nil
 }
 
+func (s *ClusterStore) getLock(ns, cluster string) *sync.RWMutex {
+       value, _ := s.locks.LoadOrStore(fmt.Sprintf("%s/%s", ns, cluster), 
&sync.RWMutex{})
+       lock, _ := value.(*sync.RWMutex)
+       return lock
+}
+
 // ListCluster return the list of name of cluster under the specified namespace
 func (s *ClusterStore) ListCluster(ctx context.Context, ns string) ([]string, 
error) {
        entries, err := s.e.List(ctx, buildClusterPrefix(ns))
@@ -142,6 +150,14 @@ func (s *ClusterStore) existsCluster(ctx context.Context, 
ns, cluster string) (b
 }
 
 func (s *ClusterStore) GetCluster(ctx context.Context, ns, cluster string) 
(*Cluster, error) {
+       lock := s.getLock(ns, cluster)
+       lock.RLock()
+       defer lock.RUnlock()
+
+       return s.getClusterWithoutLock(ctx, ns, cluster)
+}
+
+func (s *ClusterStore) getClusterWithoutLock(ctx context.Context, ns, cluster 
string) (*Cluster, error) {
        value, err := s.e.Get(ctx, buildClusterKey(ns, cluster))
        if err != nil {
                return nil, fmt.Errorf("cluster: %w", err)
@@ -155,10 +171,27 @@ func (s *ClusterStore) GetCluster(ctx context.Context, 
ns, cluster string) (*Clu
 
 // UpdateCluster update the Name to store under the specified namespace
 func (s *ClusterStore) UpdateCluster(ctx context.Context, ns string, 
clusterInfo *Cluster) error {
+       lock := s.getLock(ns, clusterInfo.Name)
+       lock.Lock()
+       defer lock.Unlock()
+
+       oldCluster, err := s.getClusterWithoutLock(ctx, ns, clusterInfo.Name)
+       if err != nil {
+               return err
+       }
+       if oldCluster.Version.Load() != clusterInfo.Version.Load() {
+               return fmt.Errorf("the cluster has been updated by others")
+       }
+
        clusterInfo.Version.Inc()
-       if err := s.SetCluster(ctx, ns, clusterInfo); err != nil {
+       clusterBytes, err := json.Marshal(clusterInfo)
+       if err != nil {
                return fmt.Errorf("cluster: %w", err)
        }
+       if err := s.e.Set(ctx, buildClusterKey(ns, clusterInfo.Name), 
clusterBytes); err != nil {
+               return err
+       }
+
        s.EmitEvent(EventPayload{
                Namespace: ns,
                Cluster:   clusterInfo.Name,
@@ -168,10 +201,20 @@ func (s *ClusterStore) UpdateCluster(ctx context.Context, 
ns string, clusterInfo
        return nil
 }
 
+// SetCluster set the cluster to store under the specified namespace but won't 
increase the version.
 func (s *ClusterStore) SetCluster(ctx context.Context, ns string, clusterInfo 
*Cluster) error {
-       if len(clusterInfo.Shards) == 0 {
-               return fmt.Errorf("%w: required at least one shard", 
consts.ErrInvalidArgument)
+       lock := s.getLock(ns, clusterInfo.Name)
+       lock.Lock()
+       defer lock.Unlock()
+
+       oldCluster, err := s.getClusterWithoutLock(ctx, ns, clusterInfo.Name)
+       if err != nil {
+               return err
+       }
+       if oldCluster.Version.Load() != clusterInfo.Version.Load() {
+               return fmt.Errorf("the cluster has been updated by others")
        }
+
        value, err := json.Marshal(clusterInfo)
        if err != nil {
                return fmt.Errorf("cluster: %w", err)
@@ -180,10 +223,18 @@ func (s *ClusterStore) SetCluster(ctx context.Context, ns 
string, clusterInfo *C
 }
 
 func (s *ClusterStore) CreateCluster(ctx context.Context, ns string, 
clusterInfo *Cluster) error {
+       lock := s.getLock(ns, clusterInfo.Name)
+       lock.Lock()
+       defer lock.Unlock()
+
        if exists, _ := s.existsCluster(ctx, ns, clusterInfo.Name); exists {
                return fmt.Errorf("cluster: %w", consts.ErrAlreadyExists)
        }
-       if err := s.SetCluster(ctx, ns, clusterInfo); err != nil {
+       clusterBytes, err := json.Marshal(clusterInfo)
+       if err != nil {
+               return fmt.Errorf("cluster: %w", err)
+       }
+       if err := s.e.Set(ctx, buildClusterKey(ns, clusterInfo.Name), 
clusterBytes); err != nil {
                return err
        }
        s.EmitEvent(EventPayload{
@@ -196,12 +247,17 @@ func (s *ClusterStore) CreateCluster(ctx context.Context, 
ns string, clusterInfo
 }
 
 func (s *ClusterStore) RemoveCluster(ctx context.Context, ns, cluster string) 
error {
+       lock := s.getLock(ns, cluster)
+       lock.Lock()
+       defer lock.Unlock()
+
        if exists, _ := s.existsCluster(ctx, ns, cluster); !exists {
                return consts.ErrNotFound
        }
        if err := s.e.Delete(ctx, buildClusterKey(ns, cluster)); err != nil {
                return err
        }
+
        s.EmitEvent(EventPayload{
                Namespace: ns,
                Cluster:   cluster,
diff --git a/store/store_test.go b/store/store_test.go
index 59b0fdc..f75115c 100644
--- a/store/store_test.go
+++ b/store/store_test.go
@@ -81,12 +81,11 @@ func TestClusterStore(t *testing.T) {
                require.NoError(t, err)
                require.ElementsMatch(t, []string{"cluster0", "cluster1"}, 
gotClusters)
 
-               cluster0.Version.Store(4)
                require.NoError(t, store.UpdateCluster(ctx, ns, cluster0))
                gotCluster, err = store.GetCluster(ctx, ns, "cluster0")
                require.NoError(t, err)
                require.Equal(t, cluster0.Name, gotCluster.Name)
-               require.Equal(t, cluster0.Version, gotCluster.Version)
+               require.EqualValues(t, 3, gotCluster.Version.Load())
 
                for _, name := range []string{"cluster0", "cluster1"} {
                        require.NoError(t, store.RemoveCluster(ctx, ns, name))

Reply via email to