This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e0f463c69d0 [#24931][Go SDK] Make element checkpoints independant 
(#24932)
e0f463c69d0 is described below

commit e0f463c69d01d68a0c0be3cbdda711d5de6a297c
Author: Robert Burke <[email protected]>
AuthorDate: Mon Jan 9 09:53:17 2023 -0800

    [#24931][Go SDK] Make element checkpoints independant (#24932)
---
 sdks/go/pkg/beam/core/runtime/exec/datasource.go   |  49 +++++--
 .../pkg/beam/core/runtime/exec/datasource_test.go  | 155 ++++++++++++++++++++-
 sdks/go/pkg/beam/core/runtime/exec/plan.go         |  25 ++--
 sdks/go/pkg/beam/core/runtime/exec/plan_test.go    | 141 +++++++++++++++++++
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          |  24 ++--
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     |  30 ++++
 sdks/go/pkg/beam/core/runtime/exec/unit.go         |   2 +-
 sdks/go/pkg/beam/core/runtime/exec/unit_test.go    |  14 +-
 sdks/go/pkg/beam/core/runtime/harness/harness.go   |  32 ++---
 sdks/go/pkg/beam/core/sdf/lock.go                  |   9 +-
 .../beam/io/rtrackers/offsetrange/offsetrange.go   |   4 +
 sdks/go/pkg/beam/runners/direct/impulse.go         |   4 +-
 .../test/integration/primitives/checkpointing.go   |   2 -
 13 files changed, 419 insertions(+), 72 deletions(-)

diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go 
b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
index 55b0b0a5cad..9c4de0564c8 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
@@ -127,10 +127,10 @@ func (r *byteCountReader) reset() int {
 }
 
 // Process opens the data source, reads and decodes data, kicking off element 
processing.
-func (n *DataSource) Process(ctx context.Context) error {
+func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) {
        r, err := n.source.OpenRead(ctx, n.SID)
        if err != nil {
-               return err
+               return nil, err
        }
        defer r.Close()
        n.PCol.resetSize() // initialize the size distribution for this bundle.
@@ -154,23 +154,24 @@ func (n *DataSource) Process(ctx context.Context) error {
                cp = MakeElementDecoder(c)
        }
 
+       var checkpoints []*Checkpoint
        for {
                if n.incrementIndexAndCheckSplit() {
-                       return nil
+                       break
                }
                // TODO(lostluck) 2020/02/22: Should we include window headers 
or just count the element sizes?
                ws, t, pn, err := DecodeWindowedValueHeader(wc, r)
                if err != nil {
                        if err == io.EOF {
-                               return nil
+                               break
                        }
-                       return errors.Wrap(err, "source failed")
+                       return nil, errors.Wrap(err, "source failed")
                }
 
                // Decode key or parallel element.
                pe, err := cp.Decode(&bcr)
                if err != nil {
-                       return errors.Wrap(err, "source decode failed")
+                       return nil, errors.Wrap(err, "source decode failed")
                }
                pe.Timestamp = t
                pe.Windows = ws
@@ -180,18 +181,32 @@ func (n *DataSource) Process(ctx context.Context) error {
                for _, cv := range cvs {
                        values, err := n.makeReStream(ctx, cv, &bcr, len(cvs) 
== 1 && n.singleIterate)
                        if err != nil {
-                               return err
+                               return nil, err
                        }
                        valReStreams = append(valReStreams, values)
                }
 
                if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err 
!= nil {
-                       return err
+                       return nil, err
                }
                // Collect the actual size of the element, and reset the 
bytecounter reader.
                n.PCol.addSize(int64(bcr.reset()))
                bcr.reader = r
+
+               // Check if there's a continuation and return residuals
+               // Needs to be done immeadiately after processing to not lose 
the element.
+               if c := n.getProcessContinuation(); c != nil {
+                       cp, err := n.checkpointThis(c)
+                       if err != nil {
+                               // Errors during checkpointing should fail a 
bundle.
+                               return nil, err
+                       }
+                       if cp != nil {
+                               checkpoints = append(checkpoints, cp)
+                       }
+               }
        }
+       return checkpoints, nil
 }
 
 func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr 
*byteCountReader, onlyStream bool) (ReStream, error) {
@@ -397,18 +412,22 @@ func (n *DataSource) makeEncodeElms() func([]*FullValue) 
([][]byte, error) {
        return encodeElms
 }
 
+type Checkpoint struct {
+       SR      SplitResult
+       Reapply time.Duration
+}
+
 // Checkpoint attempts to split an SDF that has self-checkpointed (e.g. 
returned a
 // ProcessContinuation) and needs to be resumed later. If the underlying DoFn 
is not
 // splittable or has not returned a resuming continuation, the function 
returns an empty
 // SplitResult, a negative resumption time, and a false boolean to indicate 
that no split
 // occurred.
-func (n *DataSource) Checkpoint() (SplitResult, time.Duration, bool, error) {
+func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, 
error) {
        n.mu.Lock()
        defer n.mu.Unlock()
 
-       pc := n.getProcessContinuation()
        if pc == nil || !pc.ShouldResume() {
-               return SplitResult{}, -1 * time.Minute, false, nil
+               return nil, nil
        }
 
        su := SplittableUnit(n.Out.(*ProcessSizedElementsAndRestrictions))
@@ -418,17 +437,17 @@ func (n *DataSource) Checkpoint() (SplitResult, 
time.Duration, bool, error) {
        // Checkpointing is functionally a split at fraction 0.0
        rs, err := su.Checkpoint()
        if err != nil {
-               return SplitResult{}, -1 * time.Minute, false, err
+               return nil, err
        }
        if len(rs) == 0 {
-               return SplitResult{}, -1 * time.Minute, false, nil
+               return nil, nil
        }
 
        encodeElms := n.makeEncodeElms()
 
        rsEnc, err := encodeElms(rs)
        if err != nil {
-               return SplitResult{}, -1 * time.Minute, false, err
+               return nil, err
        }
 
        res := SplitResult{
@@ -437,7 +456,7 @@ func (n *DataSource) Checkpoint() (SplitResult, 
time.Duration, bool, error) {
                InId: su.GetInputId(),
                OW:   ow,
        }
-       return res, pc.ResumeDelay(), true, nil
+       return &Checkpoint{SR: res, Reapply: pc.ResumeDelay()}, nil
 }
 
 // Split takes a sorted set of potential split indices and a fraction of the
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
index 19f639b31d0..64a37739b24 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
@@ -16,18 +16,25 @@
 package exec
 
 import (
+       "bytes"
        "context"
        "fmt"
        "io"
        "math"
+       "reflect"
        "testing"
        "time"
 
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange"
        "google.golang.org/protobuf/types/known/timestamppb"
 )
 
@@ -314,7 +321,10 @@ func TestDataSource_Split(t *testing.T) {
                                t.Fatalf("error in Split: got primary index = 
%v, want %v ", got, want)
                        }
 
-                       runOnRoots(ctx, t, p, "Process", Root.Process)
+                       runOnRoots(ctx, t, p, "Process", func(root Root, ctx 
context.Context) error {
+                               _, err := root.Process(ctx)
+                               return err
+                       })
                        runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)
 
                        validateSource(t, out, source, 
makeValues(test.expected...))
@@ -449,7 +459,10 @@ func TestDataSource_Split(t *testing.T) {
                if got, want := splitRes.PI, test.splitIdx-1; got != want {
                        t.Fatalf("error in Split: got primary index = %v, want 
%v ", got, want)
                }
-               runOnRoots(ctx, t, p, "Process", Root.Process)
+               runOnRoots(ctx, t, p, "Process", func(root Root, ctx 
context.Context) error {
+                       _, err := root.Process(ctx)
+                       return err
+               })
                runOnRoots(ctx, t, p, "FinishBundle", Root.FinishBundle)
 
                validateSource(t, out, source, makeValues(test.expected...))
@@ -582,7 +595,10 @@ func TestDataSource_Split(t *testing.T) {
                if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0}, 
Frac: -1}); err != nil || !sr.Unsuccessful {
                        t.Fatalf("p.Split(active) = %v,%v want unsuccessful 
split & nil err", sr, err)
                }
-               runOnRoots(ctx, t, p, "Process", Root.Process)
+               runOnRoots(ctx, t, p, "Process", func(root Root, ctx 
context.Context) error {
+                       _, err := root.Process(ctx)
+                       return err
+               })
                if sr, err := p.Split(ctx, SplitPoints{Splits: []int64{0}, 
Frac: -1}); err != nil || !sr.Unsuccessful {
                        t.Fatalf("p.Split(active, unable to get desired split) 
= %v,%v want unsuccessful split & nil err", sr, err)
                }
@@ -858,6 +874,139 @@ func TestSplitHelper(t *testing.T) {
        })
 }
 
+func TestCheckpointing(t *testing.T) {
+       t.Run("nil", func(t *testing.T) {
+               cps, err := (&DataSource{}).checkpointThis(nil)
+               if err != nil {
+                       t.Fatalf("checkpointThis() = %v, %v", cps, err)
+               }
+       })
+       t.Run("Stop", func(t *testing.T) {
+               cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing())
+               if err != nil {
+                       t.Fatalf("checkpointThis() = %v, %v", cps, err)
+               }
+       })
+       t.Run("Delay_no_residuals", func(t *testing.T) {
+               wesInv, _ := newWatermarkEstimatorStateInvoker(nil)
+               root := &DataSource{
+                       Out: &ProcessSizedElementsAndRestrictions{
+                               PDo:    &ParDo{},
+                               wesInv: wesInv,
+                               rt:     
offsetrange.NewTracker(offsetrange.Restriction{}),
+                               elm: &FullValue{
+                                       Windows: window.SingleGlobalWindow,
+                               },
+                       },
+               }
+               cp, err := 
root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13))
+               if err != nil {
+                       t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err)
+               }
+               if cp != nil {
+                       t.Fatalf("checkpointThis() = %v, want nil", cp)
+               }
+       })
+       dfn, err := graph.NewDoFn(&CheckpointingSdf{delay: time.Minute}, 
graph.NumMainInputs(graph.MainSingle))
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+
+       intCoder, _ := coderx.NewVarIntZ(reflectx.Int)
+       ERSCoder := coder.NewKV([]*coder.Coder{
+               coder.NewKV([]*coder.Coder{
+                       coder.CoderFrom(intCoder), // Element
+                       coder.NewKV([]*coder.Coder{
+                               
coder.NewR(typex.New(reflect.TypeOf((*offsetrange.Restriction)(nil)).Elem())), 
// Restriction
+                               coder.NewBool(), // Watermark State
+                       }),
+               }),
+               coder.NewDouble(), // Size
+       })
+       wvERSCoder := coder.NewW(
+               ERSCoder,
+               coder.NewGlobalWindow(),
+       )
+
+       rest := offsetrange.Restriction{Start: 1, End: 10}
+       value := &FullValue{
+               Elm: &FullValue{
+                       Elm: 42,
+                       Elm2: &FullValue{
+                               Elm:  rest,  // Restriction
+                               Elm2: false, // Watermark State falsie
+                       },
+               },
+               Elm2:      rest.Size(),
+               Windows:   window.SingleGlobalWindow,
+               Timestamp: mtime.MaxTimestamp,
+               Pane:      typex.NoFiringPane(),
+       }
+       t.Run("Delay_residuals_Process", func(t *testing.T) {
+               ctx := context.Background()
+               wesInv, _ := newWatermarkEstimatorStateInvoker(nil)
+               rest := offsetrange.Restriction{Start: 1, End: 10}
+               root := &DataSource{
+                       Coder: wvERSCoder,
+                       Out: &ProcessSizedElementsAndRestrictions{
+                               PDo: &ParDo{
+                                       Fn:  dfn,
+                                       Out: []Node{&Discard{}},
+                               },
+                               TfId:   "testTransformID",
+                               wesInv: wesInv,
+                               rt:     offsetrange.NewTracker(rest),
+                       },
+               }
+               if err := root.Up(ctx); err != nil {
+                       t.Fatalf("invalid function: %v", err)
+               }
+               if err := root.Out.Up(ctx); err != nil {
+                       t.Fatalf("invalid function: %v", err)
+               }
+
+               enc := MakeElementEncoder(wvERSCoder)
+               var buf bytes.Buffer
+
+               // We encode the element several times to ensure we don't
+               // drop any residuals, the root of issue #24931.
+               wantCount := 3
+               for i := 0; i < wantCount; i++ {
+                       if err := enc.Encode(value, &buf); err != nil {
+                               t.Fatalf("couldn't encode value: %v", err)
+                       }
+               }
+
+               if err := root.StartBundle(ctx, "testBund", DataContext{
+                       Data: &TestDataManager{
+                               R: io.NopCloser(&buf),
+                       },
+               },
+               ); err != nil {
+                       t.Fatalf("invalid function: %v", err)
+               }
+               cps, err := root.Process(ctx)
+               if err != nil {
+                       t.Fatalf("Process() = %v, %v, want nil", cps, err)
+               }
+               if got, want := len(cps), wantCount; got != want {
+                       t.Fatalf("Process() = len %v checkpoints, want %v", 
got, want)
+               }
+               // Check each checkpoint has the expected values.
+               for _, cp := range cps {
+                       if got, want := cp.Reapply, time.Minute; got != want {
+                               t.Errorf("Process(delay(%v)) delay = %v, want 
%v", want, got, want)
+                       }
+                       if got, want := cp.SR.TId, 
root.Out.(*ProcessSizedElementsAndRestrictions).TfId; got != want {
+                               t.Errorf("Process() transformID = %v, want %v", 
got, want)
+                       }
+                       if got, want := cp.SR.InId, "i0"; got != want {
+                               t.Errorf("Process() transformID = %v, want %v", 
got, want)
+                       }
+               }
+       })
+}
+
 func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd 
func(Root, context.Context) error) {
        t.Helper()
        for i, root := range p.roots {
diff --git a/sdks/go/pkg/beam/core/runtime/exec/plan.go 
b/sdks/go/pkg/beam/core/runtime/exec/plan.go
index 0189de51c7f..7958cf38238 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/plan.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/plan.go
@@ -31,11 +31,12 @@ import (
 // from a part of a pipeline. A plan can be used to process multiple bundles
 // serially.
 type Plan struct {
-       id    string // id of the bundle descriptor for this plan
-       roots []Root
-       units []Unit
-       pcols []*PCollection
-       bf    *bundleFinalizer
+       id          string // id of the bundle descriptor for this plan
+       roots       []Root
+       units       []Unit
+       pcols       []*PCollection
+       bf          *bundleFinalizer
+       checkpoints []*Checkpoint
 
        status Status
 
@@ -126,7 +127,11 @@ func (p *Plan) Execute(ctx context.Context, id string, 
manager DataContext) erro
                }
        }
        for _, root := range p.roots {
-               if err := callNoPanic(ctx, root.Process); err != nil {
+               if err := callNoPanic(ctx, func(ctx context.Context) error {
+                       cps, err := root.Process(ctx)
+                       p.checkpoints = cps
+                       return err
+               }); err != nil {
                        p.status = Broken
                        return errors.Wrapf(err, "while executing Process for 
%v", p)
                }
@@ -281,9 +286,7 @@ func (p *Plan) Split(ctx context.Context, s SplitPoints) 
(SplitResult, error) {
 }
 
 // Checkpoint attempts to split an SDF if the DoFn self-checkpointed.
-func (p *Plan) Checkpoint() (SplitResult, time.Duration, bool, error) {
-       if p.source != nil {
-               return p.source.Checkpoint()
-       }
-       return SplitResult{}, -1 * time.Minute, false, nil
+func (p *Plan) Checkpoint() []*Checkpoint {
+       defer func() { p.checkpoints = nil }()
+       return p.checkpoints
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/plan_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/plan_test.go
new file mode 100644
index 00000000000..348e0a44734
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/plan_test.go
@@ -0,0 +1,141 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package exec
+
+import (
+       "fmt"
+       "testing"
+       "time"
+
+       "github.com/google/go-cmp/cmp"
+)
+
+func TestPlan_Checkpoint(t *testing.T) {
+       var p Plan
+       want := []*Checkpoint{{Reapply: time.Hour}}
+       p.checkpoints = want
+       if got := p.Checkpoint(); !cmp.Equal(got, want) {
+               t.Errorf("p.Checkpoint() = %v, want %v", got, want)
+       }
+       if p.checkpoints != nil {
+               t.Errorf("p.Checkpoint() didn't nil checkpoints field")
+       }
+}
+
+func TestPlan_BundleFinalizers(t *testing.T) {
+       newPlan := func() Plan {
+               var p Plan
+               p.status = Up
+               return p
+       }
+       t.Run("NoCallbacks", func(t *testing.T) {
+               p := newPlan()
+               p.bf = &bundleFinalizer{}
+               if err := p.Finalize(); err != nil {
+                       t.Errorf("p.Finalize() = %v, want nil", err)
+               }
+               // Expiration time is no longer set to default
+               if got, want := p.GetExpirationTime(), (time.Time{}); 
want.Equal(got) {
+                       t.Errorf("p.GetExpirationTime() = %v, want %v", got, 
want)
+               }
+       })
+
+       t.Run("Callbacks", func(t *testing.T) {
+               p := newPlan()
+               initialDeadline := time.Now().Add(time.Hour)
+
+               var callCount int
+               inc := func() error {
+                       callCount++
+                       return nil
+               }
+               p.bf = &bundleFinalizer{
+                       callbacks: []bundleFinalizationCallback{
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                       },
+               }
+               if err := p.Finalize(); err != nil {
+                       t.Errorf("p.Finalize() = %v, want nil", err)
+               }
+               // Expiration time is no longer set to default
+               if got, want := p.GetExpirationTime(), (time.Time{}); 
want.Equal(got) {
+                       t.Errorf("p.GetExpirationTime() = %v, want %v", got, 
want)
+               }
+               if got, want := callCount, 3; got != want {
+                       t.Errorf("p.Finalize() didn't call all finalizers, got 
%v, want %v", got, want)
+               }
+       })
+       t.Run("Callbacks_expired", func(t *testing.T) {
+               p := newPlan()
+               initialDeadline := time.Now().Add(-time.Hour)
+
+               var callCount int
+               inc := func() error {
+                       callCount++
+                       return nil
+               }
+               p.bf = &bundleFinalizer{
+                       callbacks: []bundleFinalizationCallback{
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                       },
+               }
+               if err := p.Finalize(); err != nil {
+                       t.Errorf("p.Finalize() = %v, want nil", err)
+               }
+               // Expiration time is no longer set to default
+               if got, want := p.GetExpirationTime(), (time.Time{}); 
want.Equal(got) {
+                       t.Errorf("p.GetExpirationTime() = %v, want %v", got, 
want)
+               }
+               if got, want := callCount, 0; got != want {
+                       t.Errorf("p.Finalize() didn't call all finalizers, got 
%v, want %v", got, want)
+               }
+       })
+
+       t.Run("Callbacks_failures", func(t *testing.T) {
+               p := newPlan()
+               initialDeadline := time.Now().Add(time.Hour)
+
+               var callCount int
+               inc := func() error {
+                       callCount++
+                       if callCount == 1 {
+                               return fmt.Errorf("unable to call")
+                       }
+                       return nil
+               }
+               p.bf = &bundleFinalizer{
+                       callbacks: []bundleFinalizationCallback{
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                               {callback: inc, validUntil: initialDeadline},
+                       },
+               }
+               if err := p.Finalize(); err == nil {
+                       t.Errorf("p.Finalize() = %v, want an error", err)
+               }
+               if got, want := callCount, 3; got != want {
+                       t.Errorf("p.Finalize() didn't call all finalizers, got 
%v, want %v", got, want)
+               }
+               if len(p.bf.callbacks) != 1 {
+                       t.Errorf("p.Finalize() didn't preserve failed 
callbacks")
+               }
+       })
+
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index d4b1b32d257..e22496eae6e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -459,18 +459,18 @@ func (n *ProcessSizedElementsAndRestrictions) 
StartBundle(ctx context.Context, i
 //
 // Input Diagram:
 //
-//       *FullValue {
-//         Elm: *FullValue {
-//           Elm:  *FullValue (KV input) or InputType (single-element input)
-//                      Elm2: *FullValue {
-//                        Elm: Restriction
-//             Elm2: Watermark estimator state
-//                      }
-//         }
-//         Elm2: float64 (size)
-//         Windows
-//         Timestamps
-//       }
+//     *FullValue {
+//             Elm: *FullValue {
+//                     Elm:  *FullValue (KV input) or InputType 
(single-element input)
+//                     Elm2: *FullValue {
+//                             Elm: Restriction
+//                             Elm2: Watermark estimator state
+//                     }
+//             }
+//             Elm2: float64 (size)
+//             Windows
+//             Timestamps
+//     }
 //
 // ProcessElement then creates a restriction tracker from the stored 
restriction
 // and processes each element using the underlying ParDo and adding the
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 6520b9dfe59..414f28553a8 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -1524,6 +1524,36 @@ func (fn *WindowBlockingSdf) ProcessElement(w 
typex.Window, rt *sdf.LockRTracker
        emit(elm)
 }
 
+// CheckpointingSdf is a very basic checkpointing DoFn that always
+// returns a processing continuation.
+type CheckpointingSdf struct {
+       delay time.Duration
+}
+
+// CreateInitialRestriction creates a four-element offset range.
+func (fn *CheckpointingSdf) CreateInitialRestriction(_ int) 
offsetrange.Restriction {
+       return offsetrange.Restriction{Start: 0, End: 4}
+}
+
+// SplitRestriction is a no-op, and does not split.
+func (fn *CheckpointingSdf) SplitRestriction(_ int, rest 
offsetrange.Restriction) []offsetrange.Restriction {
+       return []offsetrange.Restriction{rest}
+}
+
+// RestrictionSize defers to the default offset range restriction size.
+func (fn *CheckpointingSdf) RestrictionSize(_ int, rest 
offsetrange.Restriction) float64 {
+       return rest.Size()
+}
+
+// CreateTracker creates a LockRTracker wrapping an offset range RTracker.
+func (fn *CheckpointingSdf) CreateTracker(rest offsetrange.Restriction) 
*sdf.LockRTracker {
+       return sdf.NewLockRTracker(offsetrange.NewTracker(rest))
+}
+
+func (fn *CheckpointingSdf) ProcessElement(rt *sdf.LockRTracker, elm int, emit 
func(int)) sdf.ProcessContinuation {
+       return sdf.ResumeProcessingIn(fn.delay)
+}
+
 // SplittableUnitRTracker is a VetRTracker with some added behavior needed for
 // TestAsSplittableUnit.
 type SplittableUnitRTracker struct {
diff --git a/sdks/go/pkg/beam/core/runtime/exec/unit.go 
b/sdks/go/pkg/beam/core/runtime/exec/unit.go
index 04635086ab1..ac907d98f75 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/unit.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/unit.go
@@ -57,7 +57,7 @@ type Root interface {
 
        // Process processes the entire source, notably emitting elements to
        // downstream nodes.
-       Process(ctx context.Context) error
+       Process(ctx context.Context) ([]*Checkpoint, error)
 }
 
 // ElementProcessor presents a component that can process an element.
diff --git a/sdks/go/pkg/beam/core/runtime/exec/unit_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/unit_test.go
index 9d0137a0edd..885294ed397 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/unit_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/unit_test.go
@@ -125,6 +125,8 @@ type FixedRoot struct {
        Out      Node
 }
 
+var _ Root = (*FixedRoot)(nil)
+
 func (n *FixedRoot) ID() UnitID {
        return n.UID
 }
@@ -137,13 +139,13 @@ func (n *FixedRoot) StartBundle(ctx context.Context, id 
string, data DataContext
        return n.Out.StartBundle(ctx, id, data)
 }
 
-func (n *FixedRoot) Process(ctx context.Context) error {
+func (n *FixedRoot) Process(ctx context.Context) ([]*Checkpoint, error) {
        for _, elm := range n.Elements {
                if err := n.Out.ProcessElement(ctx, &elm.Key, elm.Values...); 
err != nil {
-                       return err
+                       return nil, err
                }
        }
-       return nil
+       return nil, nil
 }
 
 func (n *FixedRoot) FinishBundle(ctx context.Context) error {
@@ -186,13 +188,13 @@ func (n *BenchRoot) StartBundle(ctx context.Context, id 
string, data DataContext
        return n.Out.StartBundle(ctx, id, data)
 }
 
-func (n *BenchRoot) Process(ctx context.Context) error {
+func (n *BenchRoot) Process(ctx context.Context) ([]*Checkpoint, error) {
        for elm := range n.Elements {
                if err := n.Out.ProcessElement(ctx, &elm.Key, elm.Values...); 
err != nil {
-                       return err
+                       return nil, err
                }
        }
-       return nil
+       return nil, nil
 }
 
 func (n *BenchRoot) FinishBundle(ctx context.Context) error {
diff --git a/sdks/go/pkg/beam/core/runtime/harness/harness.go 
b/sdks/go/pkg/beam/core/runtime/harness/harness.go
index 6b2386cc2e3..c260a46c80e 100644
--- a/sdks/go/pkg/beam/core/runtime/harness/harness.go
+++ b/sdks/go/pkg/beam/core/runtime/harness/harness.go
@@ -415,6 +415,7 @@ func (c *control) handleInstruction(ctx context.Context, 
req *fnpb.InstructionRe
 
                mons, pylds := monitoring(plan, store, 
c.runnerCapabilities[URNMonitoringInfoShortID])
 
+               checkpoints := plan.Checkpoint()
                requiresFinalization := false
                // Move the plan back to the candidate state
                c.mu.Lock()
@@ -447,21 +448,19 @@ func (c *control) handleInstruction(ctx context.Context, 
req *fnpb.InstructionRe
                        }
                }
 
-               // Check if the underlying DoFn self-checkpointed.
-               sr, delay, checkpointed, checkErr := plan.Checkpoint()
-
                var rRoots []*fnpb.DelayedBundleApplication
-               if checkpointed {
-                       rRoots = make([]*fnpb.DelayedBundleApplication, 
len(sr.RS))
-                       for i, r := range sr.RS {
-                               rRoots[i] = &fnpb.DelayedBundleApplication{
-                                       Application: &fnpb.BundleApplication{
-                                               TransformId:      sr.TId,
-                                               InputId:          sr.InId,
-                                               Element:          r,
-                                               OutputWatermarks: sr.OW,
-                                       },
-                                       RequestedTimeDelay: 
durationpb.New(delay),
+               if len(checkpoints) > 0 {
+                       for _, cp := range checkpoints {
+                               for _, r := range cp.SR.RS {
+                                       rRoots = append(rRoots, 
&fnpb.DelayedBundleApplication{
+                                               Application: 
&fnpb.BundleApplication{
+                                                       TransformId:      
cp.SR.TId,
+                                                       InputId:          
cp.SR.InId,
+                                                       Element:          r,
+                                                       OutputWatermarks: 
cp.SR.OW,
+                                               },
+                                               RequestedTimeDelay: 
durationpb.New(cp.Reapply),
+                                       })
                                }
                        }
                }
@@ -477,11 +476,6 @@ func (c *control) handleInstruction(ctx context.Context, 
req *fnpb.InstructionRe
                if err != nil {
                        return fail(ctx, instID, "process bundle failed for 
instruction %v using plan %v : %v", instID, bdID, err)
                }
-
-               if checkErr != nil {
-                       return fail(ctx, instID, "process bundle failed at 
checkpointing for instruction %v using plan %v : %v", instID, bdID, checkErr)
-               }
-
                return &fnpb.InstructionResponse{
                        InstructionId: string(instID),
                        Response: &fnpb.InstructionResponse_ProcessBundle{
diff --git a/sdks/go/pkg/beam/core/sdf/lock.go 
b/sdks/go/pkg/beam/core/sdf/lock.go
index ea26129d34c..eecae0ef4fd 100644
--- a/sdks/go/pkg/beam/core/sdf/lock.go
+++ b/sdks/go/pkg/beam/core/sdf/lock.go
@@ -15,7 +15,10 @@
 
 package sdf
 
-import "sync"
+import (
+       "fmt"
+       "sync"
+)
 
 // NewLockRTracker creates a LockRTracker initialized with the specified
 // restriction tracker as its underlying restriction tracker.
@@ -92,3 +95,7 @@ func (rt *LockRTracker) IsBounded() bool {
        }
        return true
 }
+
+func (rt *LockRTracker) String() string {
+       return fmt.Sprintf("LockRTracker(%v)", rt.Rt)
+}
diff --git a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go 
b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
index 3aaf48bc0af..ad8da0fa90b 100644
--- a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
+++ b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
@@ -131,6 +131,10 @@ type Tracker struct {
        err       error
 }
 
+func (tracker *Tracker) String() string {
+       return fmt.Sprintf("[%v,%v) c: %v, a.: %v, stopped: %v, err: %v", 
tracker.rest.Start, tracker.rest.End, tracker.claimed, tracker.attempted, 
tracker.stopped, tracker.err)
+}
+
 // NewTracker is a constructor for an Tracker given a start and end range.
 func NewTracker(rest Restriction) *Tracker {
        return &Tracker{
diff --git a/sdks/go/pkg/beam/runners/direct/impulse.go 
b/sdks/go/pkg/beam/runners/direct/impulse.go
index 1d6f78d0620..c2348474a0b 100644
--- a/sdks/go/pkg/beam/runners/direct/impulse.go
+++ b/sdks/go/pkg/beam/runners/direct/impulse.go
@@ -43,13 +43,13 @@ func (n *Impulse) StartBundle(ctx context.Context, id 
string, data exec.DataCont
        return n.Out.StartBundle(ctx, id, data)
 }
 
-func (n *Impulse) Process(ctx context.Context) error {
+func (n *Impulse) Process(ctx context.Context) ([]*exec.Checkpoint, error) {
        value := &exec.FullValue{
                Windows:   window.SingleGlobalWindow,
                Timestamp: mtime.Now(),
                Elm:       n.Value,
        }
-       return n.Out.ProcessElement(ctx, value)
+       return nil, n.Out.ProcessElement(ctx, value)
 }
 
 func (n *Impulse) FinishBundle(ctx context.Context) error {
diff --git a/sdks/go/test/integration/primitives/checkpointing.go 
b/sdks/go/test/integration/primitives/checkpointing.go
index 1a52f9d8aeb..ae61f318629 100644
--- a/sdks/go/test/integration/primitives/checkpointing.go
+++ b/sdks/go/test/integration/primitives/checkpointing.go
@@ -97,8 +97,6 @@ func (fn *selfCheckpointingDoFn) ProcessElement(rt 
*sdf.LockRTracker, _ []byte,
 
 // Checkpoints is a small test pipeline to establish the correctness of the 
simple test case.
 func Checkpoints(s beam.Scope) {
-       beam.Init()
-
        s.Scope("checkpoint")
        out := beam.ParDo(s, &selfCheckpointingDoFn{}, beam.Impulse(s))
        passert.Count(s, out, "num ints", 10)

Reply via email to