This is an automated email from the ASF dual-hosted git repository.

shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new fa0eef96567 Add retry in connecting manager in MultiProcessShared 
(#38456)
fa0eef96567 is described below

commit fa0eef96567ee886cdd710421879b8c7aa3e03d3
Author: Shunping Huang <[email protected]>
AuthorDate: Tue May 12 15:36:09 2026 -0400

    Add retry in connecting manager in MultiProcessShared (#38456)
---
 .../apache_beam/utils/multi_process_shared.py      | 17 ++++++++++--
 .../apache_beam/utils/multi_process_shared_test.py | 30 +++++++++++++++++++++-
 2 files changed, 44 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py 
b/sdks/python/apache_beam/utils/multi_process_shared.py
index b05fdd305a6..13a0f13e617 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -38,6 +38,8 @@ from typing import TypeVar
 
 import fasteners
 
+from apache_beam.utils import retry
+
 # In some python versions, there is a bug where AutoProxy doesn't handle
 # the kwarg 'manager_owned'. We implement our own backup here to make sure
 # we avoid this problem. More info here:
@@ -391,10 +393,21 @@ class MultiProcessShared(Generic[T]):
               manager = _SingletonRegistrar(
                   address=(host, int(port)), authkey=AUTH_KEY)
               multiprocessing.current_process().authkey = AUTH_KEY
-              try:
+
+              retryable_exceptions = (ConnectionError, EOFError)
+
+              @retry.with_exponential_backoff(
+                  num_retries=5,
+                  initial_delay_secs=0.1,
+                  retry_filter=lambda exn: isinstance(
+                      exn, retryable_exceptions))
+              def connect_manager():
                 manager.connect()
+
+              try:
+                connect_manager()
                 self._manager = manager
-              except ConnectionError:
+              except retryable_exceptions:
                 # The server is no longer good, assume it died.
                 os.unlink(address_file)
 
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py 
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index 3c74903b8d9..18ed49c6fa1 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -23,6 +23,7 @@ import tempfile
 import threading
 import unittest
 from typing import Any
+from unittest.mock import patch
 
 from apache_beam.utils import multi_process_shared
 
@@ -293,7 +294,8 @@ class MultiProcessSharedSpawnProcessTest(unittest.TestCase):
                 'mix1',
                 'mix2',
                 'test_process_exit',
-                'thundering_herd_test']:
+                'thundering_herd_test',
+                'transient_test']:
       for ext in ['', '.address', '.address.error']:
         try:
           os.remove(os.path.join(tempdir, tag + ext))
@@ -461,6 +463,32 @@ class 
MultiProcessSharedSpawnProcessTest(unittest.TestCase):
     except Exception:
       pass
 
+  def test_transient_connection_error_recovery(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter, tag='transient_test', always_proxy=True, spawn_process=True)
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter, tag='transient_test', always_proxy=True, spawn_process=True)
+
+    counter1 = shared1.acquire()
+
+    orig_connect = multi_process_shared._SingletonRegistrar.connect
+    connect_calls = [0]
+
+    def side_effect_connect(self_mgr, *args, **kwargs):
+      connect_calls[0] += 1
+      if connect_calls[0] == 1:
+        raise ConnectionError("Simulated transient connection failure")
+      return orig_connect(self_mgr, *args, **kwargs)
+
+    with patch.object(multi_process_shared._SingletonRegistrar,
+                      'connect',
+                      autospec=True,
+                      side_effect=side_effect_connect):
+      counter2 = shared2.acquire()
+
+    self.assertEqual(counter1.increment(), 1)
+    self.assertEqual(counter2.increment(), 2)
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to