This is an automated email from the ASF dual-hosted git repository.

xqhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7cee7bb9f4b [Python] Fix: Propagate resource hints through 
with_exception_handling (#36090)
7cee7bb9f4b is described below

commit 7cee7bb9f4b72974a233eec29b522688db09d762
Author: Ian Liao <[email protected]>
AuthorDate: Tue Sep 16 13:20:26 2025 -0700

    [Python] Fix: Propagate resource hints through with_exception_handling 
(#36090)
    
    * Implement two-way propagation for resource hint, fix Python 
with_exception_handling + JAX-on-Beam = pipeline failure
    
    * Add unit test for resource hint propagation in ParDo.
    
    * Propagate resource hint in a more intuiative way
---
 sdks/python/apache_beam/transforms/core.py       | 21 ++++--
 sdks/python/apache_beam/transforms/core_test.py  | 89 ++++++++++++++++++++++++
 sdks/python/apache_beam/transforms/ptransform.py |  4 ++
 3 files changed, 108 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 1bfc732d13a..2304faf478f 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1678,7 +1678,8 @@ class ParDo(PTransformWithSideInputs):
         timeout,
         error_handler,
         on_failure_callback,
-        allow_unsafe_userstate_in_process)
+        allow_unsafe_userstate_in_process,
+        self.get_resource_hints())
 
   def with_error_handler(self, error_handler, **exception_handling_kwargs):
     """An alias for `with_exception_handling(error_handler=error_handler, ...)`
@@ -2284,7 +2285,8 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
       timeout,
       error_handler,
       on_failure_callback,
-      allow_unsafe_userstate_in_process):
+      allow_unsafe_userstate_in_process,
+      resource_hints):
     if partial and use_subprocess:
       raise ValueError('partial and use_subprocess are mutually incompatible.')
     self._fn = fn
@@ -2301,6 +2303,7 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
     self._error_handler = error_handler
     self._on_failure_callback = on_failure_callback
     self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
+    self._resource_hints = resource_hints
 
   def expand(self, pcoll):
     if self._allow_unsafe_userstate_in_process:
@@ -2317,17 +2320,23 @@ class _ExceptionHandlingWrapper(ptransform.PTransform):
       wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout)
     else:
       wrapped_fn = self._fn
-    result = pcoll | ParDo(
+    pardo = ParDo(
         _ExceptionHandlingWrapperDoFn(
             wrapped_fn,
             self._dead_letter_tag,
             self._exc_class,
             self._partial,
             self._on_failure_callback,
-            self._allow_unsafe_userstate_in_process),
+            self._allow_unsafe_userstate_in_process,
+        ),
         *self._args,
-        **self._kwargs).with_outputs(
-            self._dead_letter_tag, main=self._main_tag, 
allow_unknown_tags=True)
+        **self._kwargs,
+    )
+    # This is the fix: propagate hints.
+    pardo.get_resource_hints().update(self._resource_hints)
+
+    result = pcoll | pardo.with_outputs(
+        self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
     #TODO(BEAM-18957): Fix when type inference supports tagged outputs.
     result[self._main_tag].element_type = self._fn.infer_output_type(
         pcoll.element_type)
diff --git a/sdks/python/apache_beam/transforms/core_test.py 
b/sdks/python/apache_beam/transforms/core_test.py
index 3e5e7670bf5..0d680c969c9 100644
--- a/sdks/python/apache_beam/transforms/core_test.py
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -30,6 +30,7 @@ import apache_beam as beam
 from apache_beam.coders import coders
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms.resources import ResourceHint
 from apache_beam.transforms.userstate import BagStateSpec
 from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
 from apache_beam.transforms.userstate import TimerSpec
@@ -416,6 +417,94 @@ class ExceptionHandlingTest(unittest.TestCase):
       assert_that(good, equal_to([0, 1, 2]), 'good')
       assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad')
 
+  def test_tags_with_exception_handling_then_resource_hint(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_exception_handling()
+        .with_resource_hints(tags='test_tag')
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_exception_handling_timeout_then_resource_hint(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_exception_handling(timeout=1)
+        .with_resource_hints(tags='test_tag')
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_resource_hint_then_exception_handling(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_resource_hints(tags='test_tag')
+        .with_exception_handling()
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
+  def test_tags_with_resource_hint_then_exception_handling_timeout(self):
+    class TagHint(ResourceHint):
+      urn = 'beam:resources:tags:v1'
+
+    ResourceHint.register_resource_hint('tags', TagHint)
+    with beam.Pipeline() as pipeline:
+      ok, unused_errors = (
+        pipeline
+        | beam.Create([1])
+        | beam.Map(lambda x: x)
+        .with_resource_hints(tags='test_tag')
+        .with_exception_handling(timeout=1)
+      )
+    pd = ok.producer.transform
+    self.assertIsInstance(pd, beam.transforms.core.ParDo)
+    while hasattr(pd.fn, 'fn'):
+      pd = pd.fn
+    self.assertEqual(
+        pd.get_resource_hints(),
+        {'beam:resources:tags:v1': b'test_tag'},
+    )
+
 
 def test_callablewrapper_typehint():
   T = TypeVar("T")
diff --git a/sdks/python/apache_beam/transforms/ptransform.py 
b/sdks/python/apache_beam/transforms/ptransform.py
index d2cf836713f..cac8a8fbd95 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -1164,6 +1164,10 @@ class _NamedPTransform(PTransform):
   def __rrshift__(self, label):
     return _NamedPTransform(self.transform, label)
 
+  def with_resource_hints(self, **kwargs):
+    self.transform.with_resource_hints(**kwargs)
+    return self
+
   def __getattr__(self, attr):
     transform_attr = getattr(self.transform, attr)
     if callable(transform_attr):

Reply via email to