This is an automated email from the ASF dual-hosted git repository.
jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 79ea2e8562c Add support for sharding while avro write (#36933)
79ea2e8562c is described below
commit 79ea2e8562cdb748c56709292d1f42fd6491cf8e
Author: CherisPatelInfocusp
<[email protected]>
AuthorDate: Wed Dec 3 01:54:07 2025 +0530
Add support for sharding while avro write (#36933)
---
sdks/go/pkg/beam/io/avroio/avroio.go | 118 +++++++++++++++++++++++++++---
sdks/go/pkg/beam/io/avroio/avroio_test.go | 108 ++++++++++++++++++++++++++-
2 files changed, 212 insertions(+), 14 deletions(-)
diff --git a/sdks/go/pkg/beam/io/avroio/avroio.go
b/sdks/go/pkg/beam/io/avroio/avroio.go
index 809c9479f7a..3a116a74f55 100644
--- a/sdks/go/pkg/beam/io/avroio/avroio.go
+++ b/sdks/go/pkg/beam/io/avroio/avroio.go
@@ -19,6 +19,8 @@ package avroio
import (
"context"
"encoding/json"
+ "fmt"
+ "math/rand"
"reflect"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
@@ -32,7 +34,10 @@ import (
func init() {
register.DoFn3x1[context.Context, fileio.ReadableFile, func(beam.X),
error]((*avroReadFn)(nil))
register.DoFn3x1[context.Context, int, func(*string) bool,
error]((*writeAvroFn)(nil))
+ register.DoFn2x0[string, func(int, string)]((*roundRobinKeyFn)(nil))
register.Emitter1[beam.X]()
+ register.Emitter1[string]()
+ register.Emitter2[int, string]()
register.Iter1[string]()
}
@@ -109,32 +114,121 @@ func (f *avroReadFn) ProcessElement(ctx context.Context,
file fileio.ReadableFil
return ar.Err()
}
+type WriteOption func(*writeConfig)
+
+type writeConfig struct {
+ suffix string
+ numShards int
+}
+
+// WithSuffix sets the file suffix (default: ".avro")
+func WithSuffix(suffix string) WriteOption {
+ return func(c *writeConfig) {
+ c.suffix = suffix
+ }
+}
+
+// WithNumShards sets the number of output shards (default: 1)
+func WithNumShards(numShards int) WriteOption {
+ return func(c *writeConfig) {
+ c.numShards = numShards
+ }
+}
+
// Write writes a PCollection<string> to an AVRO file.
// Write expects a JSON string with a matching AVRO schema.
// the process will fail if the schema does not match the JSON
// provided
-func Write(s beam.Scope, filename, schema string, col beam.PCollection) {
- s = s.Scope("avroio.Write")
- filesystem.ValidateScheme(filename)
- pre := beam.AddFixedKey(s, col)
- post := beam.GroupByKey(s, pre)
- beam.ParDo0(s, &writeAvroFn{Schema: schema, Filename: filename}, post)
+//
+// Parameters:
+//
+// prefix: File path prefix (e.g., "gs://bucket/output")
+// suffix: File extension (e.g., ".avro")
+// numShards: Number of output files (0 or 1 for single file)
+// schema: AVRO schema as JSON string
+//
+// Files are named as: <prefix>-<shard>-of-<numShards><suffix>
+// Example: output-00000-of-00010.avro
+//
+// Examples:
+//
+// Write(s, "gs://bucket/output", schema, col)
// output-00000-of-00001.avro (defaults)
+// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"))
// output-00000-of-00001.avro (explicit)
+// Write(s, "gs://bucket/output", schema, col, WithNumShards(10))
// output-00000-of-00010.avro (10 shards)
+// Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"),
WithNumShards(10)) // full control
+func Write(s beam.Scope, prefix, schema string, col beam.PCollection, opts
...WriteOption) {
+ s = s.Scope("avroio.WriteSharded")
+ filesystem.ValidateScheme(prefix)
+
+ config := &writeConfig{
+ suffix: ".avro",
+ numShards: 1,
+ }
+
+ for _, opt := range opts {
+ opt(config)
+ }
+
+ // Default to single shard if not specified or 0
+ if config.numShards <= 0 {
+ config.numShards = 1
+ }
+
+ keyed := beam.ParDo(s, &roundRobinKeyFn{NumShards: config.numShards},
col)
+
+ grouped := beam.GroupByKey(s, keyed)
+
+ beam.ParDo0(s, &writeAvroFn{
+ Prefix: prefix,
+ NumShards: config.numShards,
+ Suffix: config.suffix,
+ Schema: schema,
+ }, grouped)
+}
+
+type roundRobinKeyFn struct {
+ NumShards int `json:"num_shards"`
+ counter int
+ initialized bool
+}
+
+func (f *roundRobinKeyFn) StartBundle(emit func(int, string)) {
+ f.initialized = false
+}
+
+func (f *roundRobinKeyFn) ProcessElement(element string, emit func(int,
string)) {
+ if !f.initialized {
+ f.counter = rand.Intn(f.NumShards)
+ f.initialized = true
+ }
+ emit(f.counter, element)
+ f.counter = (f.counter + 1) % f.NumShards
+}
+
+// formatShardName creates filename: prefix-SSSSS-of-NNNNN.suffix
+func formatShardName(prefix, suffix string, shardNum, numShards int) string {
+ width := max(len(fmt.Sprintf("%d", numShards-1)), 5)
+ return fmt.Sprintf("%s-%0*d-of-%0*d%s", prefix, width, shardNum, width,
numShards, suffix)
}
type writeAvroFn struct {
- Schema string `json:"schema"`
- Filename string `json:"filename"`
+ Prefix string `json:"prefix"`
+ Suffix string `json:"suffix"`
+ NumShards int `json:"num_shards"`
+ Schema string `json:"schema"`
}
-func (w *writeAvroFn) ProcessElement(ctx context.Context, _ int, lines
func(*string) bool) (err error) {
- log.Infof(ctx, "writing AVRO to %s", w.Filename)
- fs, err := filesystem.New(ctx, w.Filename)
+func (w *writeAvroFn) ProcessElement(ctx context.Context, shardNum int, lines
func(*string) bool) (err error) {
+ filename := formatShardName(w.Prefix, w.Suffix, shardNum, w.NumShards)
+ log.Infof(ctx, "Writing AVRO shard %d/%d to %s", shardNum+1,
w.NumShards, filename)
+
+ fs, err := filesystem.New(ctx, filename)
if err != nil {
return
}
defer fs.Close()
- fd, err := fs.OpenWrite(ctx, w.Filename)
+ fd, err := fs.OpenWrite(ctx, filename)
if err != nil {
return
}
diff --git a/sdks/go/pkg/beam/io/avroio/avroio_test.go
b/sdks/go/pkg/beam/io/avroio/avroio_test.go
index 403a8187555..2e888b0e040 100644
--- a/sdks/go/pkg/beam/io/avroio/avroio_test.go
+++ b/sdks/go/pkg/beam/io/avroio/avroio_test.go
@@ -19,7 +19,9 @@ import (
"bytes"
"encoding/json"
"errors"
+ "fmt"
"os"
+ "path/filepath"
"reflect"
"testing"
@@ -141,15 +143,29 @@ const userSchema = `{
}`
func TestWrite(t *testing.T) {
- avroFile := "./user.avro"
+ testWriteDefaults(t)
+}
+
+func TestWriteWithOptions(t *testing.T) {
+ testWriteWithOptions(t, 3)
+}
+
+func testWriteDefaults(t *testing.T) {
+ avroPrefix := "./user"
+ numShards := 1
+ avroSuffix := ".avro"
testUsername := "user1"
testInfo := "userInfo"
+
p, s, sequence := ptest.CreateList([]TwitterUser{{
User: testUsername,
Info: testInfo,
}})
format := beam.ParDo(s, toJSONString, sequence)
- Write(s, avroFile, userSchema, format)
+
+ Write(s, avroPrefix, userSchema, format)
+
+ avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 0, numShards,
avroSuffix)
t.Cleanup(func() {
os.Remove(avroFile)
})
@@ -189,3 +205,91 @@ func TestWrite(t *testing.T) {
t.Fatalf("User.User=%v, want %v", got, want)
}
}
+
+func testWriteWithOptions(t *testing.T, numShards int) {
+ avroPrefix := "./users"
+ avroSuffix := ".avro"
+ users := []TwitterUser{
+ {User: "user1", Info: "info1"},
+ {User: "user2", Info: "info2"},
+ {User: "user3", Info: "info3"},
+ {User: "user4", Info: "info4"},
+ {User: "user5", Info: "info5"},
+ }
+
+ p, s, sequence := ptest.CreateList(users)
+ format := beam.ParDo(s, toJSONString, sequence)
+
+ Write(s, avroPrefix, userSchema, format, WithNumShards(numShards))
+
+ t.Cleanup(func() {
+ pattern := fmt.Sprintf("%s-*-of-%s%s", avroPrefix,
fmt.Sprintf("%05d", numShards), avroSuffix)
+ files, err := filepath.Glob(pattern)
+ if err == nil {
+ for _, f := range files {
+ os.Remove(f)
+ }
+ }
+ })
+
+ ptest.RunAndValidate(t, p)
+
+ var allRecords []map[string]any
+ recordCounts := make(map[int]int)
+
+ for shardNum := 0; shardNum < numShards; shardNum++ {
+ avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix,
shardNum, numShards, avroSuffix)
+
+ if _, err := os.Stat(avroFile); errors.Is(err, os.ErrNotExist) {
+ continue
+ }
+
+ avroBytes, err := os.ReadFile(avroFile)
+ if err != nil {
+ t.Fatalf("Failed to read avro file %v: %v", avroFile,
err)
+ }
+ ocf, err := goavro.NewOCFReader(bytes.NewReader(avroBytes))
+ if err != nil {
+ t.Fatalf("Failed to make OCF Reader for %v: %v",
avroFile, err)
+ }
+ shardRecordCount := 0
+ for ocf.Scan() {
+ datum, err := ocf.Read()
+ if err != nil {
+ break
+ }
+ allRecords = append(allRecords, datum.(map[string]any))
+ shardRecordCount++
+ }
+
+ recordCounts[shardNum] = shardRecordCount
+
+ if err := ocf.Err(); err != nil {
+ t.Fatalf("Error decoding avro data from %v: %v",
avroFile, err)
+ }
+ }
+
+ if got, want := len(allRecords), len(users); got != want {
+ t.Fatalf("Total records across all shards, got %v, want %v",
got, want)
+ }
+
+ hasRecords := false
+ for _, count := range recordCounts {
+ if count > 0 {
+ hasRecords = true
+ }
+ }
+ if !hasRecords {
+ t.Fatal("No records found in any shard")
+ }
+ foundUsers := make(map[string]bool)
+ for _, record := range allRecords {
+ username := record["username"].(string)
+ foundUsers[username] = true
+ }
+ for _, user := range users {
+ if !foundUsers[user.User] {
+ t.Fatalf("Expected user %v not found in any shard",
user.User)
+ }
+ }
+}