This is an automated email from the ASF dual-hosted git repository.

shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c44a9ee80e Support custom id function in async_dofn (#36779)
8c44a9ee80e is described below

commit 8c44a9ee80ee554f08d727dd8b7468653ccc85ce
Author: Dustin Rhodes <[email protected]>
AuthorDate: Wed Nov 12 17:44:47 2025 -0800

    Support custom id function in async_dofn (#36779)
    
    * Allow for a custom id function other than the default hashing funciton.
    
    * fix formatting errors
    
    * Formatting Fix 2
    
    * fix linter errors
    
    * change element_ to _
---
 sdks/python/apache_beam/transforms/async_dofn.py   | 42 +++++++++++++---------
 .../apache_beam/transforms/async_dofn_test.py      | 34 ++++++++++++++++++
 2 files changed, 59 insertions(+), 17 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/async_dofn.py 
b/sdks/python/apache_beam/transforms/async_dofn.py
index d2fa90c8508..5e1c6d219f4 100644
--- a/sdks/python/apache_beam/transforms/async_dofn.py
+++ b/sdks/python/apache_beam/transforms/async_dofn.py
@@ -77,6 +77,7 @@ class AsyncWrapper(beam.DoFn):
       max_items_to_buffer=None,
       timeout=1,
       max_wait_time=0.5,
+      id_fn=None,
   ):
     """Wraps the sync_fn to create an asynchronous version.
 
@@ -101,6 +102,8 @@ class AsyncWrapper(beam.DoFn):
         locally before it goes in the queue of waiting work.
       max_wait_time: The maximum amount of sleep time while attempting to
         schedule an item.  Used in testing to ensure timeouts are met.
+      id_fn: A function that returns a hashable object from an element. This
+        will be used to track items instead of the element's default hash.
     """
     self._sync_fn = sync_fn
     self._uuid = uuid.uuid4().hex
@@ -108,6 +111,7 @@ class AsyncWrapper(beam.DoFn):
     self._timeout = timeout
     self._max_wait_time = max_wait_time
     self._timer_frequency = callback_frequency
+    self._id_fn = id_fn or (lambda x: x)
     if max_items_to_buffer is None:
       self._max_items_to_buffer = max(parallelism * 2, 10)
     else:
@@ -205,7 +209,8 @@ class AsyncWrapper(beam.DoFn):
       True if the item was scheduled False otherwise.
     """
     with AsyncWrapper._lock:
-      if element in AsyncWrapper._processing_elements[self._uuid]:
+      element_id = self._id_fn(element[1])
+      if element_id in AsyncWrapper._processing_elements[self._uuid]:
         logging.info('item %s already in processing elements', element)
         return True
       if self.accepting_items() or ignore_buffer:
@@ -214,7 +219,8 @@ class AsyncWrapper(beam.DoFn):
                 lambda: self.sync_fn_process(element, *args, **kwargs),
             )
         result.add_done_callback(self.decrement_items_in_buffer)
-        AsyncWrapper._processing_elements[self._uuid][element] = result
+        AsyncWrapper._processing_elements[self._uuid][element_id] = (
+            element, result)
         AsyncWrapper._items_in_buffer[self._uuid] += 1
         return True
       else:
@@ -345,9 +351,6 @@ class AsyncWrapper(beam.DoFn):
 
     to_process_local = list(to_process.read())
 
-    # For all elements that in local state but not processing state delete them
-    # from local state and cancel their futures.
-    to_remove = []
     key = None
     to_reschedule = []
     if to_process_local:
@@ -362,27 +365,32 @@ class AsyncWrapper(beam.DoFn):
     # given key.  Skip items in processing_elements which are for a different
     # key.
     with AsyncWrapper._lock:
-      for x in AsyncWrapper._processing_elements[self._uuid]:
-        if x[0] == key and x not in to_process_local:
+      processing_elements = AsyncWrapper._processing_elements[self._uuid]
+      to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local}
+      to_remove_ids = []
+      for element_id, (element, future) in processing_elements.items():
+        if element[0] == key and element_id not in to_process_local_ids:
           items_cancelled += 1
-          AsyncWrapper._processing_elements[self._uuid][x].cancel()
-          to_remove.append(x)
+          future.cancel()
+          to_remove_ids.append(element_id)
           logging.info(
-              'cancelling item %s which is no longer in processing state', x)
-      for x in to_remove:
-        AsyncWrapper._processing_elements[self._uuid].pop(x)
+              'cancelling item %s which is no longer in processing state',
+              element)
+      for element_id in to_remove_ids:
+        processing_elements.pop(element_id)
 
       # For all elements which have finished processing output their result.
       to_return = []
       finished_items = []
       for x in to_process_local:
         items_in_se_state += 1
-        if x in AsyncWrapper._processing_elements[self._uuid]:
-          if AsyncWrapper._processing_elements[self._uuid][x].done():
-            to_return.append(
-                AsyncWrapper._processing_elements[self._uuid][x].result())
+        x_id = self._id_fn(x[1])
+        if x_id in processing_elements:
+          _, future = processing_elements[x_id]
+          if future.done():
+            to_return.append(future.result())
             finished_items.append(x)
-            AsyncWrapper._processing_elements[self._uuid].pop(x)
+            processing_elements.pop(x_id)
             items_finished += 1
           else:
             items_not_yet_finished += 1
diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py 
b/sdks/python/apache_beam/transforms/async_dofn_test.py
index 7577e215d1c..fe75de05ccd 100644
--- a/sdks/python/apache_beam/transforms/async_dofn_test.py
+++ b/sdks/python/apache_beam/transforms/async_dofn_test.py
@@ -119,6 +119,40 @@ class AsyncTest(unittest.TestCase):
         expected_count,
     )
 
+  def test_custom_id_fn(self):
+    class CustomIdObject:
+      def __init__(self, element_id, value):
+        self.element_id = element_id
+        self.value = value
+
+      def __hash__(self):
+        return hash(self.element_id)
+
+      def __eq__(self, other):
+        return self.element_id == other.element_id
+
+    dofn = BasicDofn()
+    async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id)
+    async_dofn.setup()
+    fake_bag_state = FakeBagState([])
+    fake_timer = FakeTimer(0)
+    msg1 = ('key1', CustomIdObject(1, 'a'))
+    msg2 = ('key1', CustomIdObject(1, 'b'))
+
+    result = async_dofn.process(
+        msg1, to_process=fake_bag_state, timer=fake_timer)
+    self.assertEqual(result, [])
+
+    # The second message should be a no-op as it has the same id.
+    result = async_dofn.process(
+        msg2, to_process=fake_bag_state, timer=fake_timer)
+    self.assertEqual(result, [])
+
+    self.wait_for_empty(async_dofn)
+    result = async_dofn.commit_finished_items(fake_bag_state, fake_timer)
+    self.check_output(result, [('key1', msg1[1])])
+    self.assertEqual(fake_bag_state.items, [])
+
   def test_basic(self):
     # Setup an async dofn and send a message in to process.
     dofn = BasicDofn()

Reply via email to