lostluck commented on code in PR #38220: URL: https://github.com/apache/beam/pull/38220#discussion_r3197523125
########## sdks/go/pkg/beam/transforms/batch/batch.go: ########## @@ -0,0 +1,685 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package batch provides transforms that group elements of a KV-keyed +// PCollection into batches of a target size for downstream per-batch +// processing (rate-limited API calls, bulk sinks, etc.). +// +// GroupIntoBatches mirrors the behavior of the Java and Python +// transforms of the same name. GroupIntoBatchesWithShardedKey adds +// opaque per-element shard identifiers to the keys so the processing +// of a single hot logical key spreads across multiple workers. +// +// # Behavior +// +// Given a PCollection<KV<K, V>>, GroupIntoBatches buffers values per +// key and emits batches as KV<K, []V> whenever one of the following +// limits is reached: +// +// - len(batch) reaches BatchSize, OR +// - sum of byte sizes reaches BatchSizeBytes, OR +// - MaxBufferingDuration elapses in processing time since the first +// element of the current batch (if set), OR +// - the window advances past MaxTimestamp + AllowedLateness of the +// input PCollection's WindowingStrategy. +// +// Elements of different windows are never combined into the same +// batch. +// +// # Determinism requirement +// +// The key coder MUST be deterministic. State keying depends on +// byte-stable encodings: a non-deterministic key coder would silently +// split the logical key across multiple physical keys, producing +// corrupt batches. The transform panics at pipeline build time if the +// key coder is not known to be deterministic. For user-defined key +// types, register the type's coder via +// coder.RegisterDeterministicCoder. +// +// # Differences from Java/Python +// +// - BatchSize / BatchSizeBytes are int64 (parity with proto and Java +// long, avoiding overflow on 32-bit platforms). +// - BatchSizeBytes is limited to primitive value types ([]byte, +// string, numeric, bool) in this release; opaque V types panic at +// build time if BatchSizeBytes > 0. +// - GroupIntoBatchesWithShardedKey returns PCollection<KV<K, []V>> +// (same shape as GroupIntoBatches), with sharding applied +// internally. The Java/Python variants expose ShardedKey<K> to the +// user; Go does not because the SDK's type-binding engine does not Review Comment: I'll note that as implemented GroupIntoBatchesWithShardedKey doesn't match this description. The doc comment on GroupIntoBatchesWithShardedKey says it returns the ShardedKey wrapped type. ########## sdks/go/pkg/beam/transforms/batch/batch.go: ########## @@ -0,0 +1,685 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package batch provides transforms that group elements of a KV-keyed +// PCollection into batches of a target size for downstream per-batch +// processing (rate-limited API calls, bulk sinks, etc.). +// +// GroupIntoBatches mirrors the behavior of the Java and Python +// transforms of the same name. GroupIntoBatchesWithShardedKey adds +// opaque per-element shard identifiers to the keys so the processing +// of a single hot logical key spreads across multiple workers. +// +// # Behavior +// +// Given a PCollection<KV<K, V>>, GroupIntoBatches buffers values per +// key and emits batches as KV<K, []V> whenever one of the following +// limits is reached: +// +// - len(batch) reaches BatchSize, OR +// - sum of byte sizes reaches BatchSizeBytes, OR +// - MaxBufferingDuration elapses in processing time since the first +// element of the current batch (if set), OR +// - the window advances past MaxTimestamp + AllowedLateness of the +// input PCollection's WindowingStrategy. +// +// Elements of different windows are never combined into the same +// batch. +// +// # Determinism requirement +// +// The key coder MUST be deterministic. State keying depends on +// byte-stable encodings: a non-deterministic key coder would silently +// split the logical key across multiple physical keys, producing +// corrupt batches. The transform panics at pipeline build time if the +// key coder is not known to be deterministic. For user-defined key +// types, register the type's coder via +// coder.RegisterDeterministicCoder. +// +// # Differences from Java/Python +// +// - BatchSize / BatchSizeBytes are int64 (parity with proto and Java +// long, avoiding overflow on 32-bit platforms). +// - BatchSizeBytes is limited to primitive value types ([]byte, +// string, numeric, bool) in this release; opaque V types panic at +// build time if BatchSizeBytes > 0. +// - GroupIntoBatchesWithShardedKey returns PCollection<KV<K, []V>> +// (same shape as GroupIntoBatches), with sharding applied +// internally. The Java/Python variants expose ShardedKey<K> to the +// user; Go does not because the SDK's type-binding engine does not +// accept custom generic structs as DoFn output types. The +// cross-SDK beam:coder:sharded_key:v1 coder is nevertheless wired +// in typex + core/graph/coder so cross-language pipelines can +// round-trip ShardedKey values. +package batch + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx" + beamcoder "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" + "github.com/google/uuid" +) + +// ShardedKey pairs a user key with an opaque shard identifier. It is +// the key type of the PCollection produced by +// GroupIntoBatchesWithShardedKey. +type ShardedKey[K any] struct { + Key K + ShardID []byte +} + +// RegisterShardedKeyType registers a ShardedKey[K] instantiation so +// its coder survives cross-worker serialization. Common key types +// (string, []byte, int, int64) are registered automatically at init. +// Users of other K types must call this at init time. +func RegisterShardedKeyType[K any]() { + var zero K + keyT := reflect.TypeOf(zero) + skT := reflect.TypeOf(ShardedKey[K]{}) + + register.DoFn3x0[K, typex.V, func(ShardedKey[K], typex.V)](&wrapShardedKeyFn[K]{}) + register.Emitter2[ShardedKey[K], typex.V]() + beam.RegisterType(skT) + + keyEnc := beam.NewElementEncoder(keyT) + keyDec := beam.NewElementDecoder(keyT) + + enc := func(sk ShardedKey[K]) []byte { + var buf bytes.Buffer + writeVarInt(&buf, int64(len(sk.ShardID))) + buf.Write(sk.ShardID) + if err := keyEnc.Encode(sk.Key, &buf); err != nil { + panic(err) + } + return buf.Bytes() + } + dec := func(b []byte) ShardedKey[K] { + r := bytes.NewReader(b) + n := readVarInt(r) + shardID := make([]byte, n) + if n > 0 { + if _, err := r.Read(shardID); err != nil { + panic(err) + } + } + k, err := keyDec.Decode(r) + if err != nil { + panic(err) + } + return ShardedKey[K]{Key: k.(K), ShardID: shardID} + } + + // Closures inside generic functions share the same compiler + // symbol name for every type instantiation. We wrap them with a + // type-qualified name so the cross-worker deserializer resolves + // the correct enc/dec for each ShardedKey[K]. + encName := fmt.Sprintf("batch.encShardedKey[%v]", keyT) + decName := fmt.Sprintf("batch.decShardedKey[%v]", keyT) + + encFn := reflectx.MakeFuncWithName(encName, enc) + decFn := reflectx.MakeFuncWithName(decName, dec) + + // Register in the runtime cache under the qualified name so + // ResolveFunction finds them at deserialization time. + runtime.RegisterFunctionWithName(encName, enc) + runtime.RegisterFunctionWithName(decName, dec) + + encWrapped, err := funcx.New(encFn) + if err != nil { + panic(fmt.Sprintf("RegisterShardedKeyType: bad enc for %v: %v", skT, err)) + } + decWrapped, err := funcx.New(decFn) + if err != nil { + panic(fmt.Sprintf("RegisterShardedKeyType: bad dec for %v: %v", skT, err)) + } + + beamcoder.RegisterDeterministicCoderWithFuncs(skT, encWrapped, decWrapped) +} + +// Params configures GroupIntoBatches and +// GroupIntoBatchesWithShardedKey. +// +// At least one of BatchSize or BatchSizeBytes must be > 0. +type Params struct { + // BatchSize is the target maximum number of elements per batch. A + // batch is emitted as soon as it holds BatchSize elements. Zero + // disables the count-based trigger. + BatchSize int64 + + // BatchSizeBytes is the target maximum cumulative byte size per + // batch. A batch is emitted as soon as adding another element + // would exceed BatchSizeBytes. Zero disables the byte-based + // trigger. + BatchSizeBytes int64 + + // MaxBufferingDuration, when > 0, triggers emission of a partial + // batch after this much processing time has elapsed since the + // first element of the current batch was buffered. + MaxBufferingDuration time.Duration +} + +func (p Params) validate() error { + if p.BatchSize < 0 { + return fmt.Errorf("Params.BatchSize must be >= 0; got %d", p.BatchSize) + } + if p.BatchSizeBytes < 0 { + return fmt.Errorf("Params.BatchSizeBytes must be >= 0; got %d", p.BatchSizeBytes) + } + if p.BatchSize == 0 && p.BatchSizeBytes == 0 { + return fmt.Errorf("Params: at least one of BatchSize or BatchSizeBytes must be > 0") + } + if p.MaxBufferingDuration < 0 { + return fmt.Errorf("Params.MaxBufferingDuration must be >= 0; got %s", p.MaxBufferingDuration) + } + return nil +} + +const ( + sizerNone int32 = 0 + sizerPrimitive int32 = 1 +) + +// codecCache keeps a per-value-type ElementEncoder/Decoder pair. +type codecCache struct { + once sync.Once + enc beam.ElementEncoder + dec beam.ElementDecoder +} + +func (c *codecCache) init(t reflect.Type) { + c.once.Do(func() { + c.enc = beam.NewElementEncoder(t) + c.dec = beam.NewElementDecoder(t) + }) +} + +func (c *codecCache) encode(v any) []byte { + var buf bytes.Buffer + if err := c.enc.Encode(v, &buf); err != nil { + panic(err) + } + return buf.Bytes() +} + +func (c *codecCache) decode(b []byte) any { + v, err := c.dec.Decode(bytes.NewReader(b)) + if err != nil { + panic(err) + } + return v +} + +// groupIntoBatchesFn is the stateful DoFn without a processing-time +// buffering timer. +type groupIntoBatchesFn struct { + Buffer state.Bag[[]byte] + Count state.Value[int64] + ByteSize state.Value[int64] + WindowEnd timers.EventTime + + ValueType beam.EncodedType + + BatchSize int64 + BatchSizeBytes int64 + AllowedLatenessMs int64 + SizerKind int32 + + codec codecCache +} + +func (fn *groupIntoBatchesFn) ProcessElement( + w beam.Window, sp state.Provider, tp timers.Provider, + key typex.T, value typex.V, emit func(typex.T, []typex.V), +) { + fn.codec.init(fn.ValueType.T) + + count, _, err := fn.Count.Read(sp) + if err != nil { + panic(err) + } + + if w.MaxTimestamp() < mtime.MaxTimestamp { + windowEnd := w.MaxTimestamp().ToTime() + if fn.AllowedLatenessMs > 0 { + windowEnd = windowEnd.Add(time.Duration(fn.AllowedLatenessMs) * time.Millisecond) + } + fn.WindowEnd.Set(tp, windowEnd, timers.WithNoOutputTimestamp()) + } + + if err := fn.Buffer.Add(sp, fn.codec.encode(value)); err != nil { + panic(err) + } + count++ + if err := fn.Count.Write(sp, count); err != nil { + panic(err) + } + + newBytes := int64(0) + if fn.BatchSizeBytes > 0 { + cur, _, err := fn.ByteSize.Read(sp) + if err != nil { + panic(err) + } + cur += sizeOf(fn.SizerKind, value) + if err := fn.ByteSize.Write(sp, cur); err != nil { + panic(err) + } + newBytes = cur + } + + if fn.BatchSize > 0 && count >= fn.BatchSize { + fn.flush(sp, key, emit) + return + } + if fn.BatchSizeBytes > 0 && newBytes >= fn.BatchSizeBytes { + fn.flush(sp, key, emit) + return + } +} + +func (fn *groupIntoBatchesFn) OnTimer( + ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, + key typex.T, timer timers.Context, emit func(typex.T, []typex.V), +) { + if timer.Family != fn.WindowEnd.Family { + panic(fmt.Sprintf("batch.groupIntoBatchesFn: unexpected timer family %q", timer.Family)) + } + fn.codec.init(fn.ValueType.T) + fn.flush(sp, key, emit) +} + +func (fn *groupIntoBatchesFn) flush( + sp state.Provider, key typex.T, emit func(typex.T, []typex.V), +) { + buf, ok, err := fn.Buffer.Read(sp) + if err != nil { + panic(err) + } + if !ok || len(buf) == 0 { + return + } + + out := make([]typex.V, len(buf)) + for i, b := range buf { + out[i] = fn.codec.decode(b) + } + emit(key, out) + + if err := fn.Buffer.Clear(sp); err != nil { + panic(err) + } + if err := fn.Count.Clear(sp); err != nil { + panic(err) + } + if fn.BatchSizeBytes > 0 { + if err := fn.ByteSize.Clear(sp); err != nil { + panic(err) + } + } +} + +// groupIntoBatchesBufferedFn adds a processing-time buffering timer. +type groupIntoBatchesBufferedFn struct { + Buffer state.Bag[[]byte] + Count state.Value[int64] + ByteSize state.Value[int64] + TimerSet state.Value[bool] + Buffering timers.ProcessingTime + WindowEnd timers.EventTime + + ValueType beam.EncodedType + + BatchSize int64 + BatchSizeBytes int64 + MaxBufferingMs int64 + AllowedLatenessMs int64 + SizerKind int32 + + codec codecCache +} + +func (fn *groupIntoBatchesBufferedFn) ProcessElement( + w beam.Window, sp state.Provider, tp timers.Provider, + key typex.T, value typex.V, emit func(typex.T, []typex.V), +) { + fn.codec.init(fn.ValueType.T) + + count, _, err := fn.Count.Read(sp) + if err != nil { + panic(err) + } + + if w.MaxTimestamp() < mtime.MaxTimestamp { + windowEnd := w.MaxTimestamp().ToTime() + if fn.AllowedLatenessMs > 0 { + windowEnd = windowEnd.Add(time.Duration(fn.AllowedLatenessMs) * time.Millisecond) + } + fn.WindowEnd.Set(tp, windowEnd, timers.WithNoOutputTimestamp()) + } + + if err := fn.Buffer.Add(sp, fn.codec.encode(value)); err != nil { + panic(err) + } + count++ + if err := fn.Count.Write(sp, count); err != nil { + panic(err) + } + + newBytes := int64(0) + if fn.BatchSizeBytes > 0 { + cur, _, err := fn.ByteSize.Read(sp) + if err != nil { + panic(err) + } + cur += sizeOf(fn.SizerKind, value) + if err := fn.ByteSize.Write(sp, cur); err != nil { + panic(err) + } + newBytes = cur + } + + if count == 1 { + fn.Buffering.Set(tp, time.Now().Add(time.Duration(fn.MaxBufferingMs)*time.Millisecond)) + if err := fn.TimerSet.Write(sp, true); err != nil { + panic(err) + } + } + + if fn.BatchSize > 0 && count >= fn.BatchSize { + fn.flush(sp, tp, key, emit) + return + } + if fn.BatchSizeBytes > 0 && newBytes >= fn.BatchSizeBytes { + fn.flush(sp, tp, key, emit) + return + } +} + +func (fn *groupIntoBatchesBufferedFn) OnTimer( + ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, + key typex.T, timer timers.Context, emit func(typex.T, []typex.V), +) { + fn.codec.init(fn.ValueType.T) + switch timer.Family { + case fn.Buffering.Family, fn.WindowEnd.Family: + fn.flush(sp, tp, key, emit) + default: + panic(fmt.Sprintf( + "batch.groupIntoBatchesBufferedFn: unexpected timer family %q", timer.Family)) + } +} + +func (fn *groupIntoBatchesBufferedFn) flush( + sp state.Provider, tp timers.Provider, key typex.T, emit func(typex.T, []typex.V), +) { + buf, ok, err := fn.Buffer.Read(sp) + if err != nil { + panic(err) + } + if !ok || len(buf) == 0 { + return + } + + out := make([]typex.V, len(buf)) + for i, b := range buf { + out[i] = fn.codec.decode(b) + } + emit(key, out) + + if err := fn.Buffer.Clear(sp); err != nil { + panic(err) + } + if err := fn.Count.Clear(sp); err != nil { + panic(err) + } + if fn.BatchSizeBytes > 0 { + if err := fn.ByteSize.Clear(sp); err != nil { + panic(err) + } + } + setBool, _, err := fn.TimerSet.Read(sp) + if err != nil { + panic(err) + } + if setBool { + fn.Buffering.Clear(tp) + if err := fn.TimerSet.Clear(sp); err != nil { + panic(err) + } + } +} + +func sizeOf(kind int32, v any) int64 { + switch kind { + case sizerNone: + return 0 + case sizerPrimitive: + if size, ok := defaultElementByteSize(v); ok { + return size + } + panic(fmt.Sprintf("batch: sizerPrimitive cannot size value of type %T", v)) + default: + panic(fmt.Sprintf("batch: unknown sizer kind %d", kind)) + } +} + +// wrapShardedKeyFn maps KV<K, V> → KV<ShardedKey[K], V>. +type wrapShardedKeyFn[K any] struct{} + +func (*wrapShardedKeyFn[K]) ProcessElement( + key K, value typex.V, emit func(ShardedKey[K], typex.V), +) { + emit(ShardedKey[K]{Key: key, ShardID: makeShardID()}, value) +} + +var ( + workerUUIDOnce sync.Once + workerUUIDVal [16]byte + shardCounter atomic.Uint64 +) + +// makeShardID returns a 24-byte shard identifier: a 16-byte worker +// UUID fixed per process plus an 8-byte atomic counter, big-endian. +// The layout mirrors the Java and Python shapes exactly so the wire +// bytes of cross-language round-trips remain aligned. +func makeShardID() []byte { + workerUUIDOnce.Do(func() { + b, err := uuid.New().MarshalBinary() + if err != nil { + panic(fmt.Sprintf("batch: failed to marshal worker UUID: %v", err)) + } + copy(workerUUIDVal[:], b) + }) + out := make([]byte, 24) + copy(out[:16], workerUUIDVal[:]) + counter := shardCounter.Add(1) + binary.BigEndian.PutUint64(out[16:24], counter) + return out +} + +// writeVarInt writes a varint-encoded int64 to buf (unsigned, +// little-endian base-128). +func writeVarInt(buf *bytes.Buffer, v int64) { Review Comment: I assume these are duplicated here (rather than using the ones in https://github.com/apache/beam/blob/6fc8fdebb2349d9efbb1b562407e08d787d01d8f/sdks/go/pkg/beam/core/graph/coder/varint.go#L89) to avoid the (largely unnecessary) error checks, and just directly use the bytes.Buffer methods? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
