This is an automated email from the ASF dual-hosted git repository.
shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8c44a9ee80e Support custom id function in async_dofn (#36779)
8c44a9ee80e is described below
commit 8c44a9ee80ee554f08d727dd8b7468653ccc85ce
Author: Dustin Rhodes <[email protected]>
AuthorDate: Wed Nov 12 17:44:47 2025 -0800
Support custom id function in async_dofn (#36779)
* Allow for a custom id function other than the default hashing funciton.
* fix formatting errors
* Formatting Fix 2
* fix linter errors
* change element_ to _
---
sdks/python/apache_beam/transforms/async_dofn.py | 42 +++++++++++++---------
.../apache_beam/transforms/async_dofn_test.py | 34 ++++++++++++++++++
2 files changed, 59 insertions(+), 17 deletions(-)
diff --git a/sdks/python/apache_beam/transforms/async_dofn.py
b/sdks/python/apache_beam/transforms/async_dofn.py
index d2fa90c8508..5e1c6d219f4 100644
--- a/sdks/python/apache_beam/transforms/async_dofn.py
+++ b/sdks/python/apache_beam/transforms/async_dofn.py
@@ -77,6 +77,7 @@ class AsyncWrapper(beam.DoFn):
max_items_to_buffer=None,
timeout=1,
max_wait_time=0.5,
+ id_fn=None,
):
"""Wraps the sync_fn to create an asynchronous version.
@@ -101,6 +102,8 @@ class AsyncWrapper(beam.DoFn):
locally before it goes in the queue of waiting work.
max_wait_time: The maximum amount of sleep time while attempting to
schedule an item. Used in testing to ensure timeouts are met.
+ id_fn: A function that returns a hashable object from an element. This
+ will be used to track items instead of the element's default hash.
"""
self._sync_fn = sync_fn
self._uuid = uuid.uuid4().hex
@@ -108,6 +111,7 @@ class AsyncWrapper(beam.DoFn):
self._timeout = timeout
self._max_wait_time = max_wait_time
self._timer_frequency = callback_frequency
+ self._id_fn = id_fn or (lambda x: x)
if max_items_to_buffer is None:
self._max_items_to_buffer = max(parallelism * 2, 10)
else:
@@ -205,7 +209,8 @@ class AsyncWrapper(beam.DoFn):
True if the item was scheduled False otherwise.
"""
with AsyncWrapper._lock:
- if element in AsyncWrapper._processing_elements[self._uuid]:
+ element_id = self._id_fn(element[1])
+ if element_id in AsyncWrapper._processing_elements[self._uuid]:
logging.info('item %s already in processing elements', element)
return True
if self.accepting_items() or ignore_buffer:
@@ -214,7 +219,8 @@ class AsyncWrapper(beam.DoFn):
lambda: self.sync_fn_process(element, *args, **kwargs),
)
result.add_done_callback(self.decrement_items_in_buffer)
- AsyncWrapper._processing_elements[self._uuid][element] = result
+ AsyncWrapper._processing_elements[self._uuid][element_id] = (
+ element, result)
AsyncWrapper._items_in_buffer[self._uuid] += 1
return True
else:
@@ -345,9 +351,6 @@ class AsyncWrapper(beam.DoFn):
to_process_local = list(to_process.read())
- # For all elements that in local state but not processing state delete them
- # from local state and cancel their futures.
- to_remove = []
key = None
to_reschedule = []
if to_process_local:
@@ -362,27 +365,32 @@ class AsyncWrapper(beam.DoFn):
# given key. Skip items in processing_elements which are for a different
# key.
with AsyncWrapper._lock:
- for x in AsyncWrapper._processing_elements[self._uuid]:
- if x[0] == key and x not in to_process_local:
+ processing_elements = AsyncWrapper._processing_elements[self._uuid]
+ to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local}
+ to_remove_ids = []
+ for element_id, (element, future) in processing_elements.items():
+ if element[0] == key and element_id not in to_process_local_ids:
items_cancelled += 1
- AsyncWrapper._processing_elements[self._uuid][x].cancel()
- to_remove.append(x)
+ future.cancel()
+ to_remove_ids.append(element_id)
logging.info(
- 'cancelling item %s which is no longer in processing state', x)
- for x in to_remove:
- AsyncWrapper._processing_elements[self._uuid].pop(x)
+ 'cancelling item %s which is no longer in processing state',
+ element)
+ for element_id in to_remove_ids:
+ processing_elements.pop(element_id)
# For all elements which have finished processing output their result.
to_return = []
finished_items = []
for x in to_process_local:
items_in_se_state += 1
- if x in AsyncWrapper._processing_elements[self._uuid]:
- if AsyncWrapper._processing_elements[self._uuid][x].done():
- to_return.append(
- AsyncWrapper._processing_elements[self._uuid][x].result())
+ x_id = self._id_fn(x[1])
+ if x_id in processing_elements:
+ _, future = processing_elements[x_id]
+ if future.done():
+ to_return.append(future.result())
finished_items.append(x)
- AsyncWrapper._processing_elements[self._uuid].pop(x)
+ processing_elements.pop(x_id)
items_finished += 1
else:
items_not_yet_finished += 1
diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py
b/sdks/python/apache_beam/transforms/async_dofn_test.py
index 7577e215d1c..fe75de05ccd 100644
--- a/sdks/python/apache_beam/transforms/async_dofn_test.py
+++ b/sdks/python/apache_beam/transforms/async_dofn_test.py
@@ -119,6 +119,40 @@ class AsyncTest(unittest.TestCase):
expected_count,
)
+ def test_custom_id_fn(self):
+ class CustomIdObject:
+ def __init__(self, element_id, value):
+ self.element_id = element_id
+ self.value = value
+
+ def __hash__(self):
+ return hash(self.element_id)
+
+ def __eq__(self, other):
+ return self.element_id == other.element_id
+
+ dofn = BasicDofn()
+ async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id)
+ async_dofn.setup()
+ fake_bag_state = FakeBagState([])
+ fake_timer = FakeTimer(0)
+ msg1 = ('key1', CustomIdObject(1, 'a'))
+ msg2 = ('key1', CustomIdObject(1, 'b'))
+
+ result = async_dofn.process(
+ msg1, to_process=fake_bag_state, timer=fake_timer)
+ self.assertEqual(result, [])
+
+ # The second message should be a no-op as it has the same id.
+ result = async_dofn.process(
+ msg2, to_process=fake_bag_state, timer=fake_timer)
+ self.assertEqual(result, [])
+
+ self.wait_for_empty(async_dofn)
+ result = async_dofn.commit_finished_items(fake_bag_state, fake_timer)
+ self.check_output(result, [('key1', msg1[1])])
+ self.assertEqual(fake_bag_state.items, [])
+
def test_basic(self):
# Setup an async dofn and send a message in to process.
dofn = BasicDofn()