This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 79ea2e8562c Add support for sharding while avro write (#36933)
79ea2e8562c is described below

commit 79ea2e8562cdb748c56709292d1f42fd6491cf8e
Author: CherisPatelInfocusp 
<[email protected]>
AuthorDate: Wed Dec 3 01:54:07 2025 +0530

    Add support for sharding while avro write (#36933)
---
 sdks/go/pkg/beam/io/avroio/avroio.go      | 118 +++++++++++++++++++++++++++---
 sdks/go/pkg/beam/io/avroio/avroio_test.go | 108 ++++++++++++++++++++++++++-
 2 files changed, 212 insertions(+), 14 deletions(-)

diff --git a/sdks/go/pkg/beam/io/avroio/avroio.go 
b/sdks/go/pkg/beam/io/avroio/avroio.go
index 809c9479f7a..3a116a74f55 100644
--- a/sdks/go/pkg/beam/io/avroio/avroio.go
+++ b/sdks/go/pkg/beam/io/avroio/avroio.go
@@ -19,6 +19,8 @@ package avroio
 import (
        "context"
        "encoding/json"
+       "fmt"
+       "math/rand"
        "reflect"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam"
@@ -32,7 +34,10 @@ import (
 func init() {
        register.DoFn3x1[context.Context, fileio.ReadableFile, func(beam.X), 
error]((*avroReadFn)(nil))
        register.DoFn3x1[context.Context, int, func(*string) bool, 
error]((*writeAvroFn)(nil))
+       register.DoFn2x0[string, func(int, string)]((*roundRobinKeyFn)(nil))
        register.Emitter1[beam.X]()
+       register.Emitter1[string]()
+       register.Emitter2[int, string]()
        register.Iter1[string]()
 }
 
@@ -109,32 +114,121 @@ func (f *avroReadFn) ProcessElement(ctx context.Context, 
file fileio.ReadableFil
        return ar.Err()
 }
 
+type WriteOption func(*writeConfig)
+
+type writeConfig struct {
+       suffix    string
+       numShards int
+}
+
+// WithSuffix sets the file suffix (default: ".avro")
+func WithSuffix(suffix string) WriteOption {
+       return func(c *writeConfig) {
+               c.suffix = suffix
+       }
+}
+
+// WithNumShards sets the number of output shards (default: 1)
+func WithNumShards(numShards int) WriteOption {
+       return func(c *writeConfig) {
+               c.numShards = numShards
+       }
+}
+
 // Write writes a PCollection<string> to an AVRO file.
 // Write expects a JSON string with a matching AVRO schema.
 // the process will fail if the schema does not match the JSON
 // provided
-func Write(s beam.Scope, filename, schema string, col beam.PCollection) {
-       s = s.Scope("avroio.Write")
-       filesystem.ValidateScheme(filename)
-       pre := beam.AddFixedKey(s, col)
-       post := beam.GroupByKey(s, pre)
-       beam.ParDo0(s, &writeAvroFn{Schema: schema, Filename: filename}, post)
+//
+// Parameters:
+//
+//     prefix: File path prefix (e.g., "gs://bucket/output")
+//     suffix: File extension (e.g., ".avro")
+//     numShards: Number of output files (0 or 1 for single file)
+//     schema: AVRO schema as JSON string
+//
+// Files are named as: <prefix>-<shard>-of-<numShards><suffix>
+// Example: output-00000-of-00010.avro
+//
+// Examples:
+//
+//     Write(s, "gs://bucket/output", schema, col)                             
       // output-00000-of-00001.avro (defaults)
+//     Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"))        
       // output-00000-of-00001.avro (explicit)
+//     Write(s, "gs://bucket/output", schema, col, WithNumShards(10))          
       // output-00000-of-00010.avro (10 shards)
+//     Write(s, "gs://bucket/output", schema, col, WithSuffix(".avro"), 
WithNumShards(10)) // full control
+func Write(s beam.Scope, prefix, schema string, col beam.PCollection, opts 
...WriteOption) {
+       s = s.Scope("avroio.WriteSharded")
+       filesystem.ValidateScheme(prefix)
+
+       config := &writeConfig{
+               suffix:    ".avro",
+               numShards: 1,
+       }
+
+       for _, opt := range opts {
+               opt(config)
+       }
+
+       // Default to single shard if not specified or 0
+       if config.numShards <= 0 {
+               config.numShards = 1
+       }
+
+       keyed := beam.ParDo(s, &roundRobinKeyFn{NumShards: config.numShards}, 
col)
+
+       grouped := beam.GroupByKey(s, keyed)
+
+       beam.ParDo0(s, &writeAvroFn{
+               Prefix:    prefix,
+               NumShards: config.numShards,
+               Suffix:    config.suffix,
+               Schema:    schema,
+       }, grouped)
+}
+
+type roundRobinKeyFn struct {
+       NumShards   int `json:"num_shards"`
+       counter     int
+       initialized bool
+}
+
+func (f *roundRobinKeyFn) StartBundle(emit func(int, string)) {
+       f.initialized = false
+}
+
+func (f *roundRobinKeyFn) ProcessElement(element string, emit func(int, 
string)) {
+       if !f.initialized {
+               f.counter = rand.Intn(f.NumShards)
+               f.initialized = true
+       }
+       emit(f.counter, element)
+       f.counter = (f.counter + 1) % f.NumShards
+}
+
+// formatShardName creates filename: prefix-SSSSS-of-NNNNN.suffix
+func formatShardName(prefix, suffix string, shardNum, numShards int) string {
+       width := max(len(fmt.Sprintf("%d", numShards-1)), 5)
+       return fmt.Sprintf("%s-%0*d-of-%0*d%s", prefix, width, shardNum, width, 
numShards, suffix)
 }
 
 type writeAvroFn struct {
-       Schema   string `json:"schema"`
-       Filename string `json:"filename"`
+       Prefix    string `json:"prefix"`
+       Suffix    string `json:"suffix"`
+       NumShards int    `json:"num_shards"`
+       Schema    string `json:"schema"`
 }
 
-func (w *writeAvroFn) ProcessElement(ctx context.Context, _ int, lines 
func(*string) bool) (err error) {
-       log.Infof(ctx, "writing AVRO to %s", w.Filename)
-       fs, err := filesystem.New(ctx, w.Filename)
+func (w *writeAvroFn) ProcessElement(ctx context.Context, shardNum int, lines 
func(*string) bool) (err error) {
+       filename := formatShardName(w.Prefix, w.Suffix, shardNum, w.NumShards)
+       log.Infof(ctx, "Writing AVRO shard %d/%d to %s", shardNum+1, 
w.NumShards, filename)
+
+       fs, err := filesystem.New(ctx, filename)
        if err != nil {
                return
        }
        defer fs.Close()
 
-       fd, err := fs.OpenWrite(ctx, w.Filename)
+       fd, err := fs.OpenWrite(ctx, filename)
        if err != nil {
                return
        }
diff --git a/sdks/go/pkg/beam/io/avroio/avroio_test.go 
b/sdks/go/pkg/beam/io/avroio/avroio_test.go
index 403a8187555..2e888b0e040 100644
--- a/sdks/go/pkg/beam/io/avroio/avroio_test.go
+++ b/sdks/go/pkg/beam/io/avroio/avroio_test.go
@@ -19,7 +19,9 @@ import (
        "bytes"
        "encoding/json"
        "errors"
+       "fmt"
        "os"
+       "path/filepath"
        "reflect"
        "testing"
 
@@ -141,15 +143,29 @@ const userSchema = `{
 }`
 
 func TestWrite(t *testing.T) {
-       avroFile := "./user.avro"
+       testWriteDefaults(t)
+}
+
+func TestWriteWithOptions(t *testing.T) {
+       testWriteWithOptions(t, 3)
+}
+
+func testWriteDefaults(t *testing.T) {
+       avroPrefix := "./user"
+       numShards := 1
+       avroSuffix := ".avro"
        testUsername := "user1"
        testInfo := "userInfo"
+
        p, s, sequence := ptest.CreateList([]TwitterUser{{
                User: testUsername,
                Info: testInfo,
        }})
        format := beam.ParDo(s, toJSONString, sequence)
-       Write(s, avroFile, userSchema, format)
+
+       Write(s, avroPrefix, userSchema, format)
+
+       avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 0, numShards, 
avroSuffix)
        t.Cleanup(func() {
                os.Remove(avroFile)
        })
@@ -189,3 +205,91 @@ func TestWrite(t *testing.T) {
                t.Fatalf("User.User=%v, want %v", got, want)
        }
 }
+
+func testWriteWithOptions(t *testing.T, numShards int) {
+       avroPrefix := "./users"
+       avroSuffix := ".avro"
+       users := []TwitterUser{
+               {User: "user1", Info: "info1"},
+               {User: "user2", Info: "info2"},
+               {User: "user3", Info: "info3"},
+               {User: "user4", Info: "info4"},
+               {User: "user5", Info: "info5"},
+       }
+
+       p, s, sequence := ptest.CreateList(users)
+       format := beam.ParDo(s, toJSONString, sequence)
+
+       Write(s, avroPrefix, userSchema, format, WithNumShards(numShards))
+
+       t.Cleanup(func() {
+               pattern := fmt.Sprintf("%s-*-of-%s%s", avroPrefix, 
fmt.Sprintf("%05d", numShards), avroSuffix)
+               files, err := filepath.Glob(pattern)
+               if err == nil {
+                       for _, f := range files {
+                               os.Remove(f)
+                       }
+               }
+       })
+
+       ptest.RunAndValidate(t, p)
+
+       var allRecords []map[string]any
+       recordCounts := make(map[int]int)
+
+       for shardNum := 0; shardNum < numShards; shardNum++ {
+               avroFile := fmt.Sprintf("%s-%05d-of-%05d%s", avroPrefix, 
shardNum, numShards, avroSuffix)
+
+               if _, err := os.Stat(avroFile); errors.Is(err, os.ErrNotExist) {
+                       continue
+               }
+
+               avroBytes, err := os.ReadFile(avroFile)
+               if err != nil {
+                       t.Fatalf("Failed to read avro file %v: %v", avroFile, 
err)
+               }
+               ocf, err := goavro.NewOCFReader(bytes.NewReader(avroBytes))
+               if err != nil {
+                       t.Fatalf("Failed to make OCF Reader for %v: %v", 
avroFile, err)
+               }
+               shardRecordCount := 0
+               for ocf.Scan() {
+                       datum, err := ocf.Read()
+                       if err != nil {
+                               break
+                       }
+                       allRecords = append(allRecords, datum.(map[string]any))
+                       shardRecordCount++
+               }
+
+               recordCounts[shardNum] = shardRecordCount
+
+               if err := ocf.Err(); err != nil {
+                       t.Fatalf("Error decoding avro data from %v: %v", 
avroFile, err)
+               }
+       }
+
+       if got, want := len(allRecords), len(users); got != want {
+               t.Fatalf("Total records across all shards, got %v, want %v", 
got, want)
+       }
+
+       hasRecords := false
+       for _, count := range recordCounts {
+               if count > 0 {
+                       hasRecords = true
+               }
+       }
+       if !hasRecords {
+               t.Fatal("No records found in any shard")
+       }
+       foundUsers := make(map[string]bool)
+       for _, record := range allRecords {
+               username := record["username"].(string)
+               foundUsers[username] = true
+       }
+       for _, user := range users {
+               if !foundUsers[user.User] {
+                       t.Fatalf("Expected user %v not found in any shard", 
user.User)
+               }
+       }
+}

Reply via email to