This is an automated email from the ASF dual-hosted git repository. lostluck pushed a commit to branch timermgr in repository https://gitbox.apache.org/repos/asf/beam.git
commit 1fd1390b02d3a58457b5cce9be1696a8c93e9580 Author: lostluck <[email protected]> AuthorDate: Wed Mar 1 12:50:04 2023 -0800 Timers received! --- sdks/go/examples/streaming_wordcap/wordcap.go | 211 +++++++++++- sdks/go/pkg/beam/core/runtime/exec/data.go | 16 +- sdks/go/pkg/beam/core/runtime/exec/datasource.go | 161 +++++---- .../pkg/beam/core/runtime/exec/datasource_test.go | 9 +- sdks/go/pkg/beam/core/runtime/exec/fn.go | 6 +- sdks/go/pkg/beam/core/runtime/exec/timers.go | 18 +- sdks/go/pkg/beam/core/runtime/harness/datamgr.go | 374 +++++++++++++-------- .../pkg/beam/core/runtime/harness/datamgr_test.go | 141 -------- sdks/go/pkg/beam/core/timers/timers.go | 2 +- 9 files changed, 560 insertions(+), 378 deletions(-) diff --git a/sdks/go/examples/streaming_wordcap/wordcap.go b/sdks/go/examples/streaming_wordcap/wordcap.go index ddd9eab4e5f..441d7f6d324 100644 --- a/sdks/go/examples/streaming_wordcap/wordcap.go +++ b/sdks/go/examples/streaming_wordcap/wordcap.go @@ -26,16 +26,21 @@ package main import ( "context" "flag" + "fmt" "os" - "strings" + "time" "github.com/apache/beam/sdks/v2/go/pkg/beam" - "github.com/apache/beam/sdks/v2/go/pkg/beam/io/pubsubio" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" - "github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts" - "github.com/apache/beam/sdks/v2/go/pkg/beam/util/pubsubx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/x/beamx" "github.com/apache/beam/sdks/v2/go/pkg/beam/x/debug" + "golang.org/x/exp/slog" ) var ( @@ -50,32 +55,204 @@ var ( } ) +type Stateful struct { + ElementBag state.Bag[string] + TimerTime state.Value[int64] + MinTime state.Combining[int64, int64, int64] + + OutputState timers.ProcessingTime +} + +func NewStateful() *Stateful { + return &Stateful{ + ElementBag: state.MakeBagState[string]("elementBag"), + TimerTime: state.MakeValueState[int64]("timerTime"), + MinTime: state.MakeCombiningState[int64, int64, int64]("minTiInBag", func(a, b int64) int64 { + if a < b { + return a + } + return b + }), + + OutputState: timers.InProcessingTime("outputState"), + } +} + +func (s *Stateful) ProcessElement(ctx context.Context, ts beam.EventTime, sp state.Provider, tp timers.Provider, key, word string, emit func(string, string)) error { + log.Infof(ctx, "stateful dofn invoked key: %v word: %v", key, word) + + s.ElementBag.Add(sp, word) + s.MinTime.Add(sp, int64(ts)) + + toFire, ok, err := s.TimerTime.Read(sp) + if err != nil { + return err + } + if !ok { + toFire = int64(mtime.Now().Add(1 * time.Minute)) + } + minTime, _, err := s.MinTime.Read(sp) + if err != nil { + return err + } + + s.OutputState.SetWithOpts(tp, mtime.Time(toFire).ToTime(), timers.Opts{Hold: mtime.Time(minTime).ToTime()}) + s.TimerTime.Write(sp, toFire) + log.Infof(ctx, "stateful dofn key: %v word: %v, timer: %v, minTime: %v", key, word, toFire, minTime) + + // // Get the Value stored in our state + // val, ok, err := s.Val.Read(p) + // if err != nil { + // return err + // } + // log.Infof(ctx, "stateful dofn state read key: %v word: %v val: %v", key, word, val) + // if !ok { + // s.Val.Write(p, 1) + // } else { + // s.Val.Write(p, val+1) + // } + + // if val > 5 { + // log.Infof(ctx, "stateful dofn clearing key: %v word: %v val: %v", key, word, val) + // // Example of clearing and starting again with an empty bag + // s.Val.Clear(p) + // } + // fire := time.Now().Add(10 * time.Second) + + // log.Infof(ctx, "stateful dofn timer family: %v fire: %v now: %v key: %v word: %v", s.Fire.Family, fire, time.Now(), key, word) + // s.Fire.Set(tp, fire) + + // emit(key, word) + + return nil +} + +type eventtimeSDFStream struct { + RestSize, Mod, Fixed int64 + Sleep time.Duration +} + +func (fn *eventtimeSDFStream) Setup() error { + return nil +} + +func (fn *eventtimeSDFStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction { + return offsetrange.Restriction{Start: 0, End: fn.RestSize} +} + +func (fn *eventtimeSDFStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction { + // No split + return []offsetrange.Restriction{r} +} + +func (fn *eventtimeSDFStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 { + return r.Size() +} + +func (fn *eventtimeSDFStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker { + return sdf.NewLockRTracker(offsetrange.NewTracker(r)) +} + +func (fn *eventtimeSDFStream) ProcessElement(ctx context.Context, _ *CWE, rt *sdf.LockRTracker, v beam.T, emit func(beam.EventTime, int64)) sdf.ProcessContinuation { + r := rt.GetRestriction().(offsetrange.Restriction) + i := r.Start + if r.Size() < 1 { + log.Debugf(ctx, "size 0 restriction, stoping to process sentinel", slog.Any("value", v)) + return sdf.StopProcessing() + } + slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction", + slog.Any("value", v), + slog.Float64("size", r.Size()), + slog.Int64("pos", i), + )) + if rt.TryClaim(i) { + v := (i % fn.Mod) + fn.Fixed + emit(mtime.Now(), v) + } + return sdf.ResumeProcessingIn(fn.Sleep) +} + +func (fn *eventtimeSDFStream) InitialWatermarkEstimatorState(_ beam.EventTime, _ offsetrange.Restriction, _ beam.T) int64 { + return int64(mtime.MinTimestamp) +} + +func (fn *eventtimeSDFStream) CreateWatermarkEstimator(initialState int64) *CWE { + return &CWE{Watermark: initialState} +} + +func (fn *eventtimeSDFStream) WatermarkEstimatorState(e *CWE) int64 { + return e.Watermark +} + +type CWE struct { + Watermark int64 // uses int64, since the SDK prevent mtime.Time from serialization. +} + +func (e *CWE) CurrentWatermark() time.Time { + return mtime.Time(e.Watermark).ToTime() +} + +func (e *CWE) ObserveTimestamp(ts time.Time) { + // We add 10 milliseconds to allow window boundaries to + // progress after emitting + e.Watermark = int64(mtime.FromTime(ts.Add(-90 * time.Millisecond))) +} + +func init() { + register.DoFn7x1[context.Context, beam.EventTime, state.Provider, timers.Provider, string, string, func(string, string), error](&Stateful{}) + register.Emitter2[string, string]() + register.DoFn5x1[context.Context, *CWE, *sdf.LockRTracker, beam.T, func(beam.EventTime, int64), sdf.ProcessContinuation]((*eventtimeSDFStream)(nil)) + register.Emitter2[beam.EventTime, int64]() +} + func main() { flag.Parse() beam.Init() ctx := context.Background() - project := gcpopts.GetProject(ctx) + //project := gcpopts.GetProject(ctx) log.Infof(ctx, "Publishing %v messages to: %v", len(data), *input) - defer pubsubx.CleanupTopic(ctx, project, *input) - sub, err := pubsubx.Publish(ctx, project, *input, data...) - if err != nil { - log.Fatal(ctx, err) - } + // defer pubsubx.CleanupTopic(ctx, project, *input) + // sub, err := pubsubx.Publish(ctx, project, *input, data...) + // if err != nil { + // log.Fatal(ctx, err) + // } - log.Infof(ctx, "Running streaming wordcap with subscription: %v", sub.ID()) + //log.Infof(ctx, "Running streaming wordcap with subscription: %v", sub.ID()) p := beam.NewPipeline() s := p.Root() - col := pubsubio.Read(s, project, *input, &pubsubio.ReadOptions{Subscription: sub.ID()}) - str := beam.ParDo(s, func(b []byte) string { - return (string)(b) - }, col) - cap := beam.ParDo(s, strings.ToUpper, str) - debug.Print(s, cap) + //col := pubsubio.Read(s, project, *input, &pubsubio.ReadOptions{Subscription: sub.ID()}) + // col = beam.WindowInto(s, window.NewFixedWindows(60*time.Second), col) + + // str := beam.ParDo(s, func(b []byte) string { + // return (string)(b) + // }, col) + + imp := beam.Impulse(s) + elms := 100 + out := beam.ParDo(s, &eventtimeSDFStream{ + Sleep: time.Second, + RestSize: int64(elms), + Mod: int64(elms), + Fixed: 1, + }, imp) + // out = beam.WindowInto(s, window.NewFixedWindows(10*time.Second), out) + str := beam.ParDo(s, func(b int64) string { + return fmt.Sprintf("element%03d", b) + }, out) + + keyed := beam.ParDo(s, func(ctx context.Context, ts beam.EventTime, s string) (string, string) { + log.Infof(ctx, "adding key ts: %v now: %v word: %v", ts.ToTime(), time.Now(), s) + return "test", s + }, str) + debug.Printf(s, "pre stateful: %v", keyed) + + timed := beam.ParDo(s, NewStateful(), keyed) + debug.Printf(s, "post stateful: %v", timed) if err := beamx.Run(context.Background(), p); err != nil { log.Exitf(ctx, "Failed to execute job: %v", err) diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go index fdc1e368a52..9380bb8902b 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/data.go +++ b/sdks/go/pkg/beam/core/runtime/exec/data.go @@ -57,10 +57,12 @@ type SideCache interface { // DataManager manages external data byte streams. Each data stream can be // opened by one consumer only. type DataManager interface { - // OpenRead opens a closable byte stream for reading. - OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) - // OpenWrite opens a closable byte stream for writing. + // OpenElementChan opens a channel for data and timers. + OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) + // OpenWrite opens a closable byte stream for data writing. OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) + // OpenTimerWrite opens a byte stream for writing timers + OpenTimerWrite(ctx context.Context, id StreamID, family string) (io.WriteCloser, error) } // StateReader is the interface for reading side input data. @@ -91,4 +93,10 @@ type StateReader interface { GetSideInputCache() SideCache } -// TODO(herohde) 7/20/2018: user state management +// Elements holds data or timers sent across the data channel. +// If TimerFamilyID is populated, it's a timer, otherwise it's +// data elements. +type Elements struct { + Data, Timers []byte + TimerFamilyID string +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index a6347fc8d0e..12a5faeaf30 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -97,17 +97,60 @@ func (n *DataSource) StartBundle(ctx context.Context, id string, data DataContex n.source = data.Data n.state = data.State n.start = time.Now() - n.index = -1 + n.index = 0 n.splitIdx = math.MaxInt64 n.mu.Unlock() return n.Out.StartBundle(ctx, id, data) } +// process handles converting elements from the data source to timers. +func (n *DataSource) process(ctx context.Context, data func(bcr *byteCountReader) error, timer func(bcr *byteCountReader, timerFamilyID string) error) error { + elms, err := n.source.OpenElementChan(ctx, n.SID) + if err != nil { + return err + } + + n.PCol.resetSize() // initialize the size distribution for this bundle. + var r bytes.Reader + + var byteCount int + bcr := byteCountReader{reader: &r, count: &byteCount} + for { + var err error + select { + case e, ok := <-elms: + // Channel closed, so time to exit + if !ok { + return nil + } + if len(e.Data) > 0 { + r.Reset(e.Data) + log.Debugf(ctx, "%v: received %v", n, e.Data) + err = data(&bcr) + } + if len(e.Timers) > 0 { + r.Reset(e.Timers) + err = timer(&bcr, e.TimerFamilyID) + } + case <-ctx.Done(): + return nil + } + + if err != nil { + if err != io.EOF { + return errors.Wrap(err, "source failed") + } + // io.EOF means the reader successfully drained + // We're ready for a new buffer. + } + } +} + // ByteCountReader is a passthrough reader that counts all the bytes read through it. // It trusts the nested reader to return accurate byte information. type byteCountReader struct { count *int - reader io.ReadCloser + reader io.Reader } func (r *byteCountReader) Read(p []byte) (int, error) { @@ -117,7 +160,10 @@ func (r *byteCountReader) Read(p []byte) (int, error) { } func (r *byteCountReader) Close() error { - return r.reader.Close() + if c, ok := r.reader.(io.Closer); ok { + c.Close() + } + return nil } func (r *byteCountReader) reset() int { @@ -128,15 +174,6 @@ func (r *byteCountReader) reset() int { // Process opens the data source, reads and decodes data, kicking off element processing. func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { - r, err := n.source.OpenRead(ctx, n.SID) - if err != nil { - return nil, err - } - defer r.Close() - n.PCol.resetSize() // initialize the size distribution for this bundle. - var byteCount int - bcr := byteCountReader{reader: r, count: &byteCount} - c := coder.SkipW(n.Coder) wc := MakeWindowDecoder(n.Coder.Window) @@ -155,58 +192,68 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { } var checkpoints []*Checkpoint - for { - if n.incrementIndexAndCheckSplit() { - break - } - // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? - ws, t, pn, err := DecodeWindowedValueHeader(wc, r) - if err != nil { - if err == io.EOF { - break + err := n.process(ctx, func(bcr *byteCountReader) error { + for { + // TODO(lostluck) 2020/02/22: Should we include window headers or just count the element sizes? + ws, t, pn, err := DecodeWindowedValueHeader(wc, bcr.reader) + if err != nil { + return err } - return nil, errors.Wrap(err, "source failed") - } - - // Decode key or parallel element. - pe, err := cp.Decode(&bcr) - if err != nil { - return nil, errors.Wrap(err, "source decode failed") - } - pe.Timestamp = t - pe.Windows = ws - pe.Pane = pn - var valReStreams []ReStream - for _, cv := range cvs { - values, err := n.makeReStream(ctx, cv, &bcr, len(cvs) == 1 && n.singleIterate) + // Decode key or parallel element. + pe, err := cp.Decode(bcr) if err != nil { - return nil, err + return errors.Wrap(err, "source decode failed") } - valReStreams = append(valReStreams, values) - } + pe.Timestamp = t + pe.Windows = ws + pe.Pane = pn - if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { - return nil, err - } - // Collect the actual size of the element, and reset the bytecounter reader. - n.PCol.addSize(int64(bcr.reset())) - bcr.reader = r - - // Check if there's a continuation and return residuals - // Needs to be done immeadiately after processing to not lose the element. - if c := n.getProcessContinuation(); c != nil { - cp, err := n.checkpointThis(ctx, c) - if err != nil { - // Errors during checkpointing should fail a bundle. - return nil, err + log.Debugf(ctx, "%v: processing %v,%v", n, pe.Elm, pe.Elm2) + + var valReStreams []ReStream + for _, cv := range cvs { + values, err := n.makeReStream(ctx, cv, bcr, len(cvs) == 1 && n.singleIterate) + if err != nil { + return err + } + valReStreams = append(valReStreams, values) + } + + if err := n.Out.ProcessElement(ctx, pe, valReStreams...); err != nil { + return err + } + // Collect the actual size of the element, and reset the bytecounter reader. + n.PCol.addSize(int64(bcr.reset())) + + // Check if there's a continuation and return residuals + // Needs to be done immeadiately after processing to not lose the element. + if c := n.getProcessContinuation(); c != nil { + cp, err := n.checkpointThis(ctx, c) + if err != nil { + // Errors during checkpointing should fail a bundle. + return err + } + if cp != nil { + checkpoints = append(checkpoints, cp) + } } - if cp != nil { - checkpoints = append(checkpoints, cp) + // We've finished processing an element, check if we have finished a split. + if n.incrementIndexAndCheckSplit() { + break } } - } - return checkpoints, nil + // Signal data loop exit. + log.Debugf(ctx, "%v: exiting data loop", n) + return nil + }, + func(bcr *byteCountReader, timerFamilyID string) error { + tmap, err := decodeTimer(cp, wc, bcr) + log.Errorf(ctx, "timer received: %v - %+v err: %v", timerFamilyID, tmap, err) + return nil + }) + + return checkpoints, err } func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *byteCountReader, onlyStream bool) (ReStream, error) { @@ -313,7 +360,7 @@ func (n *DataSource) makeReStream(ctx context.Context, cv ElementDecoder, bcr *b } } -func readStreamToBuffer(cv ElementDecoder, r io.ReadCloser, size int64, buf []FullValue) ([]FullValue, error) { +func readStreamToBuffer(cv ElementDecoder, r io.Reader, size int64, buf []FullValue) ([]FullValue, error) { for i := int64(0); i < size; i++ { value, err := cv.Decode(r) if err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go index 2da3284f016..14e954d26af 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go @@ -1018,16 +1018,21 @@ func runOnRoots(ctx context.Context, t *testing.T, p *Plan, name string, mthd fu type TestDataManager struct { R io.ReadCloser + C chan Elements } -func (dm *TestDataManager) OpenRead(ctx context.Context, id StreamID) (io.ReadCloser, error) { - return dm.R, nil +func (dm *TestDataManager) OpenElementChan(ctx context.Context, id StreamID) (<-chan Elements, error) { + return dm.C, nil } func (dm *TestDataManager) OpenWrite(ctx context.Context, id StreamID) (io.WriteCloser, error) { return nil, nil } +func (dm *TestDataManager) OpenTimerWrite(ctx context.Context, id StreamID, key string) (io.WriteCloser, error) { + return nil, nil +} + // TestSideInputReader simulates state reads using channels. type TestStateReader struct { StateReader diff --git a/sdks/go/pkg/beam/core/runtime/exec/fn.go b/sdks/go/pkg/beam/core/runtime/exec/fn.go index d0fdb8e3630..c108627a52c 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/fn.go +++ b/sdks/go/pkg/beam/core/runtime/exec/fn.go @@ -28,7 +28,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" - "github.com/apache/beam/sdks/v2/go/pkg/beam/log" ) //go:generate specialize --input=fn_arity.tmpl @@ -243,11 +242,10 @@ func (n *invoker) invokeWithOpts(ctx context.Context, pn typex.PaneInfo, ws []ty } if n.tpIdx >= 0 { - log.Debugf(ctx, "timercall %+v", opts) - tp, err := opts.ta.NewTimerProvider(ctx, opts.tm, ws, opts.opt) + tp, err := opts.ta.NewTimerProvider(ctx, opts.tm, ts, ws, opts.opt) if err != nil { return nil, err - } + } /* */ n.tp = &tp args[n.tpIdx] = n.tp } diff --git a/sdks/go/pkg/beam/core/runtime/exec/timers.go b/sdks/go/pkg/beam/core/runtime/exec/timers.go index 0ceed0d2ebd..9e1d52d31a1 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/timers.go +++ b/sdks/go/pkg/beam/core/runtime/exec/timers.go @@ -23,10 +23,11 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/timers" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" ) type UserTimerAdapter interface { - NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) + NewTimerProvider(ctx context.Context, manager DataManager, inputTimestamp typex.EventTime, windows []typex.Window, element *MainInput) (timerProvider, error) } type userTimerAdapter struct { @@ -51,7 +52,7 @@ func NewUserTimerAdapter(sID StreamID, c *coder.Coder, timerCoders map[string]*c return &userTimerAdapter{SID: sID, wc: wc, kc: kc, C: c, timerIDToCoder: timerCoders} } -func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataManager, w []typex.Window, element *MainInput) (timerProvider, error) { +func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataManager, inputTs typex.EventTime, w []typex.Window, element *MainInput) (timerProvider, error) { if u.kc == nil { return timerProvider{}, fmt.Errorf("cannot make a state provider for an unkeyed input %v", element) } @@ -68,6 +69,7 @@ func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataMan ctx: ctx, tm: manager, elementKey: elementKey, + inputTimestamp: inputTs, SID: u.SID, window: w, writersByFamily: make(map[string]io.Writer), @@ -78,11 +80,12 @@ func (u *userTimerAdapter) NewTimerProvider(ctx context.Context, manager DataMan } type timerProvider struct { - ctx context.Context - tm DataManager - SID StreamID - elementKey []byte - window []typex.Window + ctx context.Context + tm DataManager + SID StreamID + inputTimestamp typex.EventTime + elementKey []byte + window []typex.Window pn typex.PaneInfo @@ -117,6 +120,7 @@ func (p *timerProvider) Set(t timers.TimerMap) { HoldTimestamp: t.HoldTimestamp, Pane: p.pn, } + log.Debugf(p.ctx, "timer set: %+v", tm) fv := FullValue{Elm: tm} enc := MakeElementEncoder(coder.SkipW(p.codersByFamily[t.Family])) if err := enc.Encode(&fv, w); err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go index 5a9b536b288..f15e71401eb 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr.go @@ -17,6 +17,7 @@ package harness import ( "context" + "fmt" "io" "sync" "time" @@ -47,22 +48,31 @@ func NewScopedDataManager(mgr *DataChannelManager, instID instructionID) *Scoped return &ScopedDataManager{mgr: mgr, instID: instID} } -// OpenRead opens an io.ReadCloser on the given stream. -func (s *ScopedDataManager) OpenRead(ctx context.Context, id exec.StreamID) (io.ReadCloser, error) { +// OpenWrite opens an io.WriteCloser on the given stream. +func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenRead(ctx, id.PtransformID, s.instID), nil + return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil } -// OpenWrite opens an io.WriteCloser on the given stream. -func (s *ScopedDataManager) OpenWrite(ctx context.Context, id exec.StreamID) (io.WriteCloser, error) { +// OpenElementChan returns a channel of exec.Elements on the given stream. +func (s *ScopedDataManager) OpenElementChan(ctx context.Context, id exec.StreamID) (<-chan exec.Elements, error) { ch, err := s.open(ctx, id.Port) if err != nil { return nil, err } - return ch.OpenWrite(ctx, id.PtransformID, s.instID), nil + return ch.OpenElementChan(ctx, id.PtransformID, s.instID), nil +} + +// OpenTimerWrite opens an io.WriteCloser on the given stream to write timers +func (s *ScopedDataManager) OpenTimerWrite(ctx context.Context, id exec.StreamID, family string) (io.WriteCloser, error) { + ch, err := s.open(ctx, id.Port) + if err != nil { + return nil, err + } + return ch.OpenTimerWrite(ctx, id.PtransformID, s.instID, family), nil } func (s *ScopedDataManager) open(ctx context.Context, port exec.Port) (*DataChannel, error) { @@ -134,8 +144,9 @@ func (m *DataChannelManager) closeInstruction(instID instructionID) { // clientID identifies a client of a connected channel. type clientID struct { - ptransformID string - instID instructionID + ptransformID string + instID instructionID + timerFamilyID string } // This is a reduced version of the full gRPC interface to help with testing. @@ -155,8 +166,9 @@ type DataChannel struct { id string client dataClient - writers map[instructionID]map[string]*dataWriter - readers map[instructionID]map[string]*dataReader + writers map[instructionID]map[string]*dataWriter + timerWriters map[instructionID]map[string]*timerWriter + channels map[instructionID]map[string]*elementsChan // recently terminated instructions endedInstructions map[instructionID]struct{} @@ -172,6 +184,19 @@ type DataChannel struct { mu sync.Mutex // guards mutable internal data, notably the maps and readErr. } +type elementsChan struct { + ch chan exec.Elements + complete bool +} + +func (ec *elementsChan) Close() error { + if !ec.complete { + ec.complete = true + close(ec.ch) + } + return nil +} + func newDataChannel(ctx context.Context, port exec.Port) (*DataChannel, error) { ctx, cancelFn := context.WithCancel(ctx) cc, err := dial(ctx, port.URL, "data", 15*time.Second) @@ -196,7 +221,8 @@ func makeDataChannel(ctx context.Context, id string, client dataClient, cancelFn id: id, client: client, writers: make(map[instructionID]map[string]*dataWriter), - readers: make(map[instructionID]map[string]*dataReader), + timerWriters: make(map[instructionID]map[string]*timerWriter), + channels: make(map[instructionID]map[string]*elementsChan), endedInstructions: make(map[instructionID]struct{}), cancelFn: cancelFn, } @@ -214,25 +240,56 @@ func (c *DataChannel) terminateStreamOnError(err error) { } } -// OpenRead returns an io.ReadCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenRead(ctx context.Context, ptransformID string, instID instructionID) io.ReadCloser { +// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. +func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { + return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +} + +// OpenElementChan returns a channel of typex.Elements for the given instruction and ptransform. +func (c *DataChannel) OpenElementChan(ctx context.Context, ptransformID string, instID instructionID) <-chan exec.Elements { c.mu.Lock() defer c.mu.Unlock() cid := clientID{ptransformID: ptransformID, instID: instID} if c.readErr != nil { - log.Errorf(ctx, "opening a reader %v on a closed channel", cid) - return &errReader{c.readErr} + panic(fmt.Errorf("opening a reader %v on a closed channel", cid)) } - return c.makeReader(ctx, cid) + return c.makeChannel(ctx, cid).ch } -// OpenWrite returns an io.WriteCloser of the data elements for the given instruction and ptransform. -func (c *DataChannel) OpenWrite(ctx context.Context, ptransformID string, instID instructionID) io.WriteCloser { - return c.makeWriter(ctx, clientID{ptransformID: ptransformID, instID: instID}) +// makeChannel creates a channel of exec.Elements. It expects to be called while c.mu is held. +func (c *DataChannel) makeChannel(ctx context.Context, id clientID) *elementsChan { + var m map[string]*elementsChan + var ok bool + if m, ok = c.channels[id.instID]; !ok { + m = make(map[string]*elementsChan) + c.channels[id.instID] = m + } + + if r, ok := m[id.ptransformID]; ok { + return r + } + + r := &elementsChan{ch: make(chan exec.Elements, 20)} + // Just in case initial data for an instruction arrives *after* an instructon has ended. + // eg. it was blocked by another reader being slow, or the other instruction failed. + // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. + if _, ok := c.endedInstructions[id.instID]; ok { + close(r.ch) + r.complete = true + return r + } + + m[id.ptransformID] = r + return r +} + +// OpenTimerWrite returns io.WriteCloser for the given timerFamilyID, instruction and ptransform. +func (c *DataChannel) OpenTimerWrite(ctx context.Context, ptransformID string, instID instructionID, family string) io.WriteCloser { + return c.makeTimerWriter(ctx, clientID{timerFamilyID: family, ptransformID: ptransformID, instID: instID}) } func (c *DataChannel) read(ctx context.Context) { - cache := make(map[clientID]*dataReader) + cache := make(map[clientID]*elementsChan) for { msg, err := c.client.Recv() if err != nil { @@ -243,15 +300,11 @@ func (c *DataChannel) read(ctx context.Context) { // close the r.buf channels twice, or send on a closed channel. // Any other approach is racy, and may cause one of the above // panics. - for _, m := range c.readers { - for _, r := range m { - log.Errorf(ctx, "DataChannel.read %v reader %v closing due to error on channel", c.id, r.id) - if !r.completed { - r.completed = true - r.err = err - close(r.buf) - } - delete(cache, r.id) + for instID, m := range c.channels { + for tid, r := range m { + log.Errorf(ctx, "DataChannel.read %v channel inst: %v tid %v closing due to error on channel", c.id, instID, tid) + r.Close() + delete(cache, clientID{ptransformID: tid, instID: instID}) } } c.terminateStreamOnError(err) @@ -274,31 +327,28 @@ func (c *DataChannel) read(ctx context.Context) { for _, elm := range msg.GetData() { id := clientID{ptransformID: elm.TransformId, instID: instructionID(elm.GetInstructionId())} - var r *dataReader + var r *elementsChan if local, ok := cache[id]; ok { r = local } else { c.mu.Lock() - r = c.makeReader(ctx, id) + r = c.makeChannel(ctx, id) c.mu.Unlock() cache[id] = r } + // This send is deliberately blocking, if we exceed the buffering for + // a reader. We can't buffer the entire main input, if some user code + // is slow (or gets stuck). If the local side closes, the reader + // will be marked as completed and further remote data will be ignored. + select { + case r.ch <- exec.Elements{Data: elm.GetData()}: + case <-ctx.Done(): + // Technically, we need to close all the things here... to start. + r.Close() + } if elm.GetIsLast() { - // If this reader hasn't closed yet, do so now. - if !r.completed { - // Use the last segment if any. - if len(elm.GetData()) != 0 { - // In case of local side closing, send with select. - select { - case r.buf <- elm.GetData(): - case <-r.done: - } - } - // Close buffer to signal EOF. - r.completed = true - close(r.buf) - } + r.Close() // Clean up local bookkeeping. We'll never see another message // for it again. We have to be careful not to remove the real @@ -307,12 +357,32 @@ func (c *DataChannel) read(ctx context.Context) { delete(cache, id) continue } + } + for _, tim := range msg.GetTimers() { + id := clientID{ + ptransformID: tim.TransformId, + instID: instructionID(tim.GetInstructionId()), + // timerFamilyID: tim.GetTimerFamilyId(), + } + log.Infof(ctx, "timer received for %v, %v: %v", id, tim.GetTimerFamilyId(), tim.GetTimers()) + var r *elementsChan + if local, ok := cache[id]; ok { + r = local + } else { + c.mu.Lock() + r = c.makeChannel(ctx, id) + c.mu.Unlock() + cache[id] = r + } + if tim.GetIsLast() { + // If this reader hasn't closed yet, do so now. + r.Close() - if r.completed { - // The local reader has closed but the remote is still sending data. - // Just ignore it. We keep the reader config in the cache so we don't - // treat it as a new reader. Eventually the stream will finish and go - // through normal teardown. + // Clean up local bookkeeping. We'll never see another message + // for it again. We have to be careful not to remove the real + // one, because readers may be initialized after we've seen + // the full stream. + delete(cache, id) continue } @@ -321,64 +391,15 @@ func (c *DataChannel) read(ctx context.Context) { // is slow (or gets stuck). If the local side closes, the reader // will be marked as completed and further remote data will be ignored. select { - case r.buf <- elm.GetData(): - case <-r.done: - r.completed = true - close(r.buf) + case r.ch <- exec.Elements{Timers: tim.GetTimers(), TimerFamilyID: tim.GetTimerFamilyId()}: + case <-ctx.Done(): + // Technically, we need to close all the things here... to start. + r.Close() } } } } -type errReader struct { - err error -} - -func (r *errReader) Read(_ []byte) (int, error) { - return 0, r.err -} - -func (r *errReader) Close() error { - return r.err -} - -// makeReader creates a dataReader. It expects to be called while c.mu is held. -func (c *DataChannel) makeReader(ctx context.Context, id clientID) *dataReader { - var m map[string]*dataReader - var ok bool - if m, ok = c.readers[id.instID]; !ok { - m = make(map[string]*dataReader) - c.readers[id.instID] = m - } - - if r, ok := m[id.ptransformID]; ok { - return r - } - - r := &dataReader{id: id, buf: make(chan []byte, bufElements), done: make(chan bool, 1), channel: c} - - // Just in case initial data for an instruction arrives *after* an instructon has ended. - // eg. it was blocked by another reader being slow, or the other instruction failed. - // So we provide a pre-completed reader, and do not cache it, as there's no further cleanup for it. - if _, ok := c.endedInstructions[id.instID]; ok { - r.completed = true - close(r.buf) - r.err = io.EOF // In case of any actual data readers, so they terminate without error. - return r - } - - m[id.ptransformID] = r - return r -} - -func (c *DataChannel) removeReader(id clientID) { - c.mu.Lock() - if m, ok := c.readers[id.instID]; ok { - delete(m, id.ptransformID) - } - c.mu.Unlock() -} - const endedInstructionCap = 32 // removeInstruction closes all readers and writers registered for the instruction @@ -395,21 +416,25 @@ func (c *DataChannel) removeInstruction(instID instructionID) { c.endedInstructions[instID] = struct{}{} c.rmQueue = append(c.rmQueue, instID) - rs := c.readers[instID] ws := c.writers[instID] + tws := c.timerWriters[instID] + ecs := c.channels[instID] // Prevent other users while we iterate. - delete(c.readers, instID) delete(c.writers, instID) + delete(c.timerWriters, instID) + delete(c.channels, instID) c.mu.Unlock() - // Close grabs the channel lock, so this must be outside the critical section. - for _, r := range rs { - r.Close() - } for _, w := range ws { w.Close() } + for _, tw := range tws { + tw.Close() + } + for _, ec := range ecs { + ec.Close() + } } func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { @@ -423,7 +448,7 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { c.writers[id.instID] = m } - if w, ok := m[id.ptransformID]; ok { + if w, ok := m[makeID(id)]; ok { return w } @@ -432,50 +457,40 @@ func (c *DataChannel) makeWriter(ctx context.Context, id clientID) *dataWriter { // runner or user directed. w := &dataWriter{ch: c, id: id} - m[id.ptransformID] = w + m[makeID(id)] = w return w } -type dataReader struct { - id clientID - buf chan []byte - done chan bool - cur []byte - channel *DataChannel - completed bool - err error -} +func (c *DataChannel) makeTimerWriter(ctx context.Context, id clientID) *timerWriter { + c.mu.Lock() + defer c.mu.Unlock() -func (r *dataReader) Close() error { - r.done <- true - r.channel.removeReader(r.id) - return nil -} + var m map[string]*timerWriter + var ok bool + if m, ok = c.timerWriters[id.instID]; !ok { + m = make(map[string]*timerWriter) + c.timerWriters[id.instID] = m + } -func (r *dataReader) Read(buf []byte) (int, error) { - if r.cur == nil { - b, ok := <-r.buf - if !ok { - if r.err == nil { - return 0, io.EOF - } - return 0, r.err - } - r.cur = b + if w, ok := m[makeID(id)]; ok { + return w } - // We don't need to check for a 0 length copy from r.cur here, since that's - // checked before buffers are handed to the r.buf channel. - n := copy(buf, r.cur) + // We don't check for ended instructions for writers, as writers + // can only be created if an instruction is in scope, and aren't + // runner or user directed. - switch { - case len(r.cur) == n: - r.cur = nil - default: - r.cur = r.cur[n:] - } + w := &timerWriter{ch: c, id: id} + m[makeID(id)] = w + return w +} - return n, nil +func makeID(id clientID) string { + newID := id.ptransformID + if id.timerFamilyID != "" { + newID += ":" + id.timerFamilyID + } + return newID } type dataWriter struct { @@ -574,3 +589,72 @@ func (w *dataWriter) Write(p []byte) (n int, err error) { w.buf = append(w.buf, p...) return len(p), nil } + +type timerWriter struct { + id clientID + ch *DataChannel +} + +// send requires the ch.mu lock to be held. +func (w *timerWriter) send(msg *fnpb.Elements) error { + recordStreamSend(msg) + if err := w.ch.client.Send(msg); err != nil { + if err == io.EOF { + log.Warnf(context.TODO(), "dataWriter[%v;%v] EOF on send; fetching real error", w.id, w.ch.id) + err = nil + for err == nil { + // Per GRPC stream documentation, if there's an EOF, we must call Recv + // until a non-nil error is returned, to ensure resources are cleaned up. + // https://pkg.go.dev/google.golang.org/grpc#ClientConn.NewStream + _, err = w.ch.client.Recv() + } + } + log.Warnf(context.TODO(), "dataWriter[%v;%v] error on send: %v", w.id, w.ch.id, err) + w.ch.terminateStreamOnError(err) + return err + } + return nil +} + +func (w *timerWriter) Close() error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + delete(w.ch.timerWriters[w.id.instID], makeID(w.id)) + var msg *fnpb.Elements + msg = &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.id.timerFamilyID, + IsLast: true, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) writeTimers(p []byte) error { + w.ch.mu.Lock() + defer w.ch.mu.Unlock() + + msg := &fnpb.Elements{ + Timers: []*fnpb.Elements_Timers{ + { + InstructionId: string(w.id.instID), + TransformId: w.id.ptransformID, + TimerFamilyId: w.id.timerFamilyID, + Timers: p, + }, + }, + } + return w.send(msg) +} + +func (w *timerWriter) Write(p []byte) (n int, err error) { + // write timers directly without buffering. + if err := w.writeTimers(p); err != nil { + return 0, err + } + return len(p), nil +} diff --git a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go index f69d9abde49..3f77d69f173 100644 --- a/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go +++ b/sdks/go/pkg/beam/core/runtime/harness/datamgr_test.go @@ -101,147 +101,6 @@ func (f *fakeDataClient) Send(*fnpb.Elements) error { return nil } -func TestDataChannelTerminate_dataReader(t *testing.T) { - // The logging of channels closed is quite noisy for this test - log.SetOutput(io.Discard) - - expectedError := fmt.Errorf("EXPECTED ERROR") - - tests := []struct { - name string - expectedError error - caseFn func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) - }{ - { - name: "onClose", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // We don't read up all the buffered data, but immediately close the reader. - // Previously, since nothing was consuming the incoming gRPC data, the whole - // data channel would get stuck, and the client.Recv() call was eventually - // no longer called. - r.Close() - - // If done is signaled, that means client.Recv() has been called to flush the - // channel, meaning consumer code isn't stuck. - <-client.done - }, - }, { - name: "onSentinel", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // fakeDataClient eventually returns a sentinel element. - }, - }, { - name: "onIsLast_withData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the last call with data to use is_last. - client.isLastCall = 2 - }, - }, { - name: "onIsLast_withoutData", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // Set the call without data to use is_last. - client.isLastCall = 3 - }, - }, { - name: "onRecvError", - expectedError: expectedError, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - // The SDK starts reading in a goroutine immeadiately after open. - // Set the 2nd Recv call to have an error. - client.err = expectedError - }, - }, { - name: "onInstructionEnd", - expectedError: io.EOF, - caseFn: func(t *testing.T, r io.ReadCloser, client *fakeDataClient, c *DataChannel) { - c.removeInstruction("inst_ref") - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - - r := c.OpenRead(ctx, "ptr", "inst_ref") - - n, err := r.Read(make([]byte, 4)) - if err != nil { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) - } - test.caseFn(t, r, client, c) - // Drain the reader. - i := 1 // For the earlier Read. - for err == nil { - read := make([]byte, 4) - _, err = r.Read(read) - i++ - } - - if got, want := err, test.expectedError; got != want { - t.Errorf("Unexpected error from read %d: got %v, want %v", i, got, want) - } - // Verify that new readers return the same error on their reads after client.Recv is done. - if n, err := c.OpenRead(ctx, "ptr", "inst_ref").Read(make([]byte, 4)); err != test.expectedError { - t.Errorf("Unexpected error from read: got %v, want, %v read %d bytes.", err, test.expectedError, n) - } - - select { - case <-ctx.Done(): // Assert that the context must have been cancelled on read failures. - return - case <-time.After(time.Second * 5): - t.Fatal("context wasn't cancelled") - } - }) - } -} - -func TestDataChannelRemoveInstruction_dataAfterClose(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - client.blocked.Lock() - - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - c.removeInstruction("inst_ref") - - client.blocked.Unlock() - - r := c.OpenRead(ctx, "ptr", "inst_ref") - - dr := r.(*dataReader) - if !dr.completed || dr.err != io.EOF { - t.Errorf("Expected a closed reader, but was still open: completed: %v, err: %v", dr.completed, dr.err) - } - - n, err := r.Read(make([]byte, 4)) - if err != io.EOF { - t.Errorf("Unexpected error from read: %v, read %d bytes.", err, n) - } -} - -func TestDataChannelRemoveInstruction_limitInstructionCap(t *testing.T) { - done := make(chan bool, 1) - client := &fakeDataClient{t: t, done: done} - ctx, cancelFn := context.WithCancel(context.Background()) - c := makeDataChannel(ctx, "id", client, cancelFn) - - for i := 0; i < endedInstructionCap+10; i++ { - instID := instructionID(fmt.Sprintf("inst_ref%d", i)) - c.OpenRead(ctx, "ptr", instID) - c.removeInstruction(instID) - } - if got, want := len(c.endedInstructions), endedInstructionCap; got != want { - t.Errorf("unexpected len(endedInstructions) got %v, want %v,", got, want) - } -} - func TestDataChannelTerminate_Writes(t *testing.T) { // The logging of channels closed is quite noisy for this test log.SetOutput(io.Discard) diff --git a/sdks/go/pkg/beam/core/timers/timers.go b/sdks/go/pkg/beam/core/timers/timers.go index 130564790ca..afb5ddd98b8 100644 --- a/sdks/go/pkg/beam/core/timers/timers.go +++ b/sdks/go/pkg/beam/core/timers/timers.go @@ -102,7 +102,7 @@ func (t ProcessingTime) Set(p Provider, firingTimestamp time.Time) { func (t ProcessingTime) SetWithOpts(p Provider, firingTimestamp time.Time, opts Opts) { fire := mtime.FromTime(firingTimestamp) - // Hold timestamp must match fireing timestamp if not otherwise set. + // Hold timestamp must match input element timestamp if not otherwise set. tm := TimerMap{Family: t.Family, Tag: opts.Tag, FireTimestamp: fire, HoldTimestamp: fire} if !opts.Hold.IsZero() { tm.HoldTimestamp = mtime.FromTime(opts.Hold)
