This is an automated email from the ASF dual-hosted git repository.

ccondit pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-k8shim.git


The following commit(s) were added to refs/heads/master by this push:
     new e9b05ebb [YUNIKORN-2910] Fix data corruption due to insufficient shim 
context locking (#924)
e9b05ebb is described below

commit e9b05ebbd0ccab7a05724781f2896ba0cb197f02
Author: Craig Condit <[email protected]>
AuthorDate: Thu Oct 10 09:41:01 2024 -0600

    [YUNIKORN-2910] Fix data corruption due to insufficient shim context 
locking (#924)
    
    Restore context locking that was removed as part of YUNIKORN-2629. The
    locks are necessary to prevent logical data corruption due to concurrent
    processing of both pod and node events.
    
    Closes: #924
---
 pkg/cache/context.go | 101 +++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 77 insertions(+), 24 deletions(-)

diff --git a/pkg/cache/context.go b/pkg/cache/context.go
index aaea690a..b4dca41e 100644
--- a/pkg/cache/context.go
+++ b/pkg/cache/context.go
@@ -72,7 +72,7 @@ type Context struct {
        pluginMode     bool                           // true if we are 
configured as a scheduler plugin
        namespace      string                         // yunikorn namespace
        configMaps     []*v1.ConfigMap                // cached yunikorn 
configmaps
-       lock           *locking.RWMutex               // lock
+       lock           *locking.RWMutex               // lock - used not only 
for context data but also to ensure that multiple event types are not executed 
concurrently
        txnID          atomic.Uint64                  // transaction ID counter
        klogger        klog.Logger
 }
@@ -166,6 +166,8 @@ func (ctx *Context) addNode(obj interface{}) {
 }
 
 func (ctx *Context) updateNode(_, obj interface{}) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        node, err := convertToNode(obj)
        if err != nil {
                log.Log(log.ShimContext).Error("node conversion failed", 
zap.Error(err))
@@ -227,6 +229,8 @@ func (ctx *Context) updateNodeInternal(node *v1.Node, 
register bool) {
 }
 
 func (ctx *Context) deleteNode(obj interface{}) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        var node *v1.Node
        switch t := obj.(type) {
        case *v1.Node:
@@ -246,6 +250,8 @@ func (ctx *Context) deleteNode(obj interface{}) {
 }
 
 func (ctx *Context) addNodesWithoutRegistering(nodes []*v1.Node) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        for _, node := range nodes {
                ctx.updateNodeInternal(node, false)
        }
@@ -281,6 +287,8 @@ func (ctx *Context) AddPod(obj interface{}) {
 }
 
 func (ctx *Context) UpdatePod(_, newObj interface{}) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        pod, err := utils.Convert2Pod(newObj)
        if err != nil {
                log.Log(log.ShimContext).Error("failed to update pod", 
zap.Error(err))
@@ -328,7 +336,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod, 
app *Application) {
                                zap.String("name", pod.Name))
                        return
                }
-               app = ctx.AddApplication(&AddApplicationRequest{
+               app = ctx.addApplication(&AddApplicationRequest{
                        Metadata: appMeta,
                })
        }
@@ -432,8 +440,10 @@ func (ctx *Context) DeletePod(obj interface{}) {
 }
 
 func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        if taskMeta, ok := getTaskMetadata(pod); ok {
-               
ctx.notifyTaskComplete(ctx.GetApplication(taskMeta.ApplicationID), 
taskMeta.TaskID)
+               
ctx.notifyTaskComplete(ctx.getApplication(taskMeta.ApplicationID), 
taskMeta.TaskID)
        }
 
        log.Log(log.ShimContext).Debug("removing pod from cache", 
zap.String("podName", pod.Name))
@@ -441,6 +451,8 @@ func (ctx *Context) deleteYuniKornPod(pod *v1.Pod) {
 }
 
 func (ctx *Context) deleteForeignPod(pod *v1.Pod) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        oldPod := ctx.schedulerCache.GetPod(string(pod.UID))
        if oldPod == nil {
                // if pod is not in scheduler cache, no node updates are needed
@@ -571,6 +583,8 @@ func (ctx *Context) addPriorityClass(obj interface{}) {
 }
 
 func (ctx *Context) updatePriorityClass(_, newObj interface{}) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        if priorityClass := utils.Convert2PriorityClass(newObj); priorityClass 
!= nil {
                ctx.updatePriorityClassInternal(priorityClass)
        }
@@ -581,6 +595,8 @@ func (ctx *Context) 
updatePriorityClassInternal(priorityClass *schedulingv1.Prio
 }
 
 func (ctx *Context) deletePriorityClass(obj interface{}) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        log.Log(log.ShimContext).Debug("priorityClass deleted")
        var priorityClass *schedulingv1.PriorityClass
        switch t := obj.(type) {
@@ -646,6 +662,8 @@ func (ctx *Context) EventsToRegister(queueingHintFn 
framework.QueueingHintFn) []
 
 // IsPodFitNode evaluates given predicates based on current context
 func (ctx *Context) IsPodFitNode(name, node string, allocate bool) error {
+       ctx.lock.RLock()
+       defer ctx.lock.RUnlock()
        pod := ctx.schedulerCache.GetPod(name)
        if pod == nil {
                return ErrorPodNotFound
@@ -666,6 +684,8 @@ func (ctx *Context) IsPodFitNode(name, node string, 
allocate bool) error {
 }
 
 func (ctx *Context) IsPodFitNodeViaPreemption(name, node string, allocations 
[]string, startIndex int) (int, bool) {
+       ctx.lock.RLock()
+       defer ctx.lock.RUnlock()
        if pod := ctx.schedulerCache.GetPod(name); pod != nil {
                // if pod exists in cache, try to run predicates
                if targetNode := ctx.schedulerCache.GetNode(node); targetNode 
!= nil {
@@ -774,6 +794,8 @@ func (ctx *Context) bindPodVolumes(pod *v1.Pod) error {
 // this way, the core can make allocation decisions with consideration of
 // other assumed pods before they are actually bound to the node (bound is 
slow).
 func (ctx *Context) AssumePod(name, node string) error {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        if pod := ctx.schedulerCache.GetPod(name); pod != nil {
                // when add assumed pod, we make a copy of the pod to avoid
                // modifying its original reference. otherwise, it may have
@@ -833,6 +855,8 @@ func (ctx *Context) AssumePod(name, node string) error {
 // forget pod must be called when a pod is assumed to be running on a node,
 // but then for some reason it is failed to bind or released.
 func (ctx *Context) ForgetPod(name string) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        if pod := ctx.schedulerCache.GetPod(name); pod != nil {
                log.Log(log.ShimContext).Debug("forget pod", zap.String("pod", 
pod.Name))
                ctx.schedulerCache.ForgetPod(pod)
@@ -949,6 +973,10 @@ func (ctx *Context) AddApplication(request 
*AddApplicationRequest) *Application
        ctx.lock.Lock()
        defer ctx.lock.Unlock()
 
+       return ctx.addApplication(request)
+}
+
+func (ctx *Context) addApplication(request *AddApplicationRequest) 
*Application {
        log.Log(log.ShimContext).Debug("AddApplication", zap.Any("Request", 
request))
        if app := ctx.getApplication(request.Metadata.ApplicationID); app != 
nil {
                return app
@@ -1026,6 +1054,8 @@ func (ctx *Context) RemoveApplication(appID string) {
 
 // this implements ApplicationManagementProtocol
 func (ctx *Context) AddTask(request *AddTaskRequest) *Task {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        return ctx.addTask(request)
 }
 
@@ -1074,8 +1104,8 @@ func (ctx *Context) addTask(request *AddTaskRequest) 
*Task {
 }
 
 func (ctx *Context) RemoveTask(appID, taskID string) {
-       ctx.lock.RLock()
-       defer ctx.lock.RUnlock()
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
        app, ok := ctx.applications[appID]
        if !ok {
                log.Log(log.ShimContext).Debug("Attempted to remove task from 
non-existent application", zap.String("appID", appID))
@@ -1085,7 +1115,9 @@ func (ctx *Context) RemoveTask(appID, taskID string) {
 }
 
 func (ctx *Context) getTask(appID string, taskID string) *Task {
-       app := ctx.GetApplication(appID)
+       ctx.lock.RLock()
+       defer ctx.lock.RUnlock()
+       app := ctx.getApplication(appID)
        if app == nil {
                log.Log(log.ShimContext).Debug("application is not found in the 
context",
                        zap.String("appID", appID))
@@ -1354,7 +1386,7 @@ func (ctx *Context) InitializeState() error {
                log.Log(log.ShimContext).Error("failed to load nodes", 
zap.Error(err))
                return err
        }
-       acceptedNodes, err := ctx.registerNodes(nodes)
+       acceptedNodes, err := ctx.RegisterNodes(nodes)
        if err != nil {
                log.Log(log.ShimContext).Error("failed to register nodes", 
zap.Error(err))
                return err
@@ -1474,11 +1506,17 @@ func (ctx *Context) registerNode(node *v1.Node) error {
        return nil
 }
 
+func (ctx *Context) RegisterNodes(nodes []*v1.Node) ([]*v1.Node, error) {
+       ctx.lock.Lock()
+       defer ctx.lock.Unlock()
+       return ctx.registerNodes(nodes)
+}
+
+// registerNodes registers the nodes to the scheduler core.
+// This method must be called while holding the Context write lock.
 func (ctx *Context) registerNodes(nodes []*v1.Node) ([]*v1.Node, error) {
        nodesToRegister := make([]*si.NodeInfo, 0)
        pendingNodes := make(map[string]*v1.Node)
-       acceptedNodes := make([]*v1.Node, 0)
-       rejectedNodes := make([]*v1.Node, 0)
 
        // Generate a NodeInfo object for each node and add to the registration 
request
        for _, node := range nodes {
@@ -1497,12 +1535,34 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) 
([]*v1.Node, error) {
                pendingNodes[node.Name] = node
        }
 
-       var wg sync.WaitGroup
+       acceptedNodes, rejectedNodes, err := 
ctx.registerNodesInternal(nodesToRegister, pendingNodes)
+       if err != nil {
+               log.Log(log.ShimContext).Error("Failed to register nodes", 
zap.Error(err))
+               return nil, err
+       }
+
+       for _, node := range acceptedNodes {
+               // post a successful event to the node
+               events.GetRecorder().Eventf(node.DeepCopy(), nil, 
v1.EventTypeNormal, "NodeAccepted", "NodeAccepted",
+                       fmt.Sprintf("node %s is accepted by the scheduler", 
node.Name))
+       }
+       for _, node := range rejectedNodes {
+               // post a failure event to the node
+               events.GetRecorder().Eventf(node.DeepCopy(), nil, 
v1.EventTypeWarning, "NodeRejected", "NodeRejected",
+                       fmt.Sprintf("node %s is rejected by the scheduler", 
node.Name))
+       }
 
+       return acceptedNodes, nil
+}
+
+func (ctx *Context) registerNodesInternal(nodesToRegister []*si.NodeInfo, 
pendingNodes map[string]*v1.Node) ([]*v1.Node, []*v1.Node, error) {
+       acceptedNodes := make([]*v1.Node, 0)
+       rejectedNodes := make([]*v1.Node, 0)
+
+       var wg sync.WaitGroup
        // initialize wait group with the number of responses we expect
        wg.Add(len(pendingNodes))
 
-       // register with the dispatcher so that we can track our response
        handlerID := fmt.Sprintf("%s-%d", registerNodeContextHandler, 
ctx.txnID.Add(1))
        dispatcher.RegisterEventHandler(handlerID, dispatcher.EventTypeNode, 
func(event interface{}) {
                nodeEvent, ok := event.(CachedSchedulerNodeEvent)
@@ -1534,24 +1594,17 @@ func (ctx *Context) registerNodes(nodes []*v1.Node) 
([]*v1.Node, error) {
                RmID:  schedulerconf.GetSchedulerConf().ClusterID,
        }); err != nil {
                log.Log(log.ShimContext).Error("Failed to register nodes", 
zap.Error(err))
-               return nil, err
+               return nil, nil, err
        }
 
+       // write lock must always be held at this point, releasing it while 
waiting to avoid any potential deadlocks
+       ctx.lock.Unlock()
+       defer ctx.lock.Lock()
+
        // wait for all responses to accumulate
        wg.Wait()
 
-       for _, node := range acceptedNodes {
-               // post a successful event to the node
-               events.GetRecorder().Eventf(node.DeepCopy(), nil, 
v1.EventTypeNormal, "NodeAccepted", "NodeAccepted",
-                       fmt.Sprintf("node %s is accepted by the scheduler", 
node.Name))
-       }
-       for _, node := range rejectedNodes {
-               // post a failure event to the node
-               events.GetRecorder().Eventf(node.DeepCopy(), nil, 
v1.EventTypeWarning, "NodeRejected", "NodeRejected",
-                       fmt.Sprintf("node %s is rejected by the scheduler", 
node.Name))
-       }
-
-       return acceptedNodes, nil
+       return acceptedNodes, rejectedNodes, nil
 }
 
 func (ctx *Context) decommissionNode(node *v1.Node) error {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to