This is an automated email from the ASF dual-hosted git repository.

pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-k8shim.git


The following commit(s) were added to refs/heads/master by this push:
     new 97b29d64 [YUNIKORN-2770] Simplify Application.GetTask() (#882)
97b29d64 is described below

commit 97b29d64f94c168da5c119a6f9bc84f11566f34e
Author: Peter Bacsko <[email protected]>
AuthorDate: Fri Jul 26 09:16:37 2024 +0200

    [YUNIKORN-2770] Simplify Application.GetTask() (#882)
    
    Closes: #882
    
    Signed-off-by: Peter Bacsko <[email protected]>
---
 pkg/cache/application.go        |  8 ++------
 pkg/cache/application_test.go   |  4 +---
 pkg/cache/context.go            | 10 +++++-----
 pkg/cache/context_test.go       |  6 ++----
 pkg/plugin/scheduler_plugin.go  |  2 +-
 pkg/shim/scheduler_mock_test.go |  3 +--
 6 files changed, 12 insertions(+), 21 deletions(-)

diff --git a/pkg/cache/application.go b/pkg/cache/application.go
index 3bedb132..a6ec1f3f 100644
--- a/pkg/cache/application.go
+++ b/pkg/cache/application.go
@@ -115,14 +115,10 @@ func (app *Application) canHandle(ev 
events.ApplicationEvent) bool {
        return app.sm.Can(ev.GetEvent())
 }
 
-func (app *Application) GetTask(taskID string) (*Task, error) {
+func (app *Application) GetTask(taskID string) *Task {
        app.lock.RLock()
        defer app.lock.RUnlock()
-       if task, ok := app.taskMap[taskID]; ok {
-               return task, nil
-       }
-       return nil, fmt.Errorf("task %s doesn't exist in application %s",
-               taskID, app.applicationID)
+       return app.taskMap[taskID]
 }
 
 func (app *Application) GetApplicationID() string {
diff --git a/pkg/cache/application_test.go b/pkg/cache/application_test.go
index 5d9a3ce9..da6f085d 100644
--- a/pkg/cache/application_test.go
+++ b/pkg/cache/application_test.go
@@ -1184,9 +1184,7 @@ func TestPlaceholderTimeoutEvents(t *testing.T) {
        })
        assert.Assert(t, task1 != nil)
        assert.Equal(t, task1.GetTaskID(), "task02")
-
-       _, taskErr := app.GetTask("task02")
-       assert.NilError(t, taskErr, "Task should exist")
+       assert.Assert(t, app.GetTask("task02") != nil, "Task should exist")
 
        task1.allocationKey = allocationKey
 
diff --git a/pkg/cache/context.go b/pkg/cache/context.go
index f9abaed6..0f7764ad 100644
--- a/pkg/cache/context.go
+++ b/pkg/cache/context.go
@@ -351,7 +351,7 @@ func (ctx *Context) ensureAppAndTaskCreated(pod *v1.Pod) {
        }
 
        // add task if it doesn't already exist
-       if _, taskErr := app.GetTask(string(pod.UID)); taskErr != nil {
+       if task := app.GetTask(string(pod.UID)); task == nil {
                ctx.addTask(&AddTaskRequest{
                        Metadata: taskMeta,
                })
@@ -1097,8 +1097,8 @@ func (ctx *Context) addTask(request *AddTaskRequest) 
*Task {
                zap.String("appID", request.Metadata.ApplicationID),
                zap.String("taskID", request.Metadata.TaskID))
        if app := ctx.getApplication(request.Metadata.ApplicationID); app != 
nil {
-               existingTask, err := app.GetTask(request.Metadata.TaskID)
-               if err != nil {
+               existingTask := app.GetTask(request.Metadata.TaskID)
+               if existingTask == nil {
                        var originator bool
 
                        // Is this task the originator of the application?
@@ -1156,8 +1156,8 @@ func (ctx *Context) getTask(appID string, taskID string) 
*Task {
                        zap.String("appID", appID))
                return nil
        }
-       task, err := app.GetTask(taskID)
-       if err != nil {
+       task := app.GetTask(taskID)
+       if task == nil {
                log.Log(log.ShimContext).Debug("task is not found in 
applications",
                        zap.String("taskID", taskID),
                        zap.String("appID", appID))
diff --git a/pkg/cache/context_test.go b/pkg/cache/context_test.go
index 23f68f2d..002b3c81 100644
--- a/pkg/cache/context_test.go
+++ b/pkg/cache/context_test.go
@@ -1007,8 +1007,7 @@ func TestRecoverTask(t *testing.T) {
        for _, tt := range taskInfoVerifiers {
                t.Run(tt.taskID, func(t *testing.T) {
                        // verify the info for the recovered task
-                       rt, err := app.GetTask(tt.taskID)
-                       assert.NilError(t, err)
+                       rt := app.GetTask(tt.taskID)
                        assert.Equal(t, rt.GetTaskState(), tt.expectedState)
                        assert.Equal(t, rt.allocationKey, 
tt.expectedAllocationKey)
                        assert.Equal(t, rt.pod.Name, tt.expectedPodName)
@@ -2142,9 +2141,8 @@ func TestTaskRemoveOnCompletion(t *testing.T) {
 
        // check removal
        app.Schedule()
-       appTask, err := app.GetTask(taskUID1)
+       appTask := app.GetTask(taskUID1)
        assert.Assert(t, appTask == nil)
-       assert.Error(t, err, "task task00001 doesn't exist in application 
app01")
 }
 
 func TestAssumePod(t *testing.T) {
diff --git a/pkg/plugin/scheduler_plugin.go b/pkg/plugin/scheduler_plugin.go
index 6d0351ca..7b46b619 100644
--- a/pkg/plugin/scheduler_plugin.go
+++ b/pkg/plugin/scheduler_plugin.go
@@ -302,7 +302,7 @@ func NewSchedulerPlugin(_ context.Context, _ 
runtime.Object, handle framework.Ha
 
 func (sp *YuniKornSchedulerPlugin) getTask(appID, taskID string) (app 
*cache.Application, task *cache.Task, ok bool) {
        if app := sp.context.GetApplication(appID); app != nil {
-               if task, err := app.GetTask(taskID); err == nil {
+               if task := app.GetTask(taskID); task != nil {
                        return app, task, true
                }
        }
diff --git a/pkg/shim/scheduler_mock_test.go b/pkg/shim/scheduler_mock_test.go
index 1bbe5f02..1e8f1910 100644
--- a/pkg/shim/scheduler_mock_test.go
+++ b/pkg/shim/scheduler_mock_test.go
@@ -167,8 +167,7 @@ func (fc *MockScheduler) waitAndAssertTaskState(t 
*testing.T, appID, taskID, exp
        assert.Equal(t, app != nil, true)
        assert.Equal(t, app.GetApplicationID(), appID)
 
-       task, err := app.GetTask(taskID)
-       assert.NilError(t, err, "Task retrieval failed")
+       task := app.GetTask(taskID)
        deadline := time.Now().Add(10 * time.Second)
        for {
                if task.GetTaskState() == expectedState {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to