This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b627a8e8419 [BEAM-14536] Handle 0.0 splits in offsetrange restriction 
(#17782)
b627a8e8419 is described below

commit b627a8e84194427e6dccc9b6f519068777f70b83
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Jun 1 13:51:16 2022 -0400

    [BEAM-14536] Handle 0.0 splits in offsetrange restriction (#17782)
---
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     | 41 ++-----------
 .../beam/io/rtrackers/offsetrange/offsetrange.go   |  8 ++-
 .../io/rtrackers/offsetrange/offsetrange_test.go   | 71 +++++++++++++++++++++-
 3 files changed, 81 insertions(+), 39 deletions(-)

diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 85cccea270e..278ab07c8f6 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -1204,7 +1204,6 @@ func TestAsSplittableUnit(t *testing.T) {
                        fn            *graph.DoFn
                        in            FullValue
                        finishPrimary bool
-                       expErr        bool
                        wantResiduals []*FullValue
                }{
                        {
@@ -1223,7 +1222,6 @@ func TestAsSplittableUnit(t *testing.T) {
                                        Windows:   testWindows,
                                },
                                finishPrimary: true,
-                               expErr:        false,
                                wantResiduals: []*FullValue{{
                                        Elm: &FullValue{
                                                Elm: 1,
@@ -1237,25 +1235,6 @@ func TestAsSplittableUnit(t *testing.T) {
                                        Windows:   testWindows,
                                }},
                        },
-                       {
-                               name: "unfinished primary",
-                               fn:   dfn,
-                               in: FullValue{
-                                       Elm: &FullValue{
-                                               Elm: 1,
-                                               Elm2: &FullValue{
-                                                       Elm:  
&VetRestriction{ID: "Sdf"},
-                                                       Elm2: false,
-                                               },
-                                       },
-                                       Elm2:      1.0,
-                                       Timestamp: testTimestamp,
-                                       Windows:   testWindows,
-                               },
-                               finishPrimary: false,
-                               expErr:        true,
-                               wantResiduals: []*FullValue{},
-                       },
                }
                for _, test := range tests {
                        t.Run(test.name, func(t *testing.T) {
@@ -1276,20 +1255,12 @@ func TestAsSplittableUnit(t *testing.T) {
                                        
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
                                }
                                gotResiduals, err := su.Checkpoint()
-                               if test.expErr {
-                                       if err == nil {
-                                               
t.Errorf("SplittableUnit.Checkpoint() succeeded when it should have failed")
-                                       }
-                                       if len(gotResiduals) != 0 {
-                                               
t.Errorf("SplittableUnit.Checkpoint() got residuals %v, want none", 
gotResiduals)
-                                       }
-                               } else {
-                                       if err != nil {
-                                               
t.Fatalf("SplittableUnit.Checkpoint() returned error, got %v", err)
-                                       }
-                                       if diff := cmp.Diff(gotResiduals, 
test.wantResiduals); diff != "" {
-                                               
t.Errorf("SplittableUnit.Checkpoint() has incorrect residual (-got, 
+want)\n%v", diff)
-                                       }
+
+                               if err != nil {
+                                       t.Fatalf("SplittableUnit.Checkpoint() 
returned error, got %v", err)
+                               }
+                               if diff := cmp.Diff(gotResiduals, 
test.wantResiduals); diff != "" {
+                                       t.Errorf("SplittableUnit.Checkpoint() 
has incorrect residual (-got, +want)\n%v", diff)
                                }
                        })
                }
diff --git a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go 
b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
index cd3ea3f9faa..0eef03d6f43 100644
--- a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
+++ b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange.go
@@ -192,7 +192,9 @@ func (tracker *Tracker) TrySplit(fraction float64) 
(primary, residual interface{
        }
 
        // Use Ceil to always round up from float split point.
-       splitPt := tracker.claimed + 
int64(math.Ceil(fraction*float64(tracker.rest.End-tracker.claimed)))
+       // Use Max to make sure the split point is greater than the current 
claimed work since
+       // claimed work belongs to the primary.
+       splitPt := tracker.claimed + 
int64(math.Max(math.Ceil(fraction*float64(tracker.rest.End-tracker.claimed)), 
1))
        if splitPt >= tracker.rest.End {
                return tracker.rest, nil, nil
        }
@@ -208,9 +210,9 @@ func (tracker *Tracker) GetProgress() (done, remaining 
float64) {
        return
 }
 
-// IsDone returns true if the most recent claimed element is past the end of 
the restriction.
+// IsDone returns true if the most recent claimed element is at or past the 
end of the restriction
 func (tracker *Tracker) IsDone() bool {
-       return tracker.err == nil && tracker.claimed >= tracker.rest.End
+       return tracker.err == nil && (tracker.claimed+1) >= tracker.rest.End
 }
 
 // GetRestriction returns a copy of the tracker's underlying 
offsetrange.Restriction.
diff --git a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange_test.go 
b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange_test.go
index ba7da64d8fb..4147dec97ac 100644
--- a/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange_test.go
+++ b/sdks/go/pkg/beam/io/rtrackers/offsetrange/offsetrange_test.go
@@ -263,7 +263,7 @@ func TestTracker_TrySplit(t *testing.T) {
                        rest:     Restriction{Start: 0, End: 10},
                        claimed:  5,
                        fraction: -0.5,
-                       splitPt:  5,
+                       splitPt:  6,
                },
                {
                        rest:     Restriction{Start: 0, End: 10},
@@ -299,3 +299,72 @@ func TestTracker_TrySplit(t *testing.T) {
                })
        }
 }
+
+// TestTracker_TrySplit_WithoutClaiming tests that TrySplit follows its 
contract
+// when no TryClaim calls have been made, meaning that
+// splits don't lose any elements, split fractions are clamped to 0 or 1, and
+// that TrySplit always splits at the nearest integer greater than the given
+// fraction.
+func TestTracker_TrySplit_WithoutClaiming(t *testing.T) {
+       tests := []struct {
+               rest     Restriction
+               claimed  int64
+               fraction float64
+               // Index where we want the split to happen. This will be the end
+               // (exclusive) of the primary and first element of the residual.
+               splitPt int64
+       }{
+               {
+                       rest:     Restriction{Start: 0, End: 1},
+                       fraction: 0.5,
+                       splitPt:  0,
+               },
+               {
+                       rest:     Restriction{Start: 0, End: 1},
+                       fraction: 0.0,
+                       splitPt:  0,
+               },
+               {
+                       rest:     Restriction{Start: 0, End: 5},
+                       fraction: 0.5,
+                       splitPt:  2,
+               },
+               {
+                       rest:     Restriction{Start: 5, End: 10},
+                       fraction: 0.5,
+                       splitPt:  7,
+               },
+               {
+                       rest:     Restriction{Start: 5, End: 10},
+                       fraction: -0.5,
+                       splitPt:  5,
+               },
+               {
+                       rest:     Restriction{Start: 5, End: 10},
+                       fraction: 1.5,
+                       splitPt:  10,
+               },
+       }
+       for _, test := range tests {
+               test := test
+               t.Run(fmt.Sprintf("(split at %v of [%v, %v])",
+                       test.fraction, test.rest.Start, test.rest.End), func(t 
*testing.T) {
+                       rt := NewTracker(test.rest)
+                       gotP, gotR, err := rt.TrySplit(test.fraction)
+                       if err != nil {
+                               t.Fatalf("tracker failed on split: %v", err)
+                       }
+                       var wantP interface{} = Restriction{Start: 
test.rest.Start, End: test.splitPt}
+                       var wantR interface{} = Restriction{Start: 
test.splitPt, End: test.rest.End}
+                       if test.splitPt == test.rest.End {
+                               wantR = nil // When residuals are empty we 
should get nil.
+                       }
+                       if !cmp.Equal(gotP, wantP) {
+                               t.Errorf("split got incorrect primary: got: %v, 
want: %v", gotP, wantP)
+                       }
+                       if !cmp.Equal(gotR, wantR) {
+                               t.Errorf("split got incorrect residual: got: 
%v, want: %v", gotR, wantR)
+                       }
+               })
+       }
+}

Reply via email to