This is an automated email from the ASF dual-hosted git repository.

damondouglas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a47b1faa527 [Prism] Implement jobservices.Server Cancel (#30178)
a47b1faa527 is described below

commit a47b1faa5276cdbf05c356b60ed8c4494ee622aa
Author: Damon <[email protected]>
AuthorDate: Mon Feb 5 17:23:10 2024 +0000

    [Prism] Implement jobservices.Server Cancel (#30178)
    
    * Implement jobservices.Server Cancel
    
    * Small code cleanup
    
    * Fix test err; canceled state after complete
---
 sdks/go/pkg/beam/runners/prism/internal/execute.go |  8 +++
 .../beam/runners/prism/internal/jobservices/job.go | 10 ++++
 .../prism/internal/jobservices/management.go       | 30 +++++++++++
 .../prism/internal/jobservices/management_test.go  | 34 ++++++++++++
 .../prism/internal/jobservices/server_test.go      | 61 ++++++++++++++++++++++
 5 files changed, 143 insertions(+)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index b8bc68dcd1b..1aa95bc6ee1 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -17,6 +17,7 @@ package internal
 
 import (
        "context"
+       "errors"
        "fmt"
        "io"
        "sort"
@@ -70,6 +71,13 @@ func RunPipeline(j *jobservices.Job) {
                j.Failed(err)
                return
        }
+
+       if errors.Is(context.Cause(j.RootCtx), jobservices.ErrCancel) {
+               j.SendMsg("pipeline canceled " + j.String())
+               j.Canceled()
+               return
+       }
+
        j.SendMsg("pipeline completed " + j.String())
 
        j.SendMsg("terminating " + j.String())
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index bb5eb88c919..6cde48ded9a 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -177,6 +177,16 @@ func (j *Job) Done() {
        j.sendState(jobpb.JobState_DONE)
 }
 
+// Canceling indicates that the job is canceling.
+func (j *Job) Canceling() {
+       j.sendState(jobpb.JobState_CANCELLING)
+}
+
+// Canceled indicates that the job is canceled.
+func (j *Job) Canceled() {
+       j.sendState(jobpb.JobState_CANCELLED)
+}
+
 // Failed indicates that the job completed unsuccessfully.
 func (j *Job) Failed(err error) {
        slog.Error("job failed", slog.Any("job", j), slog.Any("error", err))
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 323d8c46efb..0da37ef0bd7 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -17,6 +17,7 @@ package jobservices
 
 import (
        "context"
+       "errors"
        "fmt"
        "sync"
        "sync/atomic"
@@ -30,6 +31,10 @@ import (
        "google.golang.org/protobuf/types/known/timestamppb"
 )
 
+var (
+       ErrCancel = errors.New("pipeline canceled")
+)
+
 func (s *Server) nextId() string {
        v := atomic.AddUint32(&s.index, 1)
        return fmt.Sprintf("job-%03d", v)
@@ -215,6 +220,31 @@ func (s *Server) Run(ctx context.Context, req 
*jobpb.RunJobRequest) (*jobpb.RunJ
        }, nil
 }
 
+// Cancel a Job requested by the CancelJobRequest for jobs not in an already 
terminal state.
+// Otherwise, returns nil if Job does not exist or the Job's existing state as 
part of the CancelJobResponse.
+func (s *Server) Cancel(_ context.Context, req *jobpb.CancelJobRequest) 
(*jobpb.CancelJobResponse, error) {
+       s.mu.Lock()
+       job, ok := s.jobs[req.GetJobId()]
+       s.mu.Unlock()
+       if !ok {
+               return nil, nil
+       }
+       state := job.state.Load().(jobpb.JobState_Enum)
+       switch state {
+       case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, 
jobpb.JobState_DRAINED, jobpb.JobState_UPDATED, jobpb.JobState_FAILED:
+               // Already at terminal state.
+               return &jobpb.CancelJobResponse{
+                       State: state,
+               }, nil
+       }
+       job.SendMsg("canceling " + job.String())
+       job.Canceling()
+       job.CancelFn(ErrCancel)
+       return &jobpb.CancelJobResponse{
+               State: jobpb.JobState_CANCELLING,
+       }, nil
+}
+
 // GetMessageStream subscribes to a stream of state changes and messages from 
the job. If throughput
 // is high, this may cause losses of messages.
 func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream 
jobpb.JobService_GetMessageStreamServer) error {
diff --git 
a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
index 5813e6ef73e..176abb8543a 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
@@ -169,6 +169,40 @@ func TestServer(t *testing.T) {
                                }
                        },
                },
+               {
+                       name: "Canceling",
+                       noJobsCheck: func(ctx context.Context, t *testing.T, 
undertest *Server) {
+                               resp, err := undertest.Cancel(ctx, 
&jobpb.CancelJobRequest{JobId: "job-001"})
+                               if resp != nil {
+                                       t.Errorf("Canceling(\"job-001\") = %s, 
want nil", resp)
+                               }
+                               if err != nil {
+                                       t.Errorf("Canceling(\"job-001\") = %v, 
want nil", err)
+                               }
+                       },
+                       postPrepCheck: func(ctx context.Context, t *testing.T, 
undertest *Server) {
+                               resp, err := undertest.Cancel(ctx, 
&jobpb.CancelJobRequest{JobId: "job-001"})
+                               if err != nil {
+                                       t.Errorf("Canceling(\"job-001\") = %v, 
want nil", err)
+                               }
+                               if diff := cmp.Diff(&jobpb.CancelJobResponse{
+                                       State: jobpb.JobState_CANCELLING,
+                               }, resp, cmpOpts...); diff != "" {
+                                       t.Errorf("Canceling(\"job-001\") 
(-want, +got):\n%v", diff)
+                               }
+                       },
+                       postRunCheck: func(ctx context.Context, t *testing.T, 
undertest *Server, jobID string) {
+                               resp, err := undertest.Cancel(ctx, 
&jobpb.CancelJobRequest{JobId: jobID})
+                               if err != nil {
+                                       t.Errorf("Canceling(\"%s\") = %v, want 
nil", jobID, err)
+                               }
+                               if diff := cmp.Diff(&jobpb.CancelJobResponse{
+                                       State: jobpb.JobState_DONE,
+                               }, resp, cmpOpts...); diff != "" {
+                                       t.Errorf("Canceling(\"%s\") (-want, 
+got):\n%v", jobID, diff)
+                               }
+                       },
+               },
        }
        for _, test := range tests {
                var called sync.WaitGroup
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go
index 2223f030ce1..473c84f958e 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server_test.go
@@ -17,6 +17,7 @@ package jobservices
 
 import (
        "context"
+       "errors"
        "sync"
        "testing"
 
@@ -77,3 +78,63 @@ func TestServer_JobLifecycle(t *testing.T) {
        t.Log("success!")
        // Nothing to cleanup because we didn't start the server.
 }
+
+// Validates that invoking Cancel cancels a running job.
+func TestServer_RunThenCancel(t *testing.T) {
+       var called sync.WaitGroup
+       called.Add(1)
+       undertest := NewServer(0, func(j *Job) {
+               if errors.Is(context.Cause(j.RootCtx), ErrCancel) {
+                       j.state.Store(jobpb.JobState_CANCELLED)
+                       called.Done()
+               }
+       })
+       ctx := context.Background()
+
+       wantPipeline := &pipepb.Pipeline{
+               Requirements: []string{urns.RequirementSplittableDoFn},
+       }
+       wantName := "testJob"
+
+       resp, err := undertest.Prepare(ctx, &jobpb.PrepareJobRequest{
+               Pipeline: wantPipeline,
+               JobName:  wantName,
+       })
+       if err != nil {
+               t.Fatalf("server.Prepare() = %v, want nil", err)
+       }
+
+       if got := resp.GetPreparationId(); got == "" {
+               t.Fatalf("server.Prepare() = returned empty preparation ID, 
want non-empty: %v", prototext.Format(resp))
+       }
+
+       runResp, err := undertest.Run(ctx, &jobpb.RunJobRequest{
+               PreparationId: resp.GetPreparationId(),
+       })
+       if err != nil {
+               t.Fatalf("server.Run() = %v, want nil", err)
+       }
+       if got := runResp.GetJobId(); got == "" {
+               t.Fatalf("server.Run() = returned empty preparation ID, want 
non-empty")
+       }
+
+       cancelResp, err := undertest.Cancel(ctx, &jobpb.CancelJobRequest{
+               JobId: runResp.GetJobId(),
+       })
+       if err != nil {
+               t.Fatalf("server.Canceling() = %v, want nil", err)
+       }
+       if cancelResp.State != jobpb.JobState_CANCELLING {
+               t.Fatalf("server.Canceling() = %v, want %v", cancelResp.State, 
jobpb.JobState_CANCELLING)
+       }
+
+       called.Wait()
+
+       stateResp, err := undertest.GetState(ctx, 
&jobpb.GetJobStateRequest{JobId: runResp.GetJobId()})
+       if err != nil {
+               t.Fatalf("server.GetState() = %v, want nil", err)
+       }
+       if stateResp.State != jobpb.JobState_CANCELLED {
+               t.Fatalf("server.GetState() = %v, want %v", stateResp.State, 
jobpb.JobState_CANCELLED)
+       }
+}

Reply via email to