This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 2b96716ef5f Implement mongodbio.Read with an SDF (#25160)
2b96716ef5f is described below
commit 2b96716ef5f1e575bb53cf3d23843d42faa45ce3
Author: Johanna Öjeling <[email protected]>
AuthorDate: Sat Jan 28 01:29:30 2023 +0100
Implement mongodbio.Read with an SDF (#25160)
---
sdks/go/pkg/beam/io/mongodbio/coder.go | 25 +-
sdks/go/pkg/beam/io/mongodbio/coder_test.go | 135 +++---
sdks/go/pkg/beam/io/mongodbio/common.go | 38 +-
.../pkg/beam/io/mongodbio/id_range_restriction.go | 206 +++++++++
.../beam/io/mongodbio/id_range_restriction_test.go | 179 ++++++++
sdks/go/pkg/beam/io/mongodbio/id_range_split.go | 248 +++++++++++
.../pkg/beam/io/mongodbio/id_range_split_test.go | 275 ++++++++++++
sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go | 194 +++++++++
.../pkg/beam/io/mongodbio/id_range_tracker_test.go | 461 +++++++++++++++++++++
sdks/go/pkg/beam/io/mongodbio/read.go | 402 +++++-------------
sdks/go/pkg/beam/io/mongodbio/read_test.go | 323 ---------------
11 files changed, 1774 insertions(+), 712 deletions(-)
diff --git a/sdks/go/pkg/beam/io/mongodbio/coder.go
b/sdks/go/pkg/beam/io/mongodbio/coder.go
index c140f9a8a25..3100f0fd93d 100644
--- a/sdks/go/pkg/beam/io/mongodbio/coder.go
+++ b/sdks/go/pkg/beam/io/mongodbio/coder.go
@@ -26,9 +26,14 @@ import (
func init() {
beam.RegisterCoder(
- reflect.TypeOf((*bson.M)(nil)).Elem(),
- encodeBSONMap,
- decodeBSONMap,
+ reflect.TypeOf((*idRangeRestriction)(nil)).Elem(),
+ encodeBSON[idRangeRestriction],
+ decodeBSON[idRangeRestriction],
+ )
+ beam.RegisterCoder(
+ reflect.TypeOf((*idRange)(nil)).Elem(),
+ encodeBSON[idRange],
+ decodeBSON[idRange],
)
beam.RegisterCoder(
reflect.TypeOf((*primitive.ObjectID)(nil)).Elem(),
@@ -37,19 +42,19 @@ func init() {
)
}
-func encodeBSONMap(m bson.M) ([]byte, error) {
- bytes, err := bson.Marshal(m)
+func encodeBSON[T any](in T) ([]byte, error) {
+ out, err := bson.Marshal(in)
if err != nil {
return nil, fmt.Errorf("error encoding BSON: %w", err)
}
- return bytes, nil
+ return out, nil
}
-func decodeBSONMap(bytes []byte) (bson.M, error) {
- var out bson.M
- if err := bson.Unmarshal(bytes, &out); err != nil {
- return nil, fmt.Errorf("error decoding BSON: %w", err)
+func decodeBSON[T any](in []byte) (T, error) {
+ var out T
+ if err := bson.Unmarshal(in, &out); err != nil {
+ return out, fmt.Errorf("error decoding BSON: %w", err)
}
return out, nil
diff --git a/sdks/go/pkg/beam/io/mongodbio/coder_test.go
b/sdks/go/pkg/beam/io/mongodbio/coder_test.go
index d5e3bb2974d..98f81c50ef2 100644
--- a/sdks/go/pkg/beam/io/mongodbio/coder_test.go
+++ b/sdks/go/pkg/beam/io/mongodbio/coder_test.go
@@ -23,137 +23,102 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive"
)
-func Test_encodeBSONMap(t *testing.T) {
+func Test_encodeDecodeBSONMap(t *testing.T) {
tests := []struct {
- name string
- m bson.M
- want []byte
- wantErr bool
+ name string
+ val bson.M
}{
{
- name: "Encode bson.M",
- m: bson.M{"key": "val"},
- want: []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0,
0, 0, 118, 97, 108, 0, 0},
- wantErr: false,
+ name: "Encode/decode bson.M",
+ val: bson.M{"key": "val"},
},
{
- name: "Encode empty bson.M",
- m: bson.M{},
- want: []byte{5, 0, 0, 0, 0},
- wantErr: false,
- },
- {
- name: "Encode nil bson.M",
- m: bson.M(nil),
- want: []byte{5, 0, 0, 0, 0},
- wantErr: false,
- },
- {
- name: "Error - invalid bson.M",
- m: bson.M{"key": make(chan int)},
- wantErr: true,
+ name: "Encode/decode empty bson.M",
+ val: bson.M{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := encodeBSONMap(tt.m)
- if (err != nil) != tt.wantErr {
- t.Fatalf("encodeBSONMap() error = %v, wantErr
%v", err, tt.wantErr)
+ encoded, err := encodeBSON[bson.M](tt.val)
+ if err != nil {
+ t.Fatalf("encodeBSON[bson.M]() error = %v", err)
+ }
+
+ decoded, err := decodeBSON[bson.M](encoded)
+ if err != nil {
+ t.Fatalf("decodeBSON[bson.M]() error = %v", err)
}
- if !cmp.Equal(got, tt.want) {
- t.Errorf("encodeBSONMap() got = %v, want %v",
got, tt.want)
+ if diff := cmp.Diff(tt.val, decoded); diff != "" {
+ t.Errorf("encode/decode mismatch (-want
+got):\n%s", diff)
}
})
}
}
-func Test_decodeBSONMap(t *testing.T) {
+func Test_encodeDecodeIDRangeRestriction(t *testing.T) {
tests := []struct {
- name string
- bytes []byte
- want bson.M
- wantErr bool
+ name string
+ rest idRangeRestriction
}{
{
- name: "Decode bson.M",
- bytes: []byte{18, 0, 0, 0, 2, 107, 101, 121, 0, 4, 0,
0, 0, 118, 97, 108, 0, 0},
- want: bson.M{"key": "val"},
- wantErr: false,
+ name: "Encode/decode idRangeRestriction",
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: objectIDFromHex(t,
"5f1b2c3d4e5f60708090a0b0"),
+ MinInclusive: true,
+ Max: objectIDFromHex(t,
"5f1b2c3d4e5f60708090a0b9"),
+ MaxInclusive: true,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 5,
+ },
},
{
- name: "Decode empty bson.M",
- bytes: []byte{5, 0, 0, 0, 0},
- want: bson.M{},
- wantErr: false,
- },
- {
- name: "Error - invalid bson.M",
- bytes: []byte{},
- wantErr: true,
+ name: "Encode/decode empty idRangeRestriction",
+ rest: idRangeRestriction{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := decodeBSONMap(tt.bytes)
- if (err != nil) != tt.wantErr {
- t.Fatalf("decodeBSONMap() error = %v, wantErr
%v", err, tt.wantErr)
+ encoded, err := encodeBSON[idRangeRestriction](tt.rest)
+ if err != nil {
+ t.Fatalf("encodeBSON[idRangeRestriction]()
error = %v", err)
+ }
+
+ decoded, err := decodeBSON[idRangeRestriction](encoded)
+ if err != nil {
+ t.Fatalf("decodeBSON[idRangeRestriction]()
error = %v", err)
}
- if !cmp.Equal(got, tt.want) {
- t.Errorf("decodeBSONMap() got = %v, want %v",
got, tt.want)
+ if diff := cmp.Diff(tt.rest, decoded); diff != "" {
+ t.Errorf("encode/decode mismatch (-want
+got):\n%s", diff)
}
})
}
}
-func Test_encodeObjectID(t *testing.T) {
+func Test_encodeDecodeObjectID(t *testing.T) {
tests := []struct {
name string
objectID primitive.ObjectID
- want []byte
}{
{
- name: "Encode object ID",
+ name: "Encode/decode object ID",
objectID: objectIDFromHex(t,
"5f1b2c3d4e5f60708090a0b0"),
- want: []byte{95, 27, 44, 61, 78, 95, 96, 112, 128,
144, 160, 176},
},
{
- name: "Encode nil object ID",
+ name: "Encode/decode nil object ID",
objectID: primitive.NilObjectID,
- want: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if got := encodeObjectID(tt.objectID); !cmp.Equal(got,
tt.want) {
- t.Errorf("encodeObjectID() = %v, want %v", got,
tt.want)
- }
- })
- }
-}
+ encoded := encodeObjectID(tt.objectID)
+ decoded := decodeObjectID(encoded)
-func Test_decodeObjectID(t *testing.T) {
- tests := []struct {
- name string
- bytes []byte
- want primitive.ObjectID
- }{
- {
- name: "Decode object ID",
- bytes: []byte{95, 27, 44, 61, 78, 95, 96, 112, 128,
144, 160, 176},
- want: objectIDFromHex(t, "5f1b2c3d4e5f60708090a0b0"),
- },
- {
- name: "Decode nil object ID",
- bytes: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
- want: primitive.NilObjectID,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := decodeObjectID(tt.bytes); !cmp.Equal(got,
tt.want) {
- t.Errorf("decodeObjectID() = %v, want %v", got,
tt.want)
+ if !cmp.Equal(decoded, tt.objectID) {
+ t.Errorf("decodeObjectID() = %v, want %v",
decoded, tt.objectID)
}
})
}
diff --git a/sdks/go/pkg/beam/io/mongodbio/common.go
b/sdks/go/pkg/beam/io/mongodbio/common.go
index 9d6ffbeaa95..e1d0657a206 100644
--- a/sdks/go/pkg/beam/io/mongodbio/common.go
+++ b/sdks/go/pkg/beam/io/mongodbio/common.go
@@ -20,6 +20,7 @@ import (
"context"
"fmt"
+ "go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
@@ -38,13 +39,16 @@ type mongoDBFn struct {
}
func (fn *mongoDBFn) Setup(ctx context.Context) error {
- client, err := newClient(ctx, fn.URI)
- if err != nil {
- return err
+ if fn.client == nil {
+ client, err := newClient(ctx, fn.URI)
+ if err != nil {
+ return err
+ }
+
+ fn.client = client
}
- fn.client = client
- fn.collection = client.Database(fn.Database).Collection(fn.Collection)
+ fn.collection =
fn.client.Database(fn.Database).Collection(fn.Collection)
return nil
}
@@ -71,3 +75,27 @@ func (fn *mongoDBFn) Teardown(ctx context.Context) error {
return nil
}
+
+type documentID struct {
+ ID any `bson:"_id"`
+}
+
+func findID(
+ ctx context.Context,
+ collection *mongo.Collection,
+ filter any,
+ order int,
+ skip int64,
+) (any, error) {
+ opts := options.FindOne().
+ SetProjection(bson.M{"_id": 1}).
+ SetSort(bson.M{"_id": order}).
+ SetSkip(skip)
+
+ var docID documentID
+ if err := collection.FindOne(ctx, filter, opts).Decode(&docID); err !=
nil {
+ return nil, err
+ }
+
+ return docID.ID, nil
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go
new file mode 100644
index 00000000000..c527631cd66
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction.go
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "reflect"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/mongo"
+)
+
+func init() {
+ beam.RegisterType(reflect.TypeOf((*idRangeRestriction)(nil)).Elem())
+ beam.RegisterType(reflect.TypeOf((*idRange)(nil)).Elem())
+}
+
+// idRangeRestriction represents a range of document IDs to read from MongoDB.
IDRange holds
+// information about the minimum and maximum IDs. CustomFilter is the custom
filter to apply when
+// reading from the collection. Count is the number of documents within the ID
range that match the
+// custom filter.
+type idRangeRestriction struct {
+ IDRange idRange
+ CustomFilter bson.M
+ Count int64
+}
+
+// newIDRangeRestriction creates a new idRangeRestriction and counts the
documents within the ID
+// range that match the custom filter.
+func newIDRangeRestriction(
+ ctx context.Context,
+ collection *mongo.Collection,
+ idRange idRange,
+ filter bson.M,
+) idRangeRestriction {
+ mergedFilter := mergeFilters(idRange.Filter(), filter)
+
+ count, err := collection.CountDocuments(ctx, mergedFilter)
+ if err != nil {
+ panic(err)
+ }
+
+ return idRangeRestriction{
+ IDRange: idRange,
+ CustomFilter: filter,
+ Count: count,
+ }
+}
+
+// Filter returns a bson.M filter based on the restriction's ID range and
custom filter.
+func (r idRangeRestriction) Filter() bson.M {
+ idFilter := r.IDRange.Filter()
+ return mergeFilters(idFilter, r.CustomFilter)
+}
+
+// mergeFilters merges the ID filter and the custom filter into a single
bson.M filter.
+func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M {
+ if len(idFilter) == 0 {
+ return customFilter
+ }
+
+ if len(customFilter) == 0 {
+ return idFilter
+ }
+
+ return bson.M{
+ "$and": []bson.M{idFilter, customFilter},
+ }
+}
+
+// SizedSplits divides the restriction into sub-restrictions based on the
desired bundle size in
+// bytes.
+func (r idRangeRestriction) SizedSplits(
+ ctx context.Context,
+ collection *mongo.Collection,
+ bundleSize int64,
+ useBucketAuto bool,
+) ([]idRangeRestriction, error) {
+ var idRanges []idRange
+ var err error
+
+ if useBucketAuto {
+ idRanges, err = bucketAutoSplits(ctx, collection, r.IDRange,
bundleSize)
+ } else {
+ idRanges, err = splitVectorSplits(ctx, collection, r.IDRange,
bundleSize)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return restrictionsFromIDRanges(ctx, collection, idRanges,
r.CustomFilter), err
+}
+
+// FractionSplits divides the restriction into a lower and higher ID
sub-restriction based on the
+// desired fraction of work the lower piece should be responsible for.
+func (r idRangeRestriction) FractionSplits(
+ ctx context.Context,
+ collection *mongo.Collection,
+ fraction float64,
+) (lower, higher idRangeRestriction, err error) {
+ skip := int64(math.Round(float64(r.Count) * fraction))
+
+ splitID, err := findID(ctx, collection, r.Filter(), 1, skip)
+ if err != nil {
+ if errors.Is(err, mongo.ErrNoDocuments) {
+ return idRangeRestriction{}, idRangeRestriction{}, nil
+ }
+
+ return idRangeRestriction{}, idRangeRestriction{}, fmt.Errorf(
+ "error finding document ID to split on: %w",
+ err,
+ )
+ }
+
+ lower = idRangeRestriction{
+ IDRange: idRange{
+ Min: r.IDRange.Min,
+ MinInclusive: r.IDRange.MinInclusive,
+ Max: splitID,
+ MaxInclusive: false,
+ },
+ CustomFilter: r.CustomFilter,
+ Count: skip,
+ }
+
+ higher = idRangeRestriction{
+ IDRange: idRange{
+ Min: splitID,
+ MinInclusive: true,
+ Max: r.IDRange.Max,
+ MaxInclusive: r.IDRange.MaxInclusive,
+ },
+ CustomFilter: r.CustomFilter,
+ Count: r.Count - skip,
+ }
+
+ return lower, higher, nil
+}
+
+// restrictionsFromIDRanges creates a slice of new restrictions based on the
ID ranges.
+func restrictionsFromIDRanges(
+ ctx context.Context,
+ collection *mongo.Collection,
+ idRanges []idRange,
+ customFilter bson.M,
+) []idRangeRestriction {
+ restrictions := make([]idRangeRestriction, len(idRanges))
+
+ for i := 0; i < len(idRanges); i++ {
+ rest := newIDRangeRestriction(
+ ctx,
+ collection,
+ idRanges[i],
+ customFilter,
+ )
+ restrictions[i] = rest
+ }
+
+ return restrictions
+}
+
+// idRange represents a range of document IDs in a MongoDB collection. It
stores information about
+// the minimum and maximum IDs, and whether they are inclusive or not.
+type idRange struct {
+ Min any
+ MinInclusive bool
+ Max any
+ MaxInclusive bool
+}
+
+// Filter creates a bson.M filter representation of the idRange.
+func (i idRange) Filter() bson.M {
+ filter := make(bson.M, 2)
+
+ if i.MinInclusive {
+ filter["$gte"] = i.Min
+ } else {
+ filter["$gt"] = i.Min
+ }
+
+ if i.MaxInclusive {
+ filter["$lte"] = i.Max
+ } else {
+ filter["$lt"] = i.Max
+ }
+
+ return bson.M{"_id": filter}
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go
new file mode 100644
index 00000000000..0534424ef05
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_restriction_test.go
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "go.mongodb.org/mongo-driver/bson"
+)
+
+func Test_mergeFilters(t *testing.T) {
+ tests := []struct {
+ name string
+ idFilter bson.M
+ filter bson.M
+ want bson.M
+ }{
+ {
+ name: "Merge ID filter and custom filter in an $and
filter",
+ idFilter: bson.M{
+ "_id": bson.M{
+ "$gte": 10,
+ },
+ },
+ filter: bson.M{
+ "key": bson.M{
+ "$ne": "value",
+ },
+ },
+ want: bson.M{
+ "$and": []bson.M{
+ {
+ "_id": bson.M{
+ "$gte": 10,
+ },
+ },
+ {
+ "key": bson.M{
+ "$ne": "value",
+ },
+ },
+ },
+ },
+ },
+ {
+ name: "Keep only ID filter when custom filter is empty",
+ idFilter: bson.M{
+ "_id": bson.M{
+ "$gte": 10,
+ },
+ },
+ filter: bson.M{},
+ want: bson.M{
+ "_id": bson.M{
+ "$gte": 10,
+ },
+ },
+ },
+ {
+ name: "Keep only custom filter when ID filter is
empty",
+ idFilter: bson.M{},
+ filter: bson.M{
+ "key": bson.M{
+ "$ne": "value",
+ },
+ },
+ want: bson.M{
+ "key": bson.M{
+ "$ne": "value",
+ },
+ },
+ },
+ {
+ name: "Empty filter when both ID filter and custom
filter are empty",
+ idFilter: bson.M{},
+ filter: bson.M{},
+ want: bson.M{},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := mergeFilters(tt.idFilter, tt.filter)
+ if diff := cmp.Diff(got, tt.want); diff != "" {
+ t.Errorf("mergeFilters() mismatch (-want +got):
%v", diff)
+ }
+ })
+ }
+}
+
+func Test_idRange_Filter(t *testing.T) {
+ tests := []struct {
+ name string
+ idRange idRange
+ want bson.M
+ }{
+ {
+ name: "ID filter with $gte when min is inclusive",
+ idRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 10,
+ MaxInclusive: false,
+ },
+ want: bson.M{
+ "_id": bson.M{
+ "$gte": 0,
+ "$lt": 10,
+ },
+ },
+ },
+ {
+ name: "ID filter with $gt when min is exclusive",
+ idRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 10,
+ MaxInclusive: false,
+ },
+ want: bson.M{
+ "_id": bson.M{
+ "$gt": 0,
+ "$lt": 10,
+ },
+ },
+ },
+ {
+ name: "ID filter with $lte when max is inclusive",
+ idRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 10,
+ MaxInclusive: true,
+ },
+ want: bson.M{
+ "_id": bson.M{
+ "$gte": 0,
+ "$lte": 10,
+ },
+ },
+ },
+ {
+ name: "ID filter with $lt when max is exclusive",
+ idRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 10,
+ MaxInclusive: false,
+ },
+ want: bson.M{
+ "_id": bson.M{
+ "$gte": 0,
+ "$lt": 10,
+ },
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.idRange.Filter()
+ if diff := cmp.Diff(tt.want, got); diff != "" {
+ t.Errorf("Filter() mismatch (-want +got): %v",
diff)
+ }
+ })
+ }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go
new file mode 100644
index 00000000000..87a6d952866
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split.go
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/mongo"
+ "go.mongodb.org/mongo-driver/mongo/options"
+ "go.mongodb.org/mongo-driver/mongo/readpref"
+)
+
+const (
+ maxBucketCount = math.MaxInt32
+ minSplitVectorChunkSize = 1024 * 1024
+ maxSplitVectorChunkSize = 1024 * 1024 * 1024
+)
+
+func bucketAutoSplits(
+ ctx context.Context,
+ collection *mongo.Collection,
+ outerRange idRange,
+ bundleSize int64,
+) ([]idRange, error) {
+ collSize, err := getCollectionSize(ctx, collection)
+ if err != nil {
+ return nil, err
+ }
+
+ bucketCount := calculateBucketCount(collSize, bundleSize)
+
+ buckets, err := getBuckets(ctx, collection, outerRange.Filter(),
bucketCount)
+ if err != nil {
+ return nil, err
+ }
+
+ return idRangesFromBuckets(buckets, outerRange), nil
+}
+
+func getCollectionSize(ctx context.Context, collection *mongo.Collection)
(int64, error) {
+ cmd := bson.M{"collStats": collection.Name()}
+ opts := options.RunCmd().SetReadPreference(readpref.Primary())
+
+ var stats struct {
+ Size int64 `bson:"size"`
+ }
+ if err := collection.Database().RunCommand(ctx, cmd,
opts).Decode(&stats); err != nil {
+ return 0, fmt.Errorf("error executing collStats command: %w",
err)
+ }
+
+ return stats.Size, nil
+}
+
+func calculateBucketCount(totalSize int64, bundleSize int64) int32 {
+ if bundleSize < 0 {
+ panic("monogdbio.calculateBucketCount: bundle size must be
greater than 0")
+ }
+
+ count := totalSize / bundleSize
+ if totalSize%bundleSize != 0 {
+ count++
+ }
+
+ if count > int64(maxBucketCount) {
+ count = maxBucketCount
+ }
+
+ return int32(count)
+}
+
+type bucket struct {
+ ID minMax `bson:"_id"`
+}
+
+type minMax struct {
+ Min any `bson:"min"`
+ Max any `bson:"max"`
+}
+
+func getBuckets(
+ ctx context.Context,
+ collection *mongo.Collection,
+ filter bson.M,
+ count int32,
+) ([]bucket, error) {
+ pipeline := mongo.Pipeline{
+ bson.D{{
+ Key: "$match",
+ Value: filter,
+ }},
+ bson.D{{
+ Key: "$bucketAuto",
+ Value: bson.M{
+ "groupBy": "$_id",
+ "buckets": count,
+ },
+ }},
+ }
+
+ opts := options.Aggregate().SetAllowDiskUse(true)
+
+ cursor, err := collection.Aggregate(ctx, pipeline, opts)
+ if err != nil {
+ return nil, fmt.Errorf("error executing bucketAuto aggregation:
%w", err)
+ }
+
+ var buckets []bucket
+ if err := cursor.All(ctx, &buckets); err != nil {
+ return nil, fmt.Errorf("error decoding buckets: %w", err)
+ }
+
+ return buckets, nil
+}
+
+func idRangesFromBuckets(buckets []bucket, outerRange idRange) []idRange {
+ if len(buckets) == 0 {
+ return nil
+ }
+
+ ranges := make([]idRange, len(buckets))
+
+ for i := 0; i < len(buckets); i++ {
+ subRange := idRange{}
+
+ if i == 0 {
+ subRange.MinInclusive = outerRange.MinInclusive
+ subRange.Min = outerRange.Min
+ } else {
+ subRange.Min = buckets[i].ID.Min
+ subRange.MinInclusive = true
+ }
+
+ if i == len(buckets)-1 {
+ subRange.Max = outerRange.Max
+ subRange.MaxInclusive = outerRange.MaxInclusive
+ } else {
+ subRange.Max = buckets[i].ID.Max
+ subRange.MaxInclusive = false
+ }
+
+ ranges[i] = subRange
+ }
+
+ return ranges
+}
+
+func splitVectorSplits(
+ ctx context.Context,
+ collection *mongo.Collection,
+ outerRange idRange,
+ bundleSize int64,
+) ([]idRange, error) {
+ chunkSize := getChunkSize(bundleSize)
+
+ splitKeys, err := getSplitKeys(ctx, collection, outerRange, chunkSize)
+ if err != nil {
+ return nil, err
+ }
+
+ return idRangesFromSplits(splitKeys, outerRange), nil
+}
+
+func getChunkSize(bundleSize int64) int64 {
+ var chunkSize int64
+
+ if bundleSize < minSplitVectorChunkSize {
+ chunkSize = minSplitVectorChunkSize
+ } else if bundleSize > maxSplitVectorChunkSize {
+ chunkSize = maxSplitVectorChunkSize
+ } else {
+ chunkSize = bundleSize
+ }
+
+ return chunkSize
+}
+
+func getSplitKeys(
+ ctx context.Context,
+ collection *mongo.Collection,
+ outerRange idRange,
+ maxChunkSizeBytes int64,
+) ([]documentID, error) {
+ database := collection.Database()
+ namespace := fmt.Sprintf("%s.%s", database.Name(), collection.Name())
+
+ cmd := bson.D{
+ {Key: "splitVector", Value: namespace},
+ {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}},
+ {Key: "min", Value: bson.D{{Key: "_id", Value:
outerRange.Min}}},
+ {Key: "max", Value: bson.D{{Key: "_id", Value:
outerRange.Max}}},
+ {Key: "maxChunkSizeBytes", Value: maxChunkSizeBytes},
+ }
+
+ opts := options.RunCmd().SetReadPreference(readpref.Primary())
+
+ var result struct {
+ SplitKeys []documentID `bson:"splitKeys"`
+ }
+ if err := database.RunCommand(ctx, cmd, opts).Decode(&result); err !=
nil {
+ return nil, fmt.Errorf("error executing splitVector command:
%w", err)
+ }
+
+ return result.SplitKeys, nil
+}
+
+func idRangesFromSplits(splitKeys []documentID, outerRange idRange) []idRange {
+ subRanges := make([]idRange, len(splitKeys)+1)
+
+ for i := 0; i < len(splitKeys)+1; i++ {
+ subRange := idRange{}
+
+ if i == 0 {
+ subRange.Min = outerRange.Min
+ subRange.MinInclusive = outerRange.MinInclusive
+ } else {
+ subRange.Min = splitKeys[i-1].ID
+ subRange.MinInclusive = true
+ }
+
+ if i == len(splitKeys) {
+ subRange.Max = outerRange.Max
+ subRange.MaxInclusive = outerRange.MaxInclusive
+ } else {
+ subRange.Max = splitKeys[i].ID
+ subRange.MaxInclusive = false
+ }
+
+ subRanges[i] = subRange
+ }
+
+ return subRanges
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go
new file mode 100644
index 00000000000..d3e867f6be2
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_split_test.go
@@ -0,0 +1,275 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "math"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func Test_calculateBucketCount(t *testing.T) {
+ tests := []struct {
+ name string
+ totalSize int64
+ bundleSize int64
+ want int32
+ }{
+ {
+ name: "Return ceiling of total size / bundle
size",
+ totalSize: 3 * 1024 * 1024,
+ bundleSize: 2 * 1024 * 1024,
+ want: 2,
+ },
+ {
+ name: "Return max int32 when calculated count is
greater than max int32",
+ totalSize: 1024 * 1024 * 1024 * 1024,
+ bundleSize: 1,
+ want: math.MaxInt32,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := calculateBucketCount(tt.totalSize,
tt.bundleSize); got != tt.want {
+ t.Errorf("calculateBucketCount() = %v, want
%v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_calculateBucketCountPanic(t *testing.T) {
+ t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T)
{
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("calculateBucketCount() does not
panic")
+ }
+ }()
+
+ calculateBucketCount(1024, 0)
+ })
+}
+
+func Test_idRangesFromBuckets(t *testing.T) {
+ tests := []struct {
+ name string
+ buckets []bucket
+ outerRange idRange
+ want []idRange
+ }{
+ {
+ name: "ID ranges with first element having min ID
configuration from outer range, and last element " +
+ "having max ID configuration from outer range",
+ buckets: []bucket{
+ {
+ ID: minMax{
+ Min: 5,
+ Max: 100,
+ },
+ },
+ {
+ ID: minMax{
+ Min: 100,
+ Max: 200,
+ },
+ },
+ {
+ ID: minMax{
+ Min: 200,
+ Max: 295,
+ },
+ },
+ },
+ outerRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 300,
+ MaxInclusive: false,
+ },
+ want: []idRange{
+ {
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ {
+ Min: 100,
+ MinInclusive: true,
+ Max: 200,
+ MaxInclusive: false,
+ },
+ {
+ Min: 200,
+ MinInclusive: true,
+ Max: 300,
+ MaxInclusive: false,
+ },
+ },
+ },
+ {
+ name: "ID ranges with one element having the same
configuration as outer range when there is one " +
+ "element in buckets",
+ buckets: []bucket{
+ {
+ ID: minMax{
+ Min: 5,
+ Max: 95,
+ },
+ },
+ },
+ outerRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ want: []idRange{
+ {
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ },
+ },
+ {
+ name: "Empty ID ranges when there are no elements in
buckets",
+ buckets: nil,
+ outerRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ want: nil,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := idRangesFromBuckets(tt.buckets, tt.outerRange)
+ if diff := cmp.Diff(got, tt.want); diff != "" {
+ t.Errorf("idRangesFromBuckets() mismatch (-want
+got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func Test_getChunkSize(t *testing.T) {
+ tests := []struct {
+ name string
+ bundleSize int64
+ want int64
+ }{
+ {
+ name: "Return 1 MB if bundle size is less than 1
MB",
+ bundleSize: 1024,
+ want: 1024 * 1024,
+ },
+ {
+ name: "Return 1 GB if bundle size is greater than
1 GB",
+ bundleSize: 2 * 1024 * 1024 * 1024,
+ want: 1024 * 1024 * 1024,
+ },
+ {
+ name: "Return bundle size if bundle size is
between 1 MB and 1 GB",
+ bundleSize: 4 * 1024 * 1024,
+ want: 4 * 1024 * 1024,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := getChunkSize(tt.bundleSize); got != tt.want {
+ t.Errorf("getChunkSize() = %v, want %v", got,
tt.want)
+ }
+ })
+ }
+}
+
+func Test_idRangesFromSplits(t *testing.T) {
+ tests := []struct {
+ name string
+ splitKeys []documentID
+ outerRange idRange
+ want []idRange
+ }{
+ {
+ name: "ID ranges with first element having min ID
configuration from outer range, and last element " +
+ "having max ID configuration from outer range",
+ splitKeys: []documentID{
+ {
+ ID: 100,
+ },
+ {
+ ID: 200,
+ },
+ },
+ outerRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 300,
+ MaxInclusive: false,
+ },
+ want: []idRange{
+ {
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ {
+ Min: 100,
+ MinInclusive: true,
+ Max: 200,
+ MaxInclusive: false,
+ },
+ {
+ Min: 200,
+ MinInclusive: true,
+ Max: 300,
+ MaxInclusive: false,
+ },
+ },
+ },
+ {
+ name: "ID ranges with one element having the same
configuration as outer range when there are no " +
+ "elements in key splits",
+ splitKeys: nil,
+ outerRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: true,
+ },
+ want: []idRange{
+ {
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: true,
+ },
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := idRangesFromSplits(tt.splitKeys, tt.outerRange)
+ if diff := cmp.Diff(got, tt.want); diff != "" {
+ t.Errorf("idRangesFromSplits() mismatch (-want
+got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go
new file mode 100644
index 00000000000..6b92ab57d49
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker.go
@@ -0,0 +1,194 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "reflect"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "go.mongodb.org/mongo-driver/mongo"
+)
+
+func init() {
+ beam.RegisterType(reflect.TypeOf((*idRangeTracker)(nil)))
+}
+
+// idRangeTracker is a tracker of an idRangeRestriction.
+type idRangeTracker struct {
+ rest idRangeRestriction
+ collection *mongo.Collection
+ claimed int64
+ claimedID any
+ stopped bool
+ err error
+}
+
+// newIDRangeTracker creates a new idRangeTracker tracking the provided
idRangeRestriction.
+func newIDRangeTracker(rest idRangeRestriction, collection *mongo.Collection)
*idRangeTracker {
+ return &idRangeTracker{
+ rest: rest,
+ collection: collection,
+ }
+}
+
+// cursorResult holds information about the next document to process from
MongoDB. nextID is the ID
+// of the document. isExhausted is whether the cursor has been exhausted.
+type cursorResult struct {
+ nextID any
+ isExhausted bool
+}
+
+// TryClaim accepts a position representing a cursorResult of a document to
read from MongoDB. The
+// position is successfully claimed if the tracker has not yet completed the
work within its
+// restriction and the cursor has not been exhausted.
+func (rt *idRangeTracker) TryClaim(pos any) (ok bool) {
+ result, ok := pos.(cursorResult)
+ if !ok {
+ rt.err = fmt.Errorf("invalid pos type: %T", pos)
+ return false
+ }
+
+ if rt.IsDone() {
+ return false
+ }
+
+ if result.isExhausted {
+ rt.stopped = true
+ return false
+ }
+
+ rt.claimed++
+ rt.claimedID = result.nextID
+
+ return true
+}
+
+// GetError returns the error associated with the tracker, if any.
+func (rt *idRangeTracker) GetError() error {
+ return rt.err
+}
+
+// TrySplit splits the underlying restriction into a primary and residual
restriction based on the
+// fraction of remaining work the primary should be responsible for. The
restriction may be modified
+// as a result of the split. The primary is a copy of the tracker's
restriction after the split.
+// If the fraction is 1 or all work has already been claimed, returns the full
restriction as the
+// primary and nil as the residual. If the fraction is 0, stops the tracker,
cuts off any remaining
+// work from its underlying restriction, and returns a residual representing
all remaining work.
+// If the fraction is between 0 and 1, attempts to split the remaining work of
the underlying
+// restriction into two sub-restrictions based on the fraction and assigns
them to the primary and
+// residual respectively. Returns an error if the split cannot be performed.
+func (rt *idRangeTracker) TrySplit(fraction float64) (primary, residual any,
err error) {
+ if fraction < 0 || fraction > 1 {
+ return nil, nil, errors.New("fraction must be between 0 and 1")
+ }
+
+ done, remaining := rt.cutRestriction()
+
+ if fraction == 1 || remaining.Count == 0 {
+ return rt.rest, nil, nil
+ }
+
+ if fraction == 0 {
+ rt.rest = done
+ return rt.rest, remaining, nil
+ }
+
+ ctx := context.Background()
+
+ primaryRem, resid, err := remaining.FractionSplits(ctx, rt.collection,
fraction)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if resid.Count == 0 {
+ return rt.rest, nil, nil
+ }
+
+ if primaryRem.Count == 0 {
+ rt.rest = done
+ return rt.rest, remaining, nil
+ }
+
+ rt.rest.IDRange.Max = primaryRem.IDRange.Max
+ rt.rest.IDRange.MaxInclusive = primaryRem.IDRange.MaxInclusive
+ rt.rest.Count -= resid.Count
+
+ return rt.rest, resid, nil
+}
+
+// cutRestriction returns two restrictions: done represents the amount of work
from the underlying
+// restriction that has already been completed, and remaining represents the
amount that remains to
+// be processed. Does not modify the underlying restriction.
+func (rt *idRangeTracker) cutRestriction() (done idRangeRestriction, remaining
idRangeRestriction) {
+ minRem := rt.claimedID
+ minInclusiveRem := false
+ maxInclusiveDone := true
+
+ if minRem == nil {
+ minRem = rt.rest.IDRange.Min
+ minInclusiveRem = rt.rest.IDRange.MinInclusive
+ maxInclusiveDone = false
+ }
+
+ done = idRangeRestriction{
+ IDRange: idRange{
+ Min: rt.rest.IDRange.Min,
+ MinInclusive: rt.rest.IDRange.MinInclusive,
+ Max: minRem,
+ MaxInclusive: maxInclusiveDone,
+ },
+ CustomFilter: rt.rest.CustomFilter,
+ Count: rt.claimed,
+ }
+
+ remaining = idRangeRestriction{
+ IDRange: idRange{
+ Min: minRem,
+ MinInclusive: minInclusiveRem,
+ Max: rt.rest.IDRange.Max,
+ MaxInclusive: rt.rest.IDRange.MaxInclusive,
+ },
+ CustomFilter: rt.rest.CustomFilter,
+ Count: rt.rest.Count - rt.claimed,
+ }
+
+ return done, remaining
+}
+
+// GetProgress returns the amount of done and remaining work, represented by
the count of documents.
+func (rt *idRangeTracker) GetProgress() (done float64, remaining float64) {
+ done = float64(rt.claimed)
+ remaining = float64(rt.rest.Count - rt.claimed)
+ return
+}
+
+// IsDone returns true if all work within the tracker's restriction has been
completed.
+func (rt *idRangeTracker) IsDone() bool {
+ return rt.stopped || rt.claimed == rt.rest.Count
+}
+
+// GetRestriction returns a copy of the restriction the tracker is tracking.
+func (rt *idRangeTracker) GetRestriction() any {
+ return rt.rest
+}
+
+// IsBounded returns whether the tracker is tracking a restriction with a
finite amount of work.
+func (*idRangeTracker) IsBounded() bool {
+ return true
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go
b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go
new file mode 100644
index 00000000000..a4484ee1007
--- /dev/null
+++ b/sdks/go/pkg/beam/io/mongodbio/id_range_tracker_test.go
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package mongodbio
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "go.mongodb.org/mongo-driver/bson"
+)
+
+func Test_idRangeTracker_TryClaim(t *testing.T) {
+ tests := []struct {
+ name string
+ tracker *idRangeTracker
+ pos any
+ wantOk bool
+ wantClaimed int64
+ wantClaimedID any
+ wantDone bool
+ wantErr bool
+ }{
+ {
+ name: "Return true when claimed count < total count -
1",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 5,
+ claimedID: 123,
+ },
+ pos: cursorResult{nextID: 124},
+ wantOk: true,
+ wantClaimed: 6,
+ wantClaimedID: 124,
+ wantDone: false,
+ },
+ {
+ name: "Return true and set to done when claimed count
== total count - 1",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 9,
+ claimedID: 123,
+ },
+ pos: cursorResult{nextID: 124},
+ wantOk: true,
+ wantClaimed: 10,
+ wantClaimedID: 124,
+ wantDone: true,
+ },
+ {
+ name: "Return false and set to done when cursor is
exhausted",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 5,
+ claimedID: 123,
+ },
+ pos: cursorResult{nextID: 124, isExhausted:
true},
+ wantOk: false,
+ wantClaimed: 5,
+ wantClaimedID: 123,
+ wantDone: true,
+ },
+ {
+ name: "Return false when claimed count == total count",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 10,
+ claimedID: 123,
+ },
+ pos: cursorResult{nextID: 124},
+ wantOk: false,
+ wantClaimed: 10,
+ wantClaimedID: 123,
+ wantDone: true,
+ },
+ {
+ name: "Return false when tracker is stopped",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 10,
+ claimedID: 123,
+ stopped: true,
+ },
+ pos: cursorResult{nextID: 124},
+ wantOk: false,
+ wantClaimed: 10,
+ wantClaimedID: 123,
+ wantDone: true,
+ },
+ {
+ name: "Return false and set error when pos is of
invalid type",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 10,
+ },
+ claimed: 5,
+ claimedID: 123,
+ },
+ pos: "invalid",
+ wantOk: false,
+ wantClaimed: 5,
+ wantClaimedID: 123,
+ wantDone: false,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if gotOk := tt.tracker.TryClaim(tt.pos); gotOk !=
tt.wantOk {
+ t.Errorf("TryClaim() = %v, want %v", gotOk,
tt.wantOk)
+ }
+ if gotClaimed := tt.tracker.claimed; gotClaimed !=
tt.wantClaimed {
+ t.Errorf("claimed = %v, want %v", gotClaimed,
tt.wantClaimed)
+ }
+ if gotClaimedID := tt.tracker.claimedID;
!cmp.Equal(gotClaimedID, tt.wantClaimedID) {
+ t.Errorf("claimedID = %v, want %v",
gotClaimedID, tt.wantClaimedID)
+ }
+ if gotDone := tt.tracker.IsDone(); gotDone !=
tt.wantDone {
+ t.Errorf("IsDone() = %v, want %v", gotDone,
tt.wantDone)
+ }
+ if gotErr := tt.tracker.GetError(); (gotErr != nil) !=
tt.wantErr {
+ t.Errorf("GetError() error = %v, wantErr %v",
gotErr, tt.wantErr)
+ }
+ })
+ }
+}
+
+func Test_idRangeTracker_TrySplit(t *testing.T) {
+ tests := []struct {
+ name string
+ tracker *idRangeTracker
+ fraction float64
+ wantPrimary any
+ wantResidual any
+ wantErr bool
+ wantDone bool
+ }{
+ {
+ name: "Primary contains no more work and residual
contains all remaining work when fraction is 0",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ claimed: 70,
+ claimedID: 69,
+ },
+ fraction: 0,
+ wantPrimary: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 69,
+ MaxInclusive: true,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 70,
+ },
+ wantResidual: idRangeRestriction{
+ IDRange: idRange{
+ Min: 69,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 30,
+ },
+ wantDone: true,
+ },
+ {
+ name: "Primary contains all original work and residual
is nil when fraction is 1",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ claimed: 70,
+ claimedID: 69,
+ },
+ fraction: 1,
+ wantPrimary: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ wantResidual: nil,
+ wantDone: false,
+ },
+ {
+ name: "Primary contains all original work and residual
is nil when the total count has been claimed",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ claimed: 100,
+ claimedID: 99,
+ },
+ fraction: 1,
+ wantPrimary: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ wantResidual: nil,
+ wantDone: true,
+ },
+ {
+ name: "Error - fraction is less than 0",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{Count: 100},
+ },
+ fraction: -0.1,
+ wantErr: true,
+ },
+ {
+ name: "Error - fraction is greater than 1",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{Count: 100},
+ },
+ fraction: 1.1,
+ wantErr: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotPrimary, gotResidual, err :=
tt.tracker.TrySplit(tt.fraction)
+ if (err != nil) != tt.wantErr {
+ t.Fatalf("TrySplit() error = %v, wantErr %v",
err, tt.wantErr)
+ }
+ if diff := cmp.Diff(gotPrimary, tt.wantPrimary); diff
!= "" {
+ t.Errorf("TrySplit() gotPrimary mismatch (-want
+got):\n%s", diff)
+ }
+ if diff := cmp.Diff(gotResidual, tt.wantResidual); diff
!= "" {
+ t.Errorf("TrySplit() gotResidual mismatch
(-want +got):\n%s", diff)
+ }
+ if tt.tracker.IsDone() != tt.wantDone {
+ t.Errorf("IsDone() = %v, want %v",
tt.tracker.IsDone(), tt.wantDone)
+ }
+ })
+ }
+}
+
+func Test_idRangeTracker_cutRestriction(t *testing.T) {
+ tests := []struct {
+ name string
+ tracker *idRangeTracker
+ wantDone idRangeRestriction
+ wantRemaining idRangeRestriction
+ }{
+ {
+ name: "The tracker's claimedID is used as the max
(inclusive) in the done restriction " +
+ "and as the min (exclusive) in the remaining
restriction when claimedID is not nil",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ claimed: 70,
+ claimedID: 69,
+ },
+ wantDone: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: true,
+ Max: 69,
+ MaxInclusive: true,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 70,
+ },
+ wantRemaining: idRangeRestriction{
+ IDRange: idRange{
+ Min: 69,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 30,
+ },
+ },
+ {
+ name: "The tracker's restriction's min ID is used as
the max (exclusive) in the done restriction " +
+ "and as the min in the remaining restriction
when claimedID is nil",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ claimed: 0,
+ claimedID: nil,
+ },
+ wantDone: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 0,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 0,
+ },
+ wantRemaining: idRangeRestriction{
+ IDRange: idRange{
+ Min: 0,
+ MinInclusive: false,
+ Max: 100,
+ MaxInclusive: false,
+ },
+ CustomFilter: bson.M{"key": "val"},
+ Count: 100,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotDone, gotRemaining := tt.tracker.cutRestriction()
+ if diff := cmp.Diff(gotDone, tt.wantDone); diff != "" {
+ t.Errorf("cutRestriction() gotDone mismatch
(-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(gotRemaining, tt.wantRemaining);
diff != "" {
+ t.Errorf("cutRestriction() gotRemaining
mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func Test_idRangeTracker_GetProgress(t *testing.T) {
+ tracker := &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 100,
+ },
+ claimed: 30,
+ }
+ wantDone := float64(30)
+ wantRemaining := float64(70)
+
+ t.Run(
+ "Done is represented by claimed count, and remaining by total
count - claimed count",
+ func(t *testing.T) {
+ gotDone, gotRemaining := tracker.GetProgress()
+ if gotDone != wantDone {
+ t.Errorf("GetProgress() gotDone = %v, want %v",
gotDone, wantDone)
+ }
+ if gotRemaining != wantRemaining {
+ t.Errorf("GetProgress() gotRemaining = %v, want
%v", gotRemaining, wantRemaining)
+ }
+ },
+ )
+}
+
+func Test_idRangeTracker_IsDone(t *testing.T) {
+ tests := []struct {
+ name string
+ tracker *idRangeTracker
+ want bool
+ }{
+ {
+ name: "True when the tracker's claimed count is equal
to the total count",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 100,
+ },
+ claimed: 100,
+ },
+ want: true,
+ },
+ {
+ name: "True when the tracker is stopped",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 100,
+ },
+ claimed: 95,
+ stopped: true,
+ },
+ want: true,
+ },
+ {
+ name: "False when the tracker is not stopped and its
claimed count is less than the total count",
+ tracker: &idRangeTracker{
+ rest: idRangeRestriction{
+ Count: 100,
+ },
+ claimed: 95,
+ },
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.tracker.IsDone(); got != tt.want {
+ t.Errorf("IsDone() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/sdks/go/pkg/beam/io/mongodbio/read.go
b/sdks/go/pkg/beam/io/mongodbio/read.go
index b12a6d7738b..59d8cf6aef9 100644
--- a/sdks/go/pkg/beam/io/mongodbio/read.go
+++ b/sdks/go/pkg/beam/io/mongodbio/read.go
@@ -17,35 +17,28 @@ package mongodbio
import (
"context"
+ "errors"
"fmt"
- "math"
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
"github.com/apache/beam/sdks/v2/go/pkg/beam/util/structx"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
- "go.mongodb.org/mongo-driver/mongo/readpref"
)
const (
defaultReadBundleSize = 64 * 1024 * 1024
-
- minSplitVectorChunkSize = 1024 * 1024
- maxSplitVectorChunkSize = 1024 * 1024 * 1024
-
- maxBucketCount = math.MaxInt32
)
func init() {
- register.DoFn3x1[context.Context, []byte, func(bson.M),
error](&bucketAutoFn{})
- register.DoFn3x1[context.Context, []byte, func(bson.M),
error](&splitVectorFn{})
- register.Emitter1[bson.M]()
-
- register.DoFn3x1[context.Context, bson.M, func(beam.Y),
error](&readFn{})
+ register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte,
func(beam.Y), error](
+ &readFn{},
+ )
register.Emitter1[beam.Y]()
}
@@ -92,338 +85,169 @@ func Read(
imp := beam.Impulse(s)
- var bundled beam.PCollection
-
- if option.BucketAuto {
- bundled = beam.ParDo(s, newBucketAutoFn(uri, database,
collection, option), imp)
- } else {
- bundled = beam.ParDo(s, newSplitVectorFn(uri, database,
collection, option), imp)
- }
-
return beam.ParDo(
s,
newReadFn(uri, database, collection, t, option),
- bundled,
+ imp,
beam.TypeDefinition{Var: beam.YType, T: t},
)
}
-type bucketAutoFn struct {
+type readFn struct {
mongoDBFn
+ BucketAuto bool
BundleSize int64
+ Filter []byte
+ Type beam.EncodedType
+ filter bson.M
+ projection bson.D
}
-func newBucketAutoFn(
+func newReadFn(
uri string,
database string,
collection string,
+ t reflect.Type,
option *ReadOption,
-) *bucketAutoFn {
- return &bucketAutoFn{
+) *readFn {
+ filter, err := encodeBSON[bson.M](option.Filter)
+ if err != nil {
+ panic(fmt.Sprintf("mongodbio.newReadFn: %v", err))
+ }
+
+ return &readFn{
mongoDBFn: mongoDBFn{
URI: uri,
Database: database,
Collection: collection,
},
+ BucketAuto: option.BucketAuto,
BundleSize: option.BundleSize,
+ Filter: filter,
+ Type: beam.EncodedType{T: t},
}
}
-func (fn *bucketAutoFn) ProcessElement(
- ctx context.Context,
- _ []byte,
- emit func(bson.M),
-) error {
- collectionSize, err := fn.getCollectionSize(ctx)
- if err != nil {
+func (fn *readFn) Setup(ctx context.Context) error {
+ var err error
+ if err = fn.mongoDBFn.Setup(ctx); err != nil {
return err
}
- if collectionSize == 0 {
- return nil
- }
-
- bucketCount := calculateBucketCount(collectionSize, fn.BundleSize)
-
- buckets, err := fn.getBuckets(ctx, bucketCount)
+ fn.filter, err = decodeBSON[bson.M](fn.Filter)
if err != nil {
return err
}
- idFilters := idFiltersFromBuckets(buckets)
-
- for _, filter := range idFilters {
- emit(filter)
- }
+ fn.projection = inferProjection(fn.Type.T, bsonTag)
return nil
}
-type collStats struct {
- Size int64 `bson:"size"`
-}
-
-func (fn *bucketAutoFn) getCollectionSize(ctx context.Context) (int64, error) {
- cmd := bson.M{"collStats": fn.Collection}
- opts := options.RunCmd().SetReadPreference(readpref.Primary())
-
- var stats collStats
- if err := fn.collection.Database().RunCommand(ctx, cmd,
opts).Decode(&stats); err != nil {
- return 0, fmt.Errorf("error executing collStats command: %w",
err)
- }
-
- return stats.Size, nil
-}
-
-func calculateBucketCount(collectionSize int64, bundleSize int64) int32 {
- if bundleSize < 0 {
- panic("monogdbio.calculateBucketCount: bundle size must be
greater than 0")
+func inferProjection(t reflect.Type, tagKey string) bson.D {
+ names := structx.InferFieldNames(t, tagKey)
+ if len(names) == 0 {
+ panic("mongodbio.inferProjection: no names to infer projection
from")
}
- count := collectionSize / bundleSize
- if collectionSize%bundleSize != 0 {
- count++
- }
+ projection := make(bson.D, len(names))
- if count > int64(maxBucketCount) {
- count = maxBucketCount
+ for i, name := range names {
+ projection[i] = bson.E{Key: name, Value: 1}
}
- return int32(count)
-}
-
-type bucket struct {
- ID minMax `bson:"_id"`
-}
-
-type minMax struct {
- Min any `bson:"min"`
- Max any `bson:"max"`
+ return projection
}
-func (fn *bucketAutoFn) getBuckets(ctx context.Context, count int32)
([]bucket, error) {
- pipeline := mongo.Pipeline{bson.D{{
- Key: "$bucketAuto",
- Value: bson.M{
- "groupBy": "$_id",
- "buckets": count,
- },
- }}}
-
- opts := options.Aggregate().SetAllowDiskUse(true)
-
- cursor, err := fn.collection.Aggregate(ctx, pipeline, opts)
- if err != nil {
- return nil, fmt.Errorf("error executing bucketAuto aggregation:
%w", err)
+func (fn *readFn) CreateInitialRestriction(_ []byte) idRangeRestriction {
+ ctx := context.Background()
+ if err := fn.Setup(ctx); err != nil {
+ panic(err)
}
- var buckets []bucket
- if err = cursor.All(ctx, &buckets); err != nil {
- return nil, fmt.Errorf("error decoding buckets: %w", err)
- }
-
- return buckets, nil
-}
-
-func idFiltersFromBuckets(buckets []bucket) []bson.M {
- idFilters := make([]bson.M, len(buckets))
-
- for i := 0; i < len(buckets); i++ {
- filter := bson.M{}
-
- if i != 0 {
- filter["$gt"] = buckets[i].ID.Min
- }
-
- if i != len(buckets)-1 {
- filter["$lte"] = buckets[i].ID.Max
+ outerRange, err := findOuterIDRange(ctx, fn.collection, fn.filter)
+ if err != nil {
+ if errors.Is(err, mongo.ErrNoDocuments) {
+ log.Infof(
+ ctx,
+ "No documents in collection %s.%s match the
provided filter",
+ fn.Database,
+ fn.Collection,
+ )
+ return idRangeRestriction{}
}
- if len(filter) == 0 {
- idFilters[i] = filter
- } else {
- idFilters[i] = bson.M{"_id": filter}
- }
+ panic(err)
}
- return idFilters
-}
-
-type splitVectorFn struct {
- mongoDBFn
- BundleSize int64
-}
-
-func newSplitVectorFn(
- uri string,
- database string,
- collection string,
- option *ReadOption,
-) *splitVectorFn {
- return &splitVectorFn{
- mongoDBFn: mongoDBFn{
- URI: uri,
- Database: database,
- Collection: collection,
- },
- BundleSize: option.BundleSize,
- }
+ return newIDRangeRestriction(
+ ctx,
+ fn.collection,
+ outerRange,
+ fn.filter,
+ )
}
-func (fn *splitVectorFn) ProcessElement(
+func findOuterIDRange(
ctx context.Context,
- _ []byte,
- emit func(bson.M),
-) error {
- chunkSize := getChunkSize(fn.BundleSize)
-
- splitKeys, err := fn.getSplitKeys(ctx, chunkSize)
+ collection *mongo.Collection,
+ filter bson.M,
+) (idRange, error) {
+ minID, err := findID(ctx, collection, filter, 1, 0)
if err != nil {
- return err
+ return idRange{}, err
}
- idFilters := idFiltersFromSplits(splitKeys)
-
- for _, filter := range idFilters {
- emit(filter)
+ maxID, err := findID(ctx, collection, filter, -1, 0)
+ if err != nil {
+ return idRange{}, err
}
- return nil
-}
-
-func getChunkSize(bundleSize int64) int64 {
- var chunkSize int64
-
- if bundleSize < minSplitVectorChunkSize {
- chunkSize = minSplitVectorChunkSize
- } else if bundleSize > maxSplitVectorChunkSize {
- chunkSize = maxSplitVectorChunkSize
- } else {
- chunkSize = bundleSize
+ outerRange := idRange{
+ Min: minID,
+ MinInclusive: true,
+ Max: maxID,
+ MaxInclusive: true,
}
- return chunkSize
-}
-
-type splitVector struct {
- SplitKeys []splitKey `bson:"splitKeys"`
-}
-
-type splitKey struct {
- ID any `bson:"_id"`
+ return outerRange, nil
}
-func (fn *splitVectorFn) getSplitKeys(ctx context.Context, chunkSize int64)
([]splitKey, error) {
- cmd := bson.D{
- {Key: "splitVector", Value: fmt.Sprintf("%s.%s", fn.Database,
fn.Collection)},
- {Key: "keyPattern", Value: bson.D{{Key: "_id", Value: 1}}},
- {Key: "maxChunkSizeBytes", Value: chunkSize},
+func (fn *readFn) SplitRestriction(_ []byte, rest idRangeRestriction)
[]idRangeRestriction {
+ if rest.Count == 0 {
+ return []idRangeRestriction{rest}
}
- opts := options.RunCmd().SetReadPreference(readpref.Primary())
-
- var vector splitVector
- if err := fn.collection.Database().RunCommand(ctx, cmd,
opts).Decode(&vector); err != nil {
- return nil, fmt.Errorf("error executing splitVector command:
%w", err)
+ ctx := context.Background()
+ if err := fn.Setup(ctx); err != nil {
+ panic(err)
}
- return vector.SplitKeys, nil
-}
-
-func idFiltersFromSplits(splitKeys []splitKey) []bson.M {
- idFilters := make([]bson.M, len(splitKeys)+1)
-
- for i := 0; i < len(splitKeys)+1; i++ {
- filter := bson.M{}
-
- if i > 0 {
- filter["$gt"] = splitKeys[i-1].ID
- }
-
- if i < len(splitKeys) {
- filter["$lte"] = splitKeys[i].ID
- }
-
- if len(filter) == 0 {
- idFilters[i] = filter
- } else {
- idFilters[i] = bson.M{"_id": filter}
- }
- }
-
- return idFilters
-}
-
-type readFn struct {
- mongoDBFn
- Filter []byte
- Type beam.EncodedType
- projection bson.D
- filter bson.M
-}
-
-func newReadFn(
- uri string,
- database string,
- collection string,
- t reflect.Type,
- option *ReadOption,
-) *readFn {
- filter, err := encodeBSONMap(option.Filter)
+ splits, err := rest.SizedSplits(ctx, fn.collection, fn.BundleSize,
fn.BucketAuto)
if err != nil {
- panic(fmt.Sprintf("mongodbio.newReadFn: %v", err))
+ panic(err)
}
- return &readFn{
- mongoDBFn: mongoDBFn{
- URI: uri,
- Database: database,
- Collection: collection,
- },
- Filter: filter,
- Type: beam.EncodedType{T: t},
- }
+ return splits
}
-func (fn *readFn) Setup(ctx context.Context) error {
- if err := fn.mongoDBFn.Setup(ctx); err != nil {
- return err
- }
-
- filter, err := decodeBSONMap(fn.Filter)
- if err != nil {
- return err
- }
-
- fn.filter = filter
- fn.projection = inferProjection(fn.Type.T, bsonTag)
-
- return nil
+func (fn *readFn) CreateTracker(rest idRangeRestriction) *sdf.LockRTracker {
+ return sdf.NewLockRTracker(newIDRangeTracker(rest, fn.collection))
}
-func inferProjection(t reflect.Type, tagKey string) bson.D {
- names := structx.InferFieldNames(t, tagKey)
- if len(names) == 0 {
- panic("mongodbio.inferProjection: no names to infer projection
from")
- }
-
- projection := make(bson.D, len(names))
-
- for i, name := range names {
- projection[i] = bson.E{Key: name, Value: 1}
- }
-
- return projection
+func (fn *readFn) RestrictionSize(_ []byte, rest idRangeRestriction) float64 {
+ return float64(rest.Count)
}
func (fn *readFn) ProcessElement(
ctx context.Context,
- elem bson.M,
+ rt *sdf.LockRTracker,
+ _ []byte,
emit func(beam.Y),
) (err error) {
- mergedFilter := mergeFilters(elem, fn.filter)
+ rest := rt.GetRestriction().(idRangeRestriction)
- cursor, err := fn.findDocuments(ctx, fn.projection, mergedFilter)
+ cursor, err := fn.getCursor(ctx, rest.Filter())
if err != nil {
return err
}
@@ -442,53 +266,53 @@ func (fn *readFn) ProcessElement(
}()
for cursor.Next(ctx) {
- value, err := decodeDocument(cursor, fn.Type.T)
+ id, value, err := decodeDocument(cursor, fn.Type.T)
if err != nil {
return err
}
- emit(value)
- }
-
- return cursor.Err()
-}
+ result := cursorResult{nextID: id}
+ if !rt.TryClaim(result) {
+ return cursor.Err()
+ }
-func mergeFilters(idFilter bson.M, customFilter bson.M) bson.M {
- if len(idFilter) == 0 {
- return customFilter
+ emit(value)
}
- if len(customFilter) == 0 {
- return idFilter
- }
+ result := cursorResult{isExhausted: true}
+ rt.TryClaim(result)
- return bson.M{
- "$and": []bson.M{idFilter, customFilter},
- }
+ return cursor.Err()
}
-func (fn *readFn) findDocuments(
+func (fn *readFn) getCursor(
ctx context.Context,
- projection bson.D,
filter bson.M,
) (*mongo.Cursor, error) {
- opts := options.Find().SetProjection(projection)
+ opts := options.Find().
+ SetProjection(fn.projection).
+ SetSort(bson.M{"_id": 1})
cursor, err := fn.collection.Find(ctx, filter, opts)
if err != nil {
- return nil, fmt.Errorf("error finding documents: %w", err)
+ return nil, fmt.Errorf("error executing find command: %w", err)
}
return cursor, nil
}
-func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (any, error) {
+func decodeDocument(cursor *mongo.Cursor, t reflect.Type) (id any, value any,
err error) {
+ var docID documentID
+ if err := cursor.Decode(&docID); err != nil {
+ return nil, nil, fmt.Errorf("error decoding document ID: %w",
err)
+ }
+
out := reflect.New(t).Interface()
if err := cursor.Decode(out); err != nil {
- return nil, fmt.Errorf("error decoding document: %w", err)
+ return nil, nil, fmt.Errorf("error decoding document: %w", err)
}
- value := reflect.ValueOf(out).Elem().Interface()
+ value = reflect.ValueOf(out).Elem().Interface()
- return value, nil
+ return docID.ID, value, nil
}
diff --git a/sdks/go/pkg/beam/io/mongodbio/read_test.go
b/sdks/go/pkg/beam/io/mongodbio/read_test.go
index 5899457d5a8..666960b17b9 100644
--- a/sdks/go/pkg/beam/io/mongodbio/read_test.go
+++ b/sdks/go/pkg/beam/io/mongodbio/read_test.go
@@ -16,7 +16,6 @@
package mongodbio
import (
- "math"
"reflect"
"testing"
@@ -24,250 +23,6 @@ import (
"go.mongodb.org/mongo-driver/bson"
)
-func Test_calculateBucketCount(t *testing.T) {
- tests := []struct {
- name string
- collectionSize int64
- bundleSize int64
- want int32
- }{
- {
- name: "Return ceiling of collection size /
bundle size",
- collectionSize: 3 * 1024 * 1024,
- bundleSize: 2 * 1024 * 1024,
- want: 2,
- },
- {
- name: "Return max int32 when calculated count
is greater than max int32",
- collectionSize: 1024 * 1024 * 1024 * 1024,
- bundleSize: 1,
- want: math.MaxInt32,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := calculateBucketCount(tt.collectionSize,
tt.bundleSize); got != tt.want {
- t.Errorf("calculateBucketCount() = %v, want
%v", got, tt.want)
- }
- })
- }
-}
-
-func Test_calculateBucketCountPanic(t *testing.T) {
- t.Run("Panic when bundleSize is not greater than 0", func(t *testing.T)
{
- defer func() {
- if r := recover(); r == nil {
- t.Errorf("calculateBucketCount() does not
panic")
- }
- }()
-
- calculateBucketCount(1024, 0)
- })
-}
-
-func Test_idFiltersFromBuckets(t *testing.T) {
- tests := []struct {
- name string
- buckets []bucket
- want []bson.M
- }{
- {
- name: "Create one $lte filter for start range, one $gt
filter for end range, and filters with both " +
- "$lte and $gt for ranges in between when there
are three or more bucket elements",
- buckets: []bucket{
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5378"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5384"),
- },
- },
- },
- want: []bson.M{
- {
- "_id": bson.M{
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- },
- },
- {
- name: "Create one $lte filter for start range and one
$gt filter for end range when there are two " +
- "bucket elements",
- buckets: []bucket{
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5378"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- },
- want: []bson.M{
- {
- "_id": bson.M{
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- },
- },
- {
- name: "Create an empty filter when there is one bucket
element",
- buckets: []bucket{
- {
- ID: minMax{
- Min: objectIDFromHex(t,
"6384e03f24f854c1a8ce5378"),
- Max: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- },
- want: []bson.M{{}},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := idFiltersFromBuckets(tt.buckets);
!cmp.Equal(got, tt.want) {
- t.Errorf("idFiltersFromBuckets() = %v, want
%v", got, tt.want)
- }
- })
- }
-}
-
-func Test_getChunkSize(t *testing.T) {
- tests := []struct {
- name string
- bundleSize int64
- want int64
- }{
- {
- name: "Return 1 MB if bundle size is less than 1
MB",
- bundleSize: 1024,
- want: 1024 * 1024,
- },
- {
- name: "Return 1 GB if bundle size is greater than
1 GB",
- bundleSize: 2 * 1024 * 1024 * 1024,
- want: 1024 * 1024 * 1024,
- },
- {
- name: "Return bundle size if bundle size is
between 1 MB and 1 GB",
- bundleSize: 4 * 1024 * 1024,
- want: 4 * 1024 * 1024,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := getChunkSize(tt.bundleSize); got != tt.want {
- t.Errorf("getChunkSize() = %v, want %v", got,
tt.want)
- }
- })
- }
-}
-
-func Test_idFiltersFromSplits(t *testing.T) {
- tests := []struct {
- name string
- splitKeys []splitKey
- want []bson.M
- }{
- {
- name: "Create one $lte filter for start range, one $gt
filter for end range, and filters with both " +
- "$lte and $gt for ranges in between when there
are two or more splitKey elements",
- splitKeys: []splitKey{
- {
- ID: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- {
- ID: objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- want: []bson.M{
- {
- "_id": bson.M{
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5382"),
- },
- },
- },
- },
- {
- name: "Create one $lte filter for start range and one
$gt filter for end range when there is one " +
- "splitKey element",
- splitKeys: []splitKey{
- {
- ID: objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- want: []bson.M{
- {
- "_id": bson.M{
- "$lte": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- {
- "_id": bson.M{
- "$gt": objectIDFromHex(t,
"6384e03f24f854c1a8ce5380"),
- },
- },
- },
- },
- {
- name: "Create an empty filter when there are no
splitKey elements",
- splitKeys: nil,
- want: []bson.M{{}},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := idFiltersFromSplits(tt.splitKeys);
!cmp.Equal(got, tt.want) {
- t.Errorf("idFiltersFromSplits() = %v, want %v",
got, tt.want)
- }
- })
- }
-}
-
func Test_inferProjection(t *testing.T) {
type doc struct {
Field1 string `bson:"field1"`
@@ -313,81 +68,3 @@ func Test_inferProjectionPanic(t *testing.T) {
inferProjection(reflect.TypeOf(doc{}), "bson")
})
}
-
-func Test_mergeFilters(t *testing.T) {
- tests := []struct {
- name string
- idFilter bson.M
- filter bson.M
- want bson.M
- }{
- {
- name: "Returned merged ID filter and custom filter in
an $and filter",
- idFilter: bson.M{
- "_id": bson.M{
- "$gte": 10,
- },
- },
- filter: bson.M{
- "key": bson.M{
- "$ne": "value",
- },
- },
- want: bson.M{
- "$and": []bson.M{
- {
- "_id": bson.M{
- "$gte": 10,
- },
- },
- {
- "key": bson.M{
- "$ne": "value",
- },
- },
- },
- },
- },
- {
- name: "Return only ID filter when custom filter is
empty",
- idFilter: bson.M{
- "_id": bson.M{
- "$gte": 10,
- },
- },
- filter: bson.M{},
- want: bson.M{
- "_id": bson.M{
- "$gte": 10,
- },
- },
- },
- {
- name: "Return only custom filter when ID filter is
empty",
- idFilter: bson.M{},
- filter: bson.M{
- "key": bson.M{
- "$ne": "value",
- },
- },
- want: bson.M{
- "key": bson.M{
- "$ne": "value",
- },
- },
- },
- {
- name: "Return empty filter when both ID filter and
custom filter are empty",
- idFilter: bson.M{},
- filter: bson.M{},
- want: bson.M{},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := mergeFilters(tt.idFilter, tt.filter);
!cmp.Equal(got, tt.want) {
- t.Errorf("mergeFilters() = %v, want %v", got,
tt.want)
- }
- })
- }
-}