This is an automated email from the ASF dual-hosted git repository.
shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 990b5ffb30b [Prism] Support injecting triggered bundle for a batch of
elements. (#36219)
990b5ffb30b is described below
commit 990b5ffb30bc754390849ce5aaab638c08b255f3
Author: Shunping Huang <[email protected]>
AuthorDate: Mon Sep 22 18:56:06 2025 -0400
[Prism] Support injecting triggered bundle for a batch of elements. (#36219)
* Support injecting trigger bundle for a batch of elements.
* Override streaming mode if there is an unbounded pcollection.
* Refactor some code.
* Enable prism on faild pipelines and rebench.
* Add tests for streaming and batch mode on data trigger for prism.
* Revert "Enable prism on faild pipelines and rebench."
This reverts commit bc648d5d40db86c672a107358c64018bcec351c7.
* Fix the newly added tests.
---
.../prism/internal/engine/elementmanager.go | 95 ++++++++++++++--------
sdks/go/pkg/beam/runners/prism/internal/execute.go | 13 +++
.../runners/portability/prism_runner_test.py | 70 +++++++++++++++-
3 files changed, 143 insertions(+), 35 deletions(-)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index cba4774dd3f..0ef7ed4ea44 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -184,6 +184,8 @@ type Config struct {
MaxBundleSize int
// Whether to use real-time clock as processing time
EnableRTC bool
+ // Whether to process the data in a streaming mode
+ StreamingMode bool
}
// ElementManager handles elements, watermarks, and related errata to determine
@@ -1296,6 +1298,43 @@ func (ss *stageState) AddPending(em *ElementManager,
newPending []element) int {
return ss.kind.addPending(ss, em, newPending)
}
+func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window
typex.Window, key string) int {
+ // Check on triggers for this key.
+ // We use an empty linkID as the key into state for aggregations.
+ count := 0
+ if ss.state == nil {
+ ss.state =
make(map[LinkID]map[typex.Window]map[string]StateData)
+ }
+ lv, ok := ss.state[LinkID{}]
+ if !ok {
+ lv = make(map[typex.Window]map[string]StateData)
+ ss.state[LinkID{}] = lv
+ }
+ wv, ok := lv[window]
+ if !ok {
+ wv = make(map[string]StateData)
+ lv[window] = wv
+ }
+ state := wv[key]
+ endOfWindowReached := window.MaxTimestamp() < ss.input
+ ready := ss.strat.IsTriggerReady(triggerInput{
+ newElementCount: 1,
+ endOfWindowReached: endOfWindowReached,
+ }, &state)
+
+ if ready {
+ state.Pane = computeNextTriggeredPane(state.Pane,
endOfWindowReached)
+ }
+ // Store the state as triggers may have changed it.
+ ss.state[LinkID{}][window][key] = state
+
+ // If we're ready, it's time to fire!
+ if ready {
+ count += ss.buildTriggeredBundle(em, key, window)
+ }
+ return count
+}
+
// addPending for aggregate stages behaves likes stateful stages, but don't
need to handle timers or a separate window
// expiration condition.
func (*aggregateStageKind) addPending(ss *stageState, em *ElementManager,
newPending []element) int {
@@ -1315,6 +1354,13 @@ func (*aggregateStageKind) addPending(ss *stageState, em
*ElementManager, newPen
if ss.pendingByKeys == nil {
ss.pendingByKeys = map[string]*dataAndTimers{}
}
+
+ type windowKey struct {
+ window typex.Window
+ key string
+ }
+ pendingWindowKeys := set[windowKey]{}
+
count := 0
for _, e := range newPending {
count++
@@ -1327,37 +1373,18 @@ func (*aggregateStageKind) addPending(ss *stageState,
em *ElementManager, newPen
ss.pendingByKeys[string(e.keyBytes)] = dnt
}
heap.Push(&dnt.elements, e)
- // Check on triggers for this key.
- // We use an empty linkID as the key into state for
aggregations.
- if ss.state == nil {
- ss.state =
make(map[LinkID]map[typex.Window]map[string]StateData)
- }
- lv, ok := ss.state[LinkID{}]
- if !ok {
- lv = make(map[typex.Window]map[string]StateData)
- ss.state[LinkID{}] = lv
- }
- wv, ok := lv[e.window]
- if !ok {
- wv = make(map[string]StateData)
- lv[e.window] = wv
- }
- state := wv[string(e.keyBytes)]
- endOfWindowReached := e.window.MaxTimestamp() < ss.input
- ready := ss.strat.IsTriggerReady(triggerInput{
- newElementCount: 1,
- endOfWindowReached: endOfWindowReached,
- }, &state)
- if ready {
- state.Pane = computeNextTriggeredPane(state.Pane,
endOfWindowReached)
+ if em.config.StreamingMode {
+ // In streaming mode, we check trigger readiness on
each element
+ count += ss.injectTriggeredBundlesIfReady(em, e.window,
string(e.keyBytes))
+ } else {
+ // In batch mode, we store key + window pairs here and
check trigger readiness for each of them later.
+ pendingWindowKeys.insert(windowKey{window: e.window,
key: string(e.keyBytes)})
}
- // Store the state as triggers may have changed it.
- ss.state[LinkID{}][e.window][string(e.keyBytes)] = state
-
- // If we're ready, it's time to fire!
- if ready {
- count += ss.buildTriggeredBundle(em, e.keyBytes,
e.window)
+ }
+ if !em.config.StreamingMode {
+ for wk := range pendingWindowKeys {
+ count += ss.injectTriggeredBundlesIfReady(em,
wk.window, wk.key)
}
}
return count
@@ -1493,9 +1520,9 @@ func (ss *stageState) savePanes(bundID string,
panesInBundle []bundlePane) {
// buildTriggeredBundle must be called with the stage.mu lock held.
// When in discarding mode, returns 0.
// When in accumulating mode, returns the number of fired elements to maintain
a correct pending count.
-func (ss *stageState) buildTriggeredBundle(em *ElementManager, key []byte, win
typex.Window) int {
+func (ss *stageState) buildTriggeredBundle(em *ElementManager, key string, win
typex.Window) int {
var toProcess []element
- dnt := ss.pendingByKeys[string(key)]
+ dnt := ss.pendingByKeys[key]
var notYet []element
rb := RunBundle{StageID: ss.ID, BundleID: "agg-" + em.nextBundID(),
Watermark: ss.input}
@@ -1524,7 +1551,7 @@ func (ss *stageState) buildTriggeredBundle(em
*ElementManager, key []byte, win t
}
dnt.elements = append(dnt.elements, notYet...)
if dnt.elements.Len() == 0 {
- delete(ss.pendingByKeys, string(key))
+ delete(ss.pendingByKeys, key)
} else {
// Ensure the heap invariants are maintained.
heap.Init(&dnt.elements)
@@ -1537,7 +1564,7 @@ func (ss *stageState) buildTriggeredBundle(em
*ElementManager, key []byte, win t
{
win: win,
key: string(key),
- pane: ss.state[LinkID{}][win][string(key)].Pane,
+ pane: ss.state[LinkID{}][win][key].Pane,
},
}
@@ -1545,7 +1572,7 @@ func (ss *stageState) buildTriggeredBundle(em
*ElementManager, key []byte, win t
func() string { return rb.BundleID },
toProcess,
ss.input,
- singleSet(string(key)),
+ singleSet(key),
nil,
panesInBundle,
)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 772c3a9ebb8..d0daa991fd2 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -153,6 +153,7 @@ func executePipeline(ctx context.Context, wks
map[string]*worker.W, j *jobservic
topo := prepro.preProcessGraph(comps, j)
ts := comps.GetTransforms()
+ pcols := comps.GetPcollections()
config := engine.Config{}
m := j.PipelineOptions().AsMap()
@@ -167,6 +168,18 @@ func executePipeline(ctx context.Context, wks
map[string]*worker.W, j *jobservic
}
}
+ if streaming, ok := m["beam:option:streaming:v1"].(bool); ok {
+ config.StreamingMode = streaming
+ }
+
+ // Set StreamingMode to true if there is any unbounded PCollection.
+ for _, pcoll := range pcols {
+ if pcoll.GetIsBounded() == pipepb.IsBounded_UNBOUNDED {
+ config.StreamingMode = true
+ break
+ }
+ }
+
em := engine.NewElementManager(config)
// TODO move this loop and code into the preprocessor instead.
diff --git a/sdks/python/apache_beam/runners/portability/prism_runner_test.py
b/sdks/python/apache_beam/runners/portability/prism_runner_test.py
index 00116e123ce..4c4c77c83cd 100644
--- a/sdks/python/apache_beam/runners/portability/prism_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/prism_runner_test.py
@@ -35,10 +35,14 @@ from parameterized import parameterized
import apache_beam as beam
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PortableOptions
+from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.runners.portability import portable_runner_test
from apache_beam.runners.portability import prism_runner
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms import trigger
+from apache_beam.transforms import window
from apache_beam.utils import shared
# Run as
@@ -64,6 +68,8 @@ class
PrismRunnerTest(portable_runner_test.PortableRunnerTest):
self.environment_type = None
self.environment_config = None
self.enable_commit = False
+ self.streaming = False
+ self.allow_unsafe_triggers = False
def setUp(self):
self.enable_commit = False
@@ -175,6 +181,9 @@ class
PrismRunnerTest(portable_runner_test.PortableRunnerTest):
options.view_as(
PortableOptions).environment_options = self.environment_options
+ options.view_as(StandardOptions).streaming = self.streaming
+ options.view_as(
+ TypeOptions).allow_unsafe_triggers = self.allow_unsafe_triggers
return options
# Can't read host files from within docker, read a "local" file there.
@@ -225,7 +234,66 @@ class
PrismRunnerTest(portable_runner_test.PortableRunnerTest):
def test_metrics(self):
super().test_metrics(check_bounded_trie=False)
- # Inherits all other tests.
+ def construct_timestamped(k, t):
+ return window.TimestampedValue((k, t), t)
+
+ def format_result(k, vs):
+ return ('%s-%s' % (k, len(list(vs))), set(vs))
+
+ def test_after_count_trigger_batch(self):
+ self.allow_unsafe_triggers = True
+ with self.create_pipeline() as p:
+ result = (
+ p
+ | beam.Create([1, 2, 3, 4, 5, 10, 11])
+ | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
+ #A1, A2, A3, A4, A5, A10, A11, B6, B7, B8, B9, B10, B15, B16
+ | beam.MapTuple(PrismRunnerTest.construct_timestamped)
+ | beam.WindowInto(
+ window.FixedWindows(10),
+ trigger=trigger.AfterCount(3),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING,
+ )
+ | beam.GroupByKey()
+ | beam.MapTuple(PrismRunnerTest.format_result))
+ assert_that(
+ result,
+ equal_to(
+ list([
+ ('A-5', {1, 2, 3, 4, 5}),
+ ('A-2', {10, 11}),
+ ('B-4', {6, 7, 8, 9}),
+ ('B-3', {10, 15, 16}),
+ ])))
+
+ def test_after_count_trigger_streaming(self):
+ self.allow_unsafe_triggers = True
+ self.streaming = True
+ with self.create_pipeline() as p:
+ result = (
+ p
+ | beam.Create([1, 2, 3, 4, 5, 10, 11])
+ | beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
+ #A1, A2, A3, A4, A5, A10, A11, B6, B7, B8, B9, B10, B15, B16
+ | beam.MapTuple(PrismRunnerTest.construct_timestamped)
+ | beam.WindowInto(
+ window.FixedWindows(10),
+ trigger=trigger.AfterCount(3),
+ accumulation_mode=trigger.AccumulationMode.DISCARDING,
+ )
+ | beam.GroupByKey()
+ | beam.MapTuple(PrismRunnerTest.format_result))
+ assert_that(
+ result,
+ equal_to(
+ list([
+ ('A-3', {1, 2, 3}),
+ ('A-2', {4, 5}),
+ ('A-2', {10, 11}),
+ ('B-3', {6, 7, 8}),
+ ('B-1', {9}),
+ ('B-3', {10, 15, 16}),
+ ])))
class PrismJobServerTest(unittest.TestCase):