This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch prismDisconnect
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 072da96afd03e2c176bf65e20672611bc670f928
Author: Robert Burke <[email protected]>
AuthorDate: Mon Aug 28 14:31:28 2023 -0700

    [prism] Fail jobs on SDK disconnect.
---
 .../prism/internal/engine/elementmanager.go        | 13 +++-
 sdks/go/pkg/beam/runners/prism/internal/execute.go |  4 +-
 .../beam/runners/prism/internal/jobservices/job.go | 10 +++-
 sdks/go/pkg/beam/runners/prism/internal/stage.go   | 11 ++--
 .../beam/runners/prism/internal/worker/bundle.go   |  8 +--
 .../beam/runners/prism/internal/worker/worker.go   | 70 +++++++++++++++-------
 6 files changed, 81 insertions(+), 35 deletions(-)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index fb9c9802502..dcfde041885 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -212,7 +212,7 @@ func (em *ElementManager) Bundles(ctx context.Context, 
nextBundID func() string)
        ctx, cancelFn := context.WithCancel(ctx)
        go func() {
                em.pendingElements.Wait()
-               slog.Info("no more pending elements: terminating pipeline")
+               slog.Debug("no more pending elements: terminating pipeline")
                cancelFn()
                // Ensure the watermark evaluation goroutine exits.
                em.refreshCond.Broadcast()
@@ -394,6 +394,17 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
        em.addRefreshAndClearBundle(stage.ID, rb.BundleID)
 }
 
+// FailBundle clears the extant data allowing the execution to shut down.
+func (em *ElementManager) FailBundle(rb RunBundle) {
+       stage := em.stages[rb.StageID]
+       stage.mu.Lock()
+       completed := stage.inprogress[rb.BundleID]
+       em.pendingElements.Add(-len(completed.es))
+       delete(stage.inprogress, rb.BundleID)
+       stage.mu.Unlock()
+       em.addRefreshAndClearBundle(rb.StageID, rb.BundleID)
+}
+
 // ReturnResiduals is called after a successful split, so the remaining work
 // can be re-assigned to a new bundle.
 func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, 
inputInfo PColInfo, residuals [][]byte) {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 42327a0209d..31b5dabb2b0 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -105,7 +105,7 @@ func runEnvironment(ctx context.Context, j 
*jobservices.Job, env string, wk *wor
                        slog.Error("unmarshing environment payload", err, 
slog.String("envID", wk.ID))
                }
                externalEnvironment(ctx, ep, wk)
-               slog.Info("environment stopped", slog.String("envID", 
wk.String()), slog.String("job", j.String()))
+               slog.Debug("environment stopped", slog.String("envID", 
wk.String()), slog.String("job", j.String()))
        default:
                panic(fmt.Sprintf("environment %v with urn %v unimplemented", 
env, e.GetUrn()))
        }
@@ -304,7 +304,7 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) erro
                        s.Execute(ctx, j, wk, comps, em, rb)
                }(rb)
        }
-       slog.Info("pipeline done!", slog.String("job", j.String()))
+       slog.Debug("pipeline done!", slog.String("job", j.String()))
        return nil
 }
 
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index fe4f18bd38e..720b56e87f5 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -137,9 +137,13 @@ func (j *Job) SendMsg(msg string) {
 func (j *Job) sendState(state jobpb.JobState_Enum) {
        j.streamCond.L.Lock()
        defer j.streamCond.L.Unlock()
-       j.stateTime = time.Now()
-       j.stateIdx++
-       j.state.Store(state)
+       old := j.state.Load()
+       // Never overwrite a failed state with another one.
+       if old != jobpb.JobState_FAILED {
+               j.state.Store(state)
+               j.stateTime = time.Now()
+               j.stateIdx++
+       }
        j.streamCond.Broadcast()
 }
 
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go 
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 3f4451d7db3..8908e3d68ad 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -123,7 +123,7 @@ func (s *stage) Execute(ctx context.Context, j 
*jobservices.Job, wk *worker.W, c
                slog.Debug("Execute: processing", "bundle", rb)
                defer b.Cleanup(wk)
                b.Fail = func(errMsg string) {
-                       slog.Error("job failed", "bundle", rb, "job", j)
+                       slog.Debug("job failed", "bundle", rb, "job", j)
                        err := fmt.Errorf("%v", errMsg)
                        j.Failed(err)
                }
@@ -145,20 +145,20 @@ progress:
                        progTick.Stop()
                        break progress // exit progress loop on close.
                case <-progTick.C:
-                       resp, err := b.Progress(wk)
+                       resp, err := b.Progress(ctx, wk)
                        if err != nil {
                                slog.Debug("SDK Error from progress, aborting 
progress", "bundle", rb, "error", err.Error())
                                break progress
                        }
                        index, unknownIDs := j.ContributeTentativeMetrics(resp)
                        if len(unknownIDs) > 0 {
-                               md := wk.MonitoringMetadata(unknownIDs)
+                               md := wk.MonitoringMetadata(ctx, unknownIDs)
                                j.AddMetricShortIDs(md)
                        }
                        slog.Debug("progress report", "bundle", rb, "index", 
index)
                        // Progress for the bundle hasn't advanced. Try 
splitting.
                        if previousIndex == index && !splitsDone {
-                               sr, err := b.Split(wk, 0.5 /* fraction of 
remainder */, nil /* allowed splits */)
+                               sr, err := b.Split(ctx, wk, 0.5 /* fraction of 
remainder */, nil /* allowed splits */)
                                if err != nil {
                                        slog.Warn("SDK Error from split, 
aborting splits", "bundle", rb, "error", err.Error())
                                        break progress
@@ -202,6 +202,7 @@ progress:
        case resp = <-b.Resp:
        case <-ctx.Done():
                // Ensures we clean up on failure, if the response is blocked.
+               em.FailBundle(rb) // Note: This should change if retries are 
added.
                return
        }
 
@@ -209,7 +210,7 @@ progress:
        // pipeline termination.
        unknownIDs := j.ContributeFinalMetrics(resp)
        if len(unknownIDs) > 0 {
-               md := wk.MonitoringMetadata(unknownIDs)
+               md := wk.MonitoringMetadata(ctx, unknownIDs)
                j.AddMetricShortIDs(md)
        }
        // TODO handle side input data properly.
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
index d17deedec8d..0c6ea2434f6 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
@@ -152,8 +152,8 @@ func (b *B) Cleanup(wk *W) {
 }
 
 // Progress sends a progress request for the given bundle to the passed in 
worker, blocking on the response.
-func (b *B) Progress(wk *W) (*fnpb.ProcessBundleProgressResponse, error) {
-       resp := wk.sendInstruction(&fnpb.InstructionRequest{
+func (b *B) Progress(ctx context.Context, wk *W) 
(*fnpb.ProcessBundleProgressResponse, error) {
+       resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_ProcessBundleProgress{
                        ProcessBundleProgress: 
&fnpb.ProcessBundleProgressRequest{
                                InstructionId: b.InstID,
@@ -167,8 +167,8 @@ func (b *B) Progress(wk *W) 
(*fnpb.ProcessBundleProgressResponse, error) {
 }
 
 // Split sends a split request for the given bundle to the passed in worker, 
blocking on the response.
-func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) 
(*fnpb.ProcessBundleSplitResponse, error) {
-       resp := wk.sendInstruction(&fnpb.InstructionRequest{
+func (b *B) Split(ctx context.Context, wk *W, fraction float64, allowedSplits 
[]int64) (*fnpb.ProcessBundleSplitResponse, error) {
+       resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_ProcessBundleSplit{
                        ProcessBundleSplit: &fnpb.ProcessBundleSplitRequest{
                                InstructionId: b.InstID,
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index 405c1e812a4..8dbd97d24bb 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -255,27 +255,26 @@ func (wk *W) Connected() bool {
 // Requests come from the runner, and are sent to the client in the SDK.
 func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error {
        wk.connected.Store(true)
-       done := make(chan struct{})
+       done := make(chan error)
        go func() {
                for {
                        resp, err := ctrl.Recv()
                        if err == io.EOF {
                                slog.Debug("ctrl.Recv finished; marking done", 
"worker", wk)
-                               done <- struct{}{} // means stream is finished
+                               done <- nil // means stream is finished
                                return
                        }
                        if err != nil {
                                switch status.Code(err) {
                                case codes.Canceled:
-                                       done <- struct{}{} // means stream is 
finished
+                                       done <- err // means stream is finished
                                        return
                                default:
-                                       slog.Error("ctrl.Recv failed", err, 
"worker", wk)
+                                       slog.Error("ctrl.Recv failed", "error", 
err, "worker", wk)
                                        panic(err)
                                }
                        }
 
-                       // TODO: Do more than assume these are 
ProcessBundleResponses.
                        wk.mu.Lock()
                        if b, ok := 
wk.activeInstructions[resp.GetInstructionId()]; ok {
                                b.Respond(resp)
@@ -288,19 +287,34 @@ func (wk *W) Control(ctrl 
fnpb.BeamFnControl_ControlServer) error {
 
        for {
                select {
-               case req := <-wk.InstReqs:
-                       err := ctrl.Send(req)
-                       if err != nil {
-                               go func() { <-done }()
+               case req, ok := <-wk.InstReqs:
+                       if !ok {
+                               slog.Debug("Worker shutting down.", "worker", 
wk)
+                               return nil
+                       }
+                       if err := ctrl.Send(req); err != nil {
                                return err
                        }
                case <-ctrl.Context().Done():
-                       slog.Debug("Control context canceled")
                        go func() { <-done }()
+                       wk.mu.Lock()
+                       // Fail extant instructions
+                       slog.Debug("SDK Disconnected", "worker", wk, 
"ctx_error", ctrl.Context().Err(), "outstanding_instructions", 
len(wk.activeInstructions))
+                       for instID, b := range wk.activeInstructions {
+                               b.Respond(&fnpb.InstructionResponse{
+                                       InstructionId: instID,
+                                       Error:         "SDK Disconnected",
+                               })
+                       }
+                       wk.mu.Unlock()
                        return ctrl.Context().Err()
-               case <-done:
-                       slog.Debug("Control done")
-                       return nil
+               case err := <-done:
+                       if err != nil {
+                               slog.Warn("Control done", "error", err, 
"worker", wk)
+                       } else {
+                               slog.Debug("Control done", "worker", wk)
+                       }
+                       return err
                }
        }
 }
@@ -490,7 +504,7 @@ func (cr *chanResponder) Respond(resp 
*fnpb.InstructionResponse) {
 
 // sendInstruction is a helper for creating and sending worker single RPCs, 
blocking
 // until the response returns.
-func (wk *W) sendInstruction(req *fnpb.InstructionRequest) 
*fnpb.InstructionResponse {
+func (wk *W) sendInstruction(ctx context.Context, req 
*fnpb.InstructionRequest) *fnpb.InstructionResponse {
        cr := chanResponderPool.Get().(*chanResponder)
        progInst := wk.NextInst()
        wk.mu.Lock()
@@ -506,15 +520,31 @@ func (wk *W) sendInstruction(req 
*fnpb.InstructionRequest) *fnpb.InstructionResp
 
        req.InstructionId = progInst
 
-       // Tell the SDK to start processing the bundle.
-       wk.InstReqs <- req
-       // Protos are safe as nil, so just return directly.
-       return <-cr.Resp
+       select {
+       case <-ctx.Done():
+               return &fnpb.InstructionResponse{
+                       InstructionId: progInst,
+                       Error:         "context canceled before send",
+               }
+       case wk.InstReqs <- req:
+               // Tell the SDK to start processing the Instruction.
+       }
+
+       select {
+       case <-ctx.Done():
+               return &fnpb.InstructionResponse{
+                       InstructionId: progInst,
+                       Error:         "context canceled before receive",
+               }
+       case resp := <-cr.Resp:
+               // Protos are safe as nil, so just return directly.
+               return resp
+       }
 }
 
 // MonitoringMetadata is a convenience method to request the metadata for 
monitoring shortIDs.
-func (wk *W) MonitoringMetadata(unknownIDs []string) 
*fnpb.MonitoringInfosMetadataResponse {
-       return wk.sendInstruction(&fnpb.InstructionRequest{
+func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) 
*fnpb.MonitoringInfosMetadataResponse {
+       return wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_MonitoringInfos{
                        MonitoringInfos: &fnpb.MonitoringInfosMetadataRequest{
                                MonitoringInfoId: unknownIDs,

Reply via email to