This is an automated email from the ASF dual-hosted git repository.

pbacsko pushed a commit to branch branch-1.5
in repository https://gitbox.apache.org/repos/asf/yunikorn-k8shim.git


The following commit(s) were added to refs/heads/branch-1.5 by this push:
     new a95495db [YUNIKORN-2520] PVC errors in AssumePod() are not handled 
properly (#810)
a95495db is described below

commit a95495db05bd8dea94f8c00289d3b03467a04780
Author: Peter Bacsko <[email protected]>
AuthorDate: Thu Apr 4 19:46:59 2024 +1100

    [YUNIKORN-2520] PVC errors in AssumePod() are not handled properly (#810)
    
    Handle PV/PVC volume errors in Context.AssumePod() properly.
    If volume errors occur fail the task as the chance that they will fix
    themselves without outside changes is negligible.
    
    Enhance test coverage.
    
    Closes: #810
    
    Signed-off-by: Wilfred Spiegelenburg <[email protected]>
    (cherry picked from commit b0cfdf83d414aa66630f241198774abe5d07eb62)
---
 pkg/cache/context.go                  |  72 ++++++++---------
 pkg/cache/context_test.go             | 147 ++++++++++++++++++++++++++++++++--
 pkg/cache/external/scheduler_cache.go |   8 +-
 pkg/cache/scheduler_callback.go       |  11 ++-
 pkg/cache/task.go                     |  37 +++++----
 pkg/cache/task_state.go               |   2 +-
 pkg/cache/task_test.go                |  41 ++++++++++
 pkg/client/apifactory_mock.go         |   7 +-
 pkg/common/si_helper.go               |  48 ++++-------
 pkg/common/si_helper_test.go          |  10 +--
 pkg/common/test/volumebinder_mock.go  | 108 +++++++++++++++++++++++++
 pkg/shim/scheduler_test.go            |  46 +++++++++++
 12 files changed, 435 insertions(+), 102 deletions(-)

diff --git a/pkg/cache/context.go b/pkg/cache/context.go
index 5f8acb8a..ff6e89ed 100644
--- a/pkg/cache/context.go
+++ b/pkg/cache/context.go
@@ -780,7 +780,7 @@ func (ctx *Context) bindPodVolumes(pod *v1.Pod) error {
 // be running on it. And we keep this cache in-sync between core and the shim.
 // this way, the core can make allocation decisions with consideration of
 // other assumed pods before they are actually bound to the node (bound is 
slow).
-func (ctx *Context) AssumePod(name string, node string) error {
+func (ctx *Context) AssumePod(name, node string) error {
        ctx.lock.Lock()
        defer ctx.lock.Unlock()
        if pod, ok := ctx.schedulerCache.GetPod(name); ok {
@@ -793,45 +793,43 @@ func (ctx *Context) AssumePod(name string, node string) 
error {
                        // assume pod volumes before assuming the pod
                        // this will update scheduler cache with essential 
PV/PVC binding info
                        var allBound = true
-                       // volume builder might be null in UTs
-                       if ctx.apiProvider.GetAPIs().VolumeBinder != nil {
-                               var err error
-                               // retrieve the volume claims
-                               podVolumeClaims, err := 
ctx.apiProvider.GetAPIs().VolumeBinder.GetPodVolumeClaims(ctx.klogger, pod)
-                               if err != nil {
-                                       log.Log(log.ShimContext).Error("Failed 
to get pod volume claims",
-                                               zap.String("podName", 
assumedPod.Name),
-                                               zap.Error(err))
-                                       return err
-                               }
+                       var err error
+                       // retrieve the volume claims
+                       podVolumeClaims, err := 
ctx.apiProvider.GetAPIs().VolumeBinder.GetPodVolumeClaims(ctx.klogger, pod)
+                       if err != nil {
+                               log.Log(log.ShimContext).Error("Failed to get 
pod volume claims",
+                                       zap.String("podName", assumedPod.Name),
+                                       zap.Error(err))
+                               return err
+                       }
 
-                               // retrieve volumes
-                               volumes, reasons, err := 
ctx.apiProvider.GetAPIs().VolumeBinder.FindPodVolumes(ctx.klogger, pod, 
podVolumeClaims, targetNode.Node())
-                               if err != nil {
-                                       log.Log(log.ShimContext).Error("Failed 
to find pod volumes",
-                                               zap.String("podName", 
assumedPod.Name),
-                                               zap.String("nodeName", 
assumedPod.Spec.NodeName),
-                                               zap.Error(err))
-                                       return err
-                               }
-                               if len(reasons) > 0 {
-                                       sReasons := make([]string, 0)
-                                       for _, reason := range reasons {
-                                               sReasons = append(sReasons, 
string(reason))
-                                       }
-                                       sReason := strings.Join(sReasons, ", ")
-                                       err = fmt.Errorf("pod %s has 
conflicting volume claims: %s", pod.Name, sReason)
-                                       log.Log(log.ShimContext).Error("Pod has 
conflicting volume claims",
-                                               zap.String("podName", 
assumedPod.Name),
-                                               zap.String("nodeName", 
assumedPod.Spec.NodeName),
-                                               zap.Error(err))
-                                       return err
-                               }
-                               allBound, err = 
ctx.apiProvider.GetAPIs().VolumeBinder.AssumePodVolumes(ctx.klogger, pod, node, 
volumes)
-                               if err != nil {
-                                       return err
+                       // retrieve volumes
+                       volumes, reasons, err := 
ctx.apiProvider.GetAPIs().VolumeBinder.FindPodVolumes(ctx.klogger, pod, 
podVolumeClaims, targetNode.Node())
+                       if err != nil {
+                               log.Log(log.ShimContext).Error("Failed to find 
pod volumes",
+                                       zap.String("podName", assumedPod.Name),
+                                       zap.String("nodeName", 
assumedPod.Spec.NodeName),
+                                       zap.Error(err))
+                               return err
+                       }
+                       if len(reasons) > 0 {
+                               sReasons := make([]string, 0)
+                               for _, reason := range reasons {
+                                       sReasons = append(sReasons, 
string(reason))
                                }
+                               sReason := strings.Join(sReasons, ", ")
+                               err = fmt.Errorf("pod %s has conflicting volume 
claims: %s", pod.Name, sReason)
+                               log.Log(log.ShimContext).Error("Pod has 
conflicting volume claims",
+                                       zap.String("podName", assumedPod.Name),
+                                       zap.String("nodeName", 
assumedPod.Spec.NodeName),
+                                       zap.Error(err))
+                               return err
                        }
+                       allBound, err = 
ctx.apiProvider.GetAPIs().VolumeBinder.AssumePodVolumes(ctx.klogger, pod, node, 
volumes)
+                       if err != nil {
+                               return err
+                       }
+
                        // assign the node name for pod
                        assumedPod.Spec.NodeName = node
                        ctx.schedulerCache.AssumePod(assumedPod, allBound)
diff --git a/pkg/cache/context_test.go b/pkg/cache/context_test.go
index cd3a23fb..d67c3b82 100644
--- a/pkg/cache/context_test.go
+++ b/pkg/cache/context_test.go
@@ -33,6 +33,7 @@ import (
        "k8s.io/apimachinery/pkg/types"
        "k8s.io/client-go/tools/cache"
        k8sEvents "k8s.io/client-go/tools/events"
+       "k8s.io/kubernetes/pkg/scheduler/framework/plugins/volumebinding"
 
        schedulercache "github.com/apache/yunikorn-k8shim/pkg/cache/external"
        "github.com/apache/yunikorn-k8shim/pkg/client"
@@ -53,6 +54,11 @@ const (
        appID1 = "app00001"
        appID2 = "app00002"
        appID3 = "app00003"
+
+       pod1UID      = "task00001"
+       taskUID1     = "task00001"
+       pod1Name     = "my-pod-1"
+       fakeNodeName = "fake-node"
 )
 
 var (
@@ -71,6 +77,11 @@ func initContextAndAPIProviderForTest() (*Context, 
*client.MockedAPIProvider) {
        return context, apis
 }
 
+func setVolumeBinder(ctx *Context, binder volumebinding.SchedulerVolumeBinder) 
{
+       mockedAPI := ctx.apiProvider.(*client.MockedAPIProvider) 
//nolint:errcheck
+       mockedAPI.SetVolumeBinder(binder)
+}
+
 func newPodHelper(name, namespace, podUID, nodeName string, appID string, 
podPhase v1.PodPhase) *v1.Pod {
        return &v1.Pod{
                TypeMeta: apis.TypeMeta{
@@ -2092,13 +2103,6 @@ func TestTaskRemoveOnCompletion(t *testing.T) {
        defer dispatcher.UnregisterAllEventHandlers()
        defer dispatcher.Stop()
 
-       const (
-               pod1UID      = "task00001"
-               taskUID1     = "task00001"
-               pod1Name     = "my-pod-1"
-               fakeNodeName = "fake-node"
-       )
-
        app := context.AddApplication(&AddApplicationRequest{
                Metadata: ApplicationMetadata{
                        ApplicationID: appID,
@@ -2138,6 +2142,135 @@ func TestTaskRemoveOnCompletion(t *testing.T) {
        assert.Error(t, err, "task task00001 doesn't exist in application 
app01")
 }
 
+func TestAssumePod(t *testing.T) {
+       context := initAssumePodTest(test.NewVolumeBinderMock())
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod(pod1UID, fakeNodeName)
+       assert.NilError(t, err)
+       assert.Assert(t, context.schedulerCache.ArePodVolumesAllBound(pod1UID))
+       assumedPod, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok, "pod not found in cache")
+       assert.Equal(t, assumedPod.Spec.NodeName, fakeNodeName)
+       assert.Assert(t, context.schedulerCache.IsAssumedPod(pod1UID))
+}
+
+func TestAssumePod_GetPodVolumeClaimsError(t *testing.T) {
+       binder := test.NewVolumeBinderMock()
+       const errMsg = "error getting volume claims"
+       binder.EnableVolumeClaimsError(errMsg)
+       context := initAssumePodTest(binder)
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod(pod1UID, fakeNodeName)
+       assert.Error(t, err, errMsg)
+       assert.Assert(t, !context.schedulerCache.IsAssumedPod(pod1UID))
+       podInCache, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok, "pod not found in cache")
+       assert.Equal(t, podInCache.Spec.NodeName, "", "NodeName in pod spec was 
set unexpectedly")
+}
+
+func TestAssumePod_FindPodVolumesError(t *testing.T) {
+       binder := test.NewVolumeBinderMock()
+       const errMsg = "error getting pod volumes"
+       binder.EnableFindPodVolumesError(errMsg)
+       context := initAssumePodTest(binder)
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod(pod1UID, fakeNodeName)
+       assert.Error(t, err, errMsg)
+       assert.Assert(t, !context.schedulerCache.IsAssumedPod(pod1UID))
+       podInCache, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok, "pod not found in cache")
+       assert.Equal(t, podInCache.Spec.NodeName, "", "NodeName in pod spec was 
set unexpectedly")
+}
+
+func TestAssumePod_ConflictingVolumes(t *testing.T) {
+       binder := test.NewVolumeBinderMock()
+       binder.SetConflictReasons("reason1", "reason2")
+       context := initAssumePodTest(binder)
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod(pod1UID, fakeNodeName)
+       assert.Error(t, err, "pod my-pod-1 has conflicting volume claims: 
reason1, reason2")
+       assert.Assert(t, !context.schedulerCache.IsAssumedPod(pod1UID))
+       podInCache, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok, "pod not found in cache")
+       assert.Equal(t, podInCache.Spec.NodeName, "", "NodeName in pod spec was 
set unexpectedly")
+}
+
+func TestAssumePod_AssumePodVolumesError(t *testing.T) {
+       binder := test.NewVolumeBinderMock()
+       const errMsg = "error assuming pod volumes"
+       binder.SetAssumePodVolumesError(errMsg)
+       context := initAssumePodTest(binder)
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod(pod1UID, fakeNodeName)
+       assert.Error(t, err, errMsg)
+       assert.Assert(t, !context.schedulerCache.IsAssumedPod(pod1UID))
+       podInCache, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok, "pod not found in cache")
+       assert.Equal(t, podInCache.Spec.NodeName, "", "NodeName in pod spec was 
set unexpectedly")
+}
+
+func TestAssumePod_PodNotFound(t *testing.T) {
+       context := initAssumePodTest(nil)
+       defer dispatcher.UnregisterAllEventHandlers()
+       defer dispatcher.Stop()
+
+       err := context.AssumePod("nonexisting", fakeNodeName)
+       assert.NilError(t, err)
+       assert.Assert(t, !context.schedulerCache.IsAssumedPod(pod1UID))
+       podInCache, ok := context.schedulerCache.GetPod(pod1UID)
+       assert.Assert(t, ok)
+       assert.Equal(t, podInCache.Spec.NodeName, "", "NodeName in pod spec was 
set unexpectedly")
+}
+
+func initAssumePodTest(binder *test.VolumeBinderMock) *Context {
+       context, apiProvider := initContextAndAPIProviderForTest()
+       if binder != nil {
+               setVolumeBinder(context, binder)
+       }
+       dispatcher.Start()
+       dispatcher.RegisterEventHandler("TestAppHandler", 
dispatcher.EventTypeApp, context.ApplicationEventHandler())
+       dispatcher.RegisterEventHandler("TestTaskHandler", 
dispatcher.EventTypeTask, context.TaskEventHandler())
+       apiProvider.MockSchedulerAPIUpdateNodeFn(func(request *si.NodeRequest) 
error {
+               for _, node := range request.Nodes {
+                       dispatcher.Dispatch(CachedSchedulerNodeEvent{
+                               NodeID: node.NodeID,
+                               Event:  NodeAccepted,
+                       })
+               }
+               return nil
+       })
+       context.AddApplication(&AddApplicationRequest{
+               Metadata: ApplicationMetadata{
+                       ApplicationID: appID,
+                       QueueName:     queue,
+                       User:          "test-user",
+                       Tags:          nil,
+               },
+       })
+       pod := newPodHelper(pod1Name, namespace, pod1UID, "", appID, 
v1.PodRunning)
+       context.AddPod(pod)
+       node := v1.Node{
+               ObjectMeta: apis.ObjectMeta{
+                       Name:      fakeNodeName,
+                       Namespace: "default",
+                       UID:       "uid_0001",
+               },
+       }
+       context.addNode(&node)
+
+       return context
+}
+
 func waitForNodeAcceptedEvent(recorder *k8sEvents.FakeRecorder) error {
        // fetch the "node accepted" event
        err := utils.WaitForCondition(func() bool {
diff --git a/pkg/cache/external/scheduler_cache.go 
b/pkg/cache/external/scheduler_cache.go
index 642b380a..43958495 100644
--- a/pkg/cache/external/scheduler_cache.go
+++ b/pkg/cache/external/scheduler_cache.go
@@ -449,7 +449,13 @@ func (cache *SchedulerCache) StartPodAllocation(podKey 
string, nodeID string) bo
        return false
 }
 
-// return if pod is assumed in cache, avoid nil
+// IsAssumedPod returns if pod is assumed in cache, avoid nil
+func (cache *SchedulerCache) IsAssumedPod(podKey string) bool {
+       cache.lock.RLock()
+       defer cache.lock.RUnlock()
+       return cache.isAssumedPod(podKey)
+}
+
 func (cache *SchedulerCache) isAssumedPod(podKey string) bool {
        _, ok := cache.assumedPods[podKey]
        return ok
diff --git a/pkg/cache/scheduler_callback.go b/pkg/cache/scheduler_callback.go
index b85a4cf1..81163572 100644
--- a/pkg/cache/scheduler_callback.go
+++ b/pkg/cache/scheduler_callback.go
@@ -56,11 +56,20 @@ func (callback *AsyncRMCallback) UpdateAllocation(response 
*si.AllocationRespons
                        zap.String("nodeID", alloc.NodeID))
 
                // update cache
+               task := callback.context.getTask(alloc.ApplicationID, 
alloc.AllocationKey)
+               if task != nil {
+                       task.setAllocationID(alloc.AllocationID)
+               } else {
+                       log.Log(log.ShimRMCallback).Warn("Unable to get task", 
zap.String("taskID", alloc.AllocationKey))
+               }
                if err := callback.context.AssumePod(alloc.AllocationKey, 
alloc.NodeID); err != nil {
+                       if task != nil {
+                               task.failWithEvent(err.Error(), 
"AssumePodError")
+                       }
                        return err
                }
                if app := callback.context.GetApplication(alloc.ApplicationID); 
app != nil {
-                       if task := 
callback.context.getTask(app.GetApplicationID(), alloc.AllocationKey); task != 
nil {
+                       if task != nil {
                                if utils.IsAssignedPod(task.GetTaskPod()) {
                                        // task is already bound, fixup state 
and continue
                                        
task.MarkPreviouslyAllocated(alloc.AllocationID, alloc.NodeID)
diff --git a/pkg/cache/task.go b/pkg/cache/task.go
index d0243c3c..fc398478 100644
--- a/pkg/cache/task.go
+++ b/pkg/cache/task.go
@@ -379,14 +379,10 @@ func (task *Task) postTaskAllocated() {
                        log.Log(log.ShimCacheTask).Debug("bind pod volumes",
                                zap.String("podName", task.pod.Name),
                                zap.String("podUID", string(task.pod.UID)))
-                       if task.context.apiProvider.GetAPIs().VolumeBinder != 
nil {
-                               if err := 
task.context.bindPodVolumes(task.pod); err != nil {
-                                       errorMessage := fmt.Sprintf("bind 
volumes to pod failed, name: %s, %s", task.alias, err.Error())
-                                       
dispatcher.Dispatch(NewFailTaskEvent(task.applicationID, task.taskID, 
errorMessage))
-                                       
events.GetRecorder().Eventf(task.pod.DeepCopy(),
-                                               nil, v1.EventTypeWarning, 
"PodVolumesBindFailure", "PodVolumesBindFailure", errorMessage)
-                                       return
-                               }
+                       if err := task.context.bindPodVolumes(task.pod); err != 
nil {
+                               log.Log(log.ShimCacheTask).Error("bind volumes 
to pod failed", zap.String("taskID", task.taskID), zap.Error(err))
+                               task.failWithEvent(fmt.Sprintf("bind volumes to 
pod failed, name: %s, %s", task.alias, err.Error()), "PodVolumesBindFailure")
+                               return
                        }
 
                        log.Log(log.ShimCacheTask).Debug("bind pod",
@@ -394,11 +390,8 @@ func (task *Task) postTaskAllocated() {
                                zap.String("podUID", string(task.pod.UID)))
 
                        if err := 
task.context.apiProvider.GetAPIs().KubeClient.Bind(task.pod, task.nodeName); 
err != nil {
-                               errorMessage := fmt.Sprintf("bind pod to node 
failed, name: %s, %s", task.alias, err.Error())
-                               log.Log(log.ShimCacheTask).Error(errorMessage)
-                               
dispatcher.Dispatch(NewFailTaskEvent(task.applicationID, task.taskID, 
errorMessage))
-                               
events.GetRecorder().Eventf(task.pod.DeepCopy(), nil,
-                                       v1.EventTypeWarning, "PodBindFailure", 
"PodBindFailure", errorMessage)
+                               log.Log(log.ShimCacheTask).Error("bind pod to 
node failed", zap.String("taskID", task.taskID), zap.Error(err))
+                               task.failWithEvent(fmt.Sprintf("bind pod to 
node failed, name: %s, %s", task.alias, err.Error()), "PodBindFailure")
                                return
                        }
 
@@ -523,8 +516,7 @@ func (task *Task) releaseAllocation() {
                s := TaskStates()
                switch task.GetTaskState() {
                case s.New, s.Pending, s.Scheduling, s.Rejected:
-                       releaseRequest = common.CreateReleaseAskRequestForTask(
-                               task.applicationID, task.taskID, 
task.application.partition)
+                       releaseRequest = 
common.CreateReleaseRequestForTask(task.applicationID, task.taskID, 
task.allocationID, task.application.partition, task.terminationType)
                default:
                        if task.allocationID == "" {
                                log.Log(log.ShimCacheTask).Warn("BUG: task 
allocation allocationID is empty on release",
@@ -532,9 +524,8 @@ func (task *Task) releaseAllocation() {
                                        zap.String("taskID", task.taskID),
                                        zap.String("taskAlias", task.alias),
                                        zap.String("task", task.GetTaskState()))
-                               return
                        }
-                       releaseRequest = 
common.CreateReleaseAllocationRequestForTask(
+                       releaseRequest = common.CreateReleaseRequestForTask(
                                task.applicationID, task.taskID, 
task.allocationID, task.application.partition, task.terminationType)
                }
 
@@ -596,3 +587,15 @@ func (task *Task) UpdatePodCondition(podCondition 
*v1.PodCondition) (bool, *v1.P
 
        return false, pod
 }
+
+func (task *Task) setAllocationID(allocationID string) {
+       task.lock.Lock()
+       defer task.lock.Unlock()
+       task.allocationID = allocationID
+}
+
+func (task *Task) failWithEvent(errorMessage, actionReason string) {
+       dispatcher.Dispatch(NewFailTaskEvent(task.applicationID, task.taskID, 
errorMessage))
+       events.GetRecorder().Eventf(task.pod.DeepCopy(),
+               nil, v1.EventTypeWarning, actionReason, actionReason, 
errorMessage)
+}
diff --git a/pkg/cache/task_state.go b/pkg/cache/task_state.go
index 808df8a0..f0371bd6 100644
--- a/pkg/cache/task_state.go
+++ b/pkg/cache/task_state.go
@@ -365,7 +365,7 @@ func newTaskState() *fsm.FSM {
                        },
                        {
                                Name: TaskFail.String(),
-                               Src:  []string{states.Rejected, 
states.Allocated},
+                               Src:  []string{states.New, states.Pending, 
states.Scheduling, states.Rejected, states.Allocated},
                                Dst:  states.Failed,
                        },
                },
diff --git a/pkg/cache/task_test.go b/pkg/cache/task_test.go
index 2df89955..c08c4f03 100644
--- a/pkg/cache/task_test.go
+++ b/pkg/cache/task_test.go
@@ -218,6 +218,12 @@ func TestReleaseTaskAllocation(t *testing.T) {
                assert.Assert(t, request.Releases.AllocationsToRelease != nil)
                assert.Equal(t, 
request.Releases.AllocationsToRelease[0].ApplicationID, app.applicationID)
                assert.Equal(t, 
request.Releases.AllocationsToRelease[0].PartitionName, "default")
+               assert.Equal(t, 
request.Releases.AllocationsToRelease[0].AllocationID, "UID-00001")
+               assert.Assert(t, request.Releases.AllocationAsksToRelease != 
nil)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].ApplicationID, app.applicationID)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].AllocationKey, "task01")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].PartitionName, "default")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].TerminationType, 
si.TerminationType_UNKNOWN_TERMINATION_TYPE)
                return nil
        })
 
@@ -228,6 +234,41 @@ func TestReleaseTaskAllocation(t *testing.T) {
        assert.Equal(t, task.GetTaskState(), TaskStates().Completed)
        // 2 updates call, 1 for submit, 1 for release
        assert.Equal(t, 
mockedApiProvider.GetSchedulerAPIUpdateAllocationCount(), int32(2))
+
+       // New to Failed, no AllocationID is set (only ask is released)
+       task = NewTask("task01", app, mockedContext, pod)
+       mockedApiProvider.MockSchedulerAPIUpdateAllocationFn(func(request 
*si.AllocationRequest) error {
+               assert.Assert(t, request.Releases != nil)
+               assert.Assert(t, request.Releases.AllocationsToRelease == nil)
+               assert.Assert(t, request.Releases.AllocationAsksToRelease != 
nil)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].ApplicationID, app.applicationID)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].AllocationKey, "task01")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].PartitionName, "default")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].TerminationType, 
si.TerminationType_UNKNOWN_TERMINATION_TYPE)
+               return nil
+       })
+       err = task.handle(NewFailTaskEvent(app.applicationID, "task01", "test 
failure"))
+       assert.NilError(t, err, "failed to handle FailTask event")
+
+       // Scheduling to Failed, AllocationID is set (ask+allocation are both 
released)
+       task = NewTask("task01", app, mockedContext, pod)
+       task.setAllocationID("alloc-0")
+       task.sm.SetState(TaskStates().Scheduling)
+       mockedApiProvider.MockSchedulerAPIUpdateAllocationFn(func(request 
*si.AllocationRequest) error {
+               assert.Assert(t, request.Releases != nil)
+               assert.Assert(t, request.Releases.AllocationsToRelease != nil)
+               assert.Equal(t, 
request.Releases.AllocationsToRelease[0].ApplicationID, app.applicationID)
+               assert.Equal(t, 
request.Releases.AllocationsToRelease[0].PartitionName, "default")
+               assert.Equal(t, 
request.Releases.AllocationsToRelease[0].AllocationID, "alloc-0")
+               assert.Assert(t, request.Releases.AllocationAsksToRelease != 
nil)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].ApplicationID, app.applicationID)
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].AllocationKey, "task01")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].PartitionName, "default")
+               assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].TerminationType, 
si.TerminationType_UNKNOWN_TERMINATION_TYPE)
+               return nil
+       })
+       err = task.handle(NewFailTaskEvent(app.applicationID, "task01", "test 
failure"))
+       assert.NilError(t, err, "failed to handle FailTask event")
 }
 
 func TestReleaseTaskAsk(t *testing.T) {
diff --git a/pkg/client/apifactory_mock.go b/pkg/client/apifactory_mock.go
index 8e27ec4a..35decd4c 100644
--- a/pkg/client/apifactory_mock.go
+++ b/pkg/client/apifactory_mock.go
@@ -29,6 +29,7 @@ import (
        corev1 "k8s.io/client-go/listers/core/v1"
        storagev1 "k8s.io/client-go/listers/storage/v1"
        "k8s.io/client-go/tools/cache"
+       "k8s.io/kubernetes/pkg/scheduler/framework/plugins/volumebinding"
 
        "github.com/apache/yunikorn-k8shim/pkg/common/test"
        "github.com/apache/yunikorn-k8shim/pkg/conf"
@@ -87,7 +88,7 @@ func NewMockedAPIProvider(showError bool) *MockedAPIProvider {
                        PVInformer:            
&MockedPersistentVolumeInformer{},
                        PVCInformer:           
&MockedPersistentVolumeClaimInformer{},
                        StorageInformer:       &MockedStorageClassInformer{},
-                       VolumeBinder:          nil,
+                       VolumeBinder:          test.NewVolumeBinderMock(),
                        NamespaceInformer:     
test.NewMockNamespaceInformer(false),
                        PriorityClassInformer: 
test.NewMockPriorityClassInformer(),
                        InformerFactory:       
informers.NewSharedInformerFactory(k8fake.NewSimpleClientset(), time.Second*60),
@@ -446,3 +447,7 @@ func (m *MockedStorageClassInformer) Informer() 
cache.SharedIndexInformer {
 func (m *MockedStorageClassInformer) Lister() storagev1.StorageClassLister {
        return nil
 }
+
+func (m *MockedAPIProvider) SetVolumeBinder(binder 
volumebinding.SchedulerVolumeBinder) {
+       m.clients.VolumeBinder = binder
+}
diff --git a/pkg/common/si_helper.go b/pkg/common/si_helper.go
index 4434a82b..3387fa80 100644
--- a/pkg/common/si_helper.go
+++ b/pkg/common/si_helper.go
@@ -116,25 +116,6 @@ func CreateAllocationForTask(appID, taskID, nodeID string, 
resource *si.Resource
        }
 }
 
-func CreateReleaseAskRequestForTask(appID, taskID, partition string) 
*si.AllocationRequest {
-       toReleases := make([]*si.AllocationAskRelease, 0)
-       toReleases = append(toReleases, &si.AllocationAskRelease{
-               ApplicationID: appID,
-               AllocationKey: taskID,
-               PartitionName: partition,
-               Message:       "task request is canceled",
-       })
-
-       releaseRequest := si.AllocationReleasesRequest{
-               AllocationAsksToRelease: toReleases,
-       }
-
-       return &si.AllocationRequest{
-               Releases: &releaseRequest,
-               RmID:     conf.GetSchedulerConf().ClusterID,
-       }
-}
-
 func GetTerminationTypeFromString(terminationTypeStr string) 
si.TerminationType {
        if v, ok := si.TerminationType_value[terminationTypeStr]; ok {
                return si.TerminationType(v)
@@ -142,18 +123,21 @@ func GetTerminationTypeFromString(terminationTypeStr 
string) si.TerminationType
        return si.TerminationType_STOPPED_BY_RM
 }
 
-func CreateReleaseAllocationRequestForTask(appID, taskID, allocationID, 
partition, terminationType string) *si.AllocationRequest {
-       toReleases := make([]*si.AllocationRelease, 0)
-       toReleases = append(toReleases, &si.AllocationRelease{
-               ApplicationID:   appID,
-               AllocationID:    allocationID,
-               PartitionName:   partition,
-               TerminationType: GetTerminationTypeFromString(terminationType),
-               Message:         "task completed",
-       })
+func CreateReleaseRequestForTask(appID, taskID, allocationID, partition, 
terminationType string) *si.AllocationRequest {
+       var allocToRelease []*si.AllocationRelease
+       if allocationID != "" {
+               allocToRelease = make([]*si.AllocationRelease, 1)
+               allocToRelease[0] = &si.AllocationRelease{
+                       ApplicationID:   appID,
+                       AllocationID:    allocationID,
+                       PartitionName:   partition,
+                       TerminationType: 
GetTerminationTypeFromString(terminationType),
+                       Message:         "task completed",
+               }
+       }
 
-       toReleaseAsk := make([]*si.AllocationAskRelease, 1)
-       toReleaseAsk[0] = &si.AllocationAskRelease{
+       askToRelease := make([]*si.AllocationAskRelease, 1)
+       askToRelease[0] = &si.AllocationAskRelease{
                ApplicationID: appID,
                AllocationKey: taskID,
                PartitionName: partition,
@@ -161,8 +145,8 @@ func CreateReleaseAllocationRequestForTask(appID, taskID, 
allocationID, partitio
        }
 
        releaseRequest := si.AllocationReleasesRequest{
-               AllocationsToRelease:    toReleases,
-               AllocationAsksToRelease: toReleaseAsk,
+               AllocationsToRelease:    allocToRelease,
+               AllocationAsksToRelease: askToRelease,
        }
 
        return &si.AllocationRequest{
diff --git a/pkg/common/si_helper_test.go b/pkg/common/si_helper_test.go
index f3a5fcf6..fc642eca 100644
--- a/pkg/common/si_helper_test.go
+++ b/pkg/common/si_helper_test.go
@@ -32,8 +32,9 @@ import (
 
 const nodeID = "node-01"
 
-func TestCreateReleaseAllocationRequest(t *testing.T) {
-       request := CreateReleaseAllocationRequestForTask("app01", "task01", 
"alloc01", "default", "STOPPED_BY_RM")
+func TestCreateReleaseRequestForTask(t *testing.T) {
+       // with "allocationID"
+       request := CreateReleaseRequestForTask("app01", "task01", "alloc01", 
"default", "STOPPED_BY_RM")
        assert.Assert(t, request.Releases != nil)
        assert.Assert(t, request.Releases.AllocationsToRelease != nil)
        assert.Assert(t, request.Releases.AllocationAsksToRelease != nil)
@@ -45,10 +46,9 @@ func TestCreateReleaseAllocationRequest(t *testing.T) {
        assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].ApplicationID, "app01")
        assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].AllocationKey, "task01")
        assert.Equal(t, 
request.Releases.AllocationAsksToRelease[0].PartitionName, "default")
-}
 
-func TestCreateReleaseAskRequestForTask(t *testing.T) {
-       request := CreateReleaseAskRequestForTask("app01", "task01", "default")
+       // without allocationID
+       request = CreateReleaseRequestForTask("app01", "task01", "", "default", 
"STOPPED_BY_RM")
        assert.Assert(t, request.Releases != nil)
        assert.Assert(t, request.Releases.AllocationsToRelease == nil)
        assert.Assert(t, request.Releases.AllocationAsksToRelease != nil)
diff --git a/pkg/common/test/volumebinder_mock.go 
b/pkg/common/test/volumebinder_mock.go
new file mode 100644
index 00000000..3bc38276
--- /dev/null
+++ b/pkg/common/test/volumebinder_mock.go
@@ -0,0 +1,108 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package test
+
+import (
+       "context"
+       "errors"
+
+       v1 "k8s.io/api/core/v1"
+       "k8s.io/apimachinery/pkg/util/sets"
+       "k8s.io/klog/v2"
+       "k8s.io/kubernetes/pkg/scheduler/framework/plugins/volumebinding"
+)
+
+var _ volumebinding.SchedulerVolumeBinder = &VolumeBinderMock{}
+
+type VolumeBinderMock struct {
+       volumeClaimError    error
+       findPodVolumesError error
+       assumeVolumeError   error
+       bindError           error
+       conflictReasons     volumebinding.ConflictReasons
+
+       podVolumeClaim *volumebinding.PodVolumeClaims
+       podVolumes     *volumebinding.PodVolumes
+       allBound       bool
+}
+
+func NewVolumeBinderMock() *VolumeBinderMock {
+       return &VolumeBinderMock{
+               allBound: true,
+       }
+}
+
+func (v *VolumeBinderMock) GetPodVolumeClaims(_ klog.Logger, _ *v1.Pod) 
(podVolumeClaims *volumebinding.PodVolumeClaims, err error) {
+       if v.volumeClaimError != nil {
+               return nil, v.volumeClaimError
+       }
+
+       return v.podVolumeClaim, nil
+}
+
+func (v *VolumeBinderMock) GetEligibleNodes(_ klog.Logger, _ 
[]*v1.PersistentVolumeClaim) (eligibleNodes sets.Set[string]) {
+       return nil
+}
+
+func (v *VolumeBinderMock) FindPodVolumes(_ klog.Logger, _ *v1.Pod, _ 
*volumebinding.PodVolumeClaims, _ *v1.Node) (podVolumes 
*volumebinding.PodVolumes, reasons volumebinding.ConflictReasons, err error) {
+       if v.findPodVolumesError != nil {
+               return nil, nil, v.findPodVolumesError
+       }
+
+       if len(v.conflictReasons) > 0 {
+               return nil, v.conflictReasons, nil
+       }
+
+       return v.podVolumes, nil, nil
+}
+
+func (v *VolumeBinderMock) AssumePodVolumes(_ klog.Logger, _ *v1.Pod, _ 
string, _ *volumebinding.PodVolumes) (allFullyBound bool, err error) {
+       if v.assumeVolumeError != nil {
+               return false, v.assumeVolumeError
+       }
+
+       return v.allBound, nil
+}
+
+func (v *VolumeBinderMock) RevertAssumedPodVolumes(_ 
*volumebinding.PodVolumes) {
+}
+
+func (v *VolumeBinderMock) BindPodVolumes(_ context.Context, _ *v1.Pod, _ 
*volumebinding.PodVolumes) error {
+       return v.bindError
+}
+
+func (v *VolumeBinderMock) EnableVolumeClaimsError(message string) {
+       v.volumeClaimError = errors.New(message)
+}
+
+func (v *VolumeBinderMock) EnableFindPodVolumesError(message string) {
+       v.findPodVolumesError = errors.New(message)
+}
+
+func (v *VolumeBinderMock) SetConflictReasons(reasons ...string) {
+       var conflicts []volumebinding.ConflictReason
+       for _, r := range reasons {
+               conflicts = append(conflicts, volumebinding.ConflictReason(r))
+       }
+       v.conflictReasons = conflicts
+}
+
+func (v *VolumeBinderMock) SetAssumePodVolumesError(message string) {
+       v.assumeVolumeError = errors.New(message)
+}
diff --git a/pkg/shim/scheduler_test.go b/pkg/shim/scheduler_test.go
index 84bb52ba..e7ee19f3 100644
--- a/pkg/shim/scheduler_test.go
+++ b/pkg/shim/scheduler_test.go
@@ -259,6 +259,52 @@ partitions:
        assert.NilError(t, err, "number of allocations is not expected, error")
 }
 
+// simulate PVC error during Context.AssumePod() call
+func TestAssumePodError(t *testing.T) {
+       configData := `
+partitions:
+  - name: default
+    queues:
+      - name: root
+        submitacl: "*"
+        queues:
+          - name: a
+            resources:
+              guaranteed:
+                memory: 100000000
+                vcore: 10
+              max:
+                memory: 150000000
+                vcore: 20
+`
+       cluster := MockScheduler{}
+       cluster.init()
+       binder := test.NewVolumeBinderMock()
+       binder.EnableVolumeClaimsError("unable to get volume claims")
+       cluster.apiProvider.SetVolumeBinder(binder)
+       assert.NilError(t, cluster.start(), "failed to start cluster")
+       defer cluster.stop()
+
+       err := cluster.updateConfig(configData, nil)
+       assert.NilError(t, err, "update config failed")
+       addNode(&cluster, "node-1")
+
+       // create app and task which will fail due to simulated volume error
+       taskResource := common.NewResourceBuilder().
+               AddResource(siCommon.Memory, 1000).
+               AddResource(siCommon.CPU, 1).
+               Build()
+       pod1 := createTestPod("root.a", "app0001", "task0001", taskResource)
+       cluster.AddPod(pod1)
+
+       // expect app to enter Completing state with allocation+ask removed
+       err = cluster.waitForApplicationStateInCore("app0001", partitionName, 
"Completing")
+       assert.NilError(t, err)
+       app := cluster.getApplicationFromCore("app0001", partitionName)
+       assert.Equal(t, 0, len(app.GetAllRequests()), "asks were not removed 
from the application")
+       assert.Equal(t, 0, len(app.GetAllAllocations()), "allocations were not 
removed from the application")
+}
+
 func createTestPod(queue string, appID string, taskID string, taskResource 
*si.Resource) *v1.Pod {
        containers := make([]v1.Container, 0)
        c1Resources := make(map[v1.ResourceName]resource.Quantity)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to