This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0282abbff3b Allow multiprocessshared to spawn process and delete 
directly with obj (#37112)
0282abbff3b is described below

commit 0282abbff3b0933c47f3c4c21ae672ceee3f0d2c
Author: RuiLong J. <[email protected]>
AuthorDate: Mon Feb 2 11:34:07 2026 -0800

    Allow multiprocessshared to spawn process and delete directly with obj 
(#37112)
    
    * Allow multiprocessshared to spawn process and delete directly with obj
    
    * Remove oom protection
    
    * Resolve comments
    
    * Rename unsafe_hard_delete for the proxy object to prevent collision
    
    * Remove support for proxy on proxy to avoid complexity
    
    * Fix import order
    
    * Update reap test to be compatiable for windows
    
    * Update print to logging
    
    * Try to tearDown test in a cleaner way
    
    * Try patching atexit call to prevent hanging on window
    
    * Try weakref so windows can GC the process
    
    * Try GC manually to make sure p is cleaned up
    
    * Use a different way to check if parent is alive
    
    * Close the pipe atexit as well
---
 .../apache_beam/utils/multi_process_shared.py      | 239 ++++++++++++++++++---
 .../apache_beam/utils/multi_process_shared_test.py | 225 ++++++++++++++++++-
 2 files changed, 436 insertions(+), 28 deletions(-)

diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py 
b/sdks/python/apache_beam/utils/multi_process_shared.py
index aecb1284a1d..de4b94bc5da 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -22,11 +22,14 @@ on it via rpc.
 """
 # pytype: skip-file
 
+import atexit
 import logging
 import multiprocessing.managers
 import os
 import tempfile
 import threading
+import time
+import traceback
 from typing import Any
 from typing import Callable
 from typing import Dict
@@ -79,6 +82,10 @@ class _SingletonProxy:
     assert self._SingletonProxy_valid
     self._SingletonProxy_valid = False
 
+  def singletonProxy_unsafe_hard_delete(self):
+    assert self._SingletonProxy_valid
+    self._SingletonProxy_entry.unsafe_hard_delete()
+
   def __getattr__(self, name):
     if not self._SingletonProxy_valid:
       raise RuntimeError('Entry was released.')
@@ -105,13 +112,16 @@ class _SingletonProxy:
     dir = self._SingletonProxy_entry.obj.__dir__()
     dir.append('singletonProxy_call__')
     dir.append('singletonProxy_release')
+    dir.append('singletonProxy_unsafe_hard_delete')
     return dir
 
 
 class _SingletonEntry:
   """Represents a single, refcounted entry in this process."""
-  def __init__(self, constructor, initialize_eagerly=True):
+  def __init__(
+      self, constructor, initialize_eagerly=True, hard_delete_callback=None):
     self.constructor = constructor
+    self._hard_delete_callback = hard_delete_callback
     self.refcount = 0
     self.lock = threading.Lock()
     if initialize_eagerly:
@@ -141,14 +151,28 @@ class _SingletonEntry:
       if self.initialied:
         del self.obj
         self.initialied = False
+      if self._hard_delete_callback:
+        self._hard_delete_callback()
 
 
 class _SingletonManager:
   entries: Dict[Any, Any] = {}
 
-  def register_singleton(self, constructor, tag, initialize_eagerly=True):
+  def __init__(self):
+    self._hard_delete_callback = None
+
+  def set_hard_delete_callback(self, callback):
+    self._hard_delete_callback = callback
+
+  def register_singleton(
+      self,
+      constructor,
+      tag,
+      initialize_eagerly=True,
+      hard_delete_callback=None):
     assert tag not in self.entries, tag
-    self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly)
+    self.entries[tag] = _SingletonEntry(
+        constructor, initialize_eagerly, hard_delete_callback)
 
   def has_singleton(self, tag):
     return tag in self.entries
@@ -160,7 +184,7 @@ class _SingletonManager:
     return self.entries[tag].release(obj)
 
   def unsafe_hard_delete_singleton(self, tag):
-    return self.entries[tag].unsafe_hard_delete()
+    self.entries[tag].unsafe_hard_delete()
 
 
 _process_level_singleton_manager = _SingletonManager()
@@ -203,6 +227,87 @@ class _AutoProxyWrapper:
   def get_auto_proxy_object(self):
     return self._proxyObject
 
+  def unsafe_hard_delete(self):
+    self._proxyObject.unsafe_hard_delete()
+
+
+def _run_server_process(address_file, tag, constructor, authkey, life_line):
+  """
+    Runs in a separate process.
+    Includes a 'Suicide Pact' monitor: If parent dies, I die.
+    """
+  parent_pid = os.getppid()
+
+  def cleanup_files():
+    logging.info("Server process exiting. Deleting files for %s", tag)
+    try:
+      if os.path.exists(address_file):
+        os.remove(address_file)
+      if os.path.exists(address_file + ".error"):
+        os.remove(address_file + ".error")
+    except Exception as e:
+      logging.warning('Failed to cleanup files for tag %s: %s', tag, e)
+
+  def handle_unsafe_hard_delete():
+    cleanup_files()
+    os._exit(0)
+
+  def _monitor_parent():
+    """Checks if parent is alive every second."""
+    while True:
+      try:
+        # This will break if parent dies.
+        life_line.recv_bytes()
+      except (EOFError, OSError, BrokenPipeError):
+        logging.warning(
+            "Process %s detected Parent %s died. Self-destructing.",
+            os.getpid(),
+            parent_pid)
+        cleanup_files()
+        os._exit(0)
+      time.sleep(0.5)
+
+  atexit.register(cleanup_files)
+
+  try:
+    t = threading.Thread(target=_monitor_parent, daemon=True)
+
+    logging.getLogger().setLevel(logging.INFO)
+    multiprocessing.current_process().authkey = authkey
+
+    serving_manager = _SingletonRegistrar(
+        address=('localhost', 0), authkey=authkey)
+    _process_level_singleton_manager.set_hard_delete_callback(
+        handle_unsafe_hard_delete)
+    _process_level_singleton_manager.register_singleton(
+        constructor,
+        tag,
+        initialize_eagerly=True,
+        hard_delete_callback=handle_unsafe_hard_delete)
+    # Start monitoring parent after initialisation is done to avoid
+    # potential race conditions.
+    t.start()
+
+    server = serving_manager.get_server()
+    logging.info(
+        'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address)
+
+    with open(address_file + '.tmp', 'w') as fout:
+      fout.write('%s:%d' % server.address)
+    os.rename(address_file + '.tmp', address_file)
+
+    server.serve_forever()
+
+  except Exception:
+    tb = traceback.format_exc()
+    try:
+      with open(address_file + ".error.tmp", 'w') as fout:
+        fout.write(tb)
+      os.rename(address_file + ".error.tmp", address_file + ".error")
+    except Exception:
+      logging.error("CRITICAL ERROR IN SHARED SERVER:\n%s", tb)
+    os._exit(1)
+
 
 class MultiProcessShared(Generic[T]):
   """MultiProcessShared is used to share a single object across processes.
@@ -252,7 +357,8 @@ class MultiProcessShared(Generic[T]):
       tag: Any,
       *,
       path: str = tempfile.gettempdir(),
-      always_proxy: Optional[bool] = None):
+      always_proxy: Optional[bool] = None,
+      spawn_process: bool = False):
     self._constructor = constructor
     self._tag = tag
     self._path = path
@@ -262,6 +368,7 @@ class MultiProcessShared(Generic[T]):
     self._rpc_address = None
     self._cross_process_lock = fasteners.InterProcessLock(
         os.path.join(self._path, self._tag) + '.lock')
+    self._spawn_process = spawn_process
 
   def _get_manager(self):
     if self._manager is None:
@@ -301,6 +408,11 @@ class MultiProcessShared(Generic[T]):
     # Caveat: They must always agree, as they will be ignored if the object
     # is already constructed.
     singleton = self._get_manager().acquire_singleton(self._tag)
+    # Trigger a sweep of zombie processes.
+    # calling active_children() has the side-effect of joining any finished
+    # processes, effectively reaping zombies from previous unsafe_hard_deletes.
+    if self._spawn_process:
+      multiprocessing.active_children()
     return _AutoProxyWrapper(singleton)
 
   def release(self, obj):
@@ -318,22 +430,101 @@ class MultiProcessShared(Generic[T]):
     self._get_manager().unsafe_hard_delete_singleton(self._tag)
 
   def _create_server(self, address_file):
-    # We need to be able to authenticate with both the manager and the process.
-    self._serving_manager = _SingletonRegistrar(
-        address=('localhost', 0), authkey=AUTH_KEY)
-    multiprocessing.current_process().authkey = AUTH_KEY
-    # Initialize eagerly to avoid acting as the server if there are issues.
-    # Note, however, that _create_server itself is called lazily.
-    _process_level_singleton_manager.register_singleton(
-        self._constructor, self._tag, initialize_eagerly=True)
-    self._server = self._serving_manager.get_server()
-    logging.info(
-        'Starting proxy server at %s for shared %s',
-        self._server.address,
-        self._tag)
-    with open(address_file + '.tmp', 'w') as fout:
-      fout.write('%s:%d' % self._server.address)
-    os.rename(address_file + '.tmp', address_file)
-    t = threading.Thread(target=self._server.serve_forever, daemon=True)
-    t.start()
-    logging.info('Done starting server')
+    if self._spawn_process:
+      error_file = address_file + ".error"
+
+      if os.path.exists(error_file):
+        try:
+          os.remove(error_file)
+        except OSError:
+          pass
+
+      # Create a pipe to connect with child process
+      # used to clean up child process if parent dies
+      reader, writer = multiprocessing.Pipe(duplex=False)
+      self._life_line = writer
+
+      ctx = multiprocessing.get_context('spawn')
+      p = ctx.Process(
+          target=_run_server_process,
+          args=(address_file, self._tag, self._constructor, AUTH_KEY, reader),
+          daemon=False  # Must be False for nested proxies
+      )
+      p.start()
+      logging.info("Parent: Waiting for %s to write address file...", 
self._tag)
+
+      def cleanup_process():
+        if self._life_line:
+          self._life_line.close()
+        if p.is_alive():
+          logging.info(
+              "Parent: Terminating server process %s for %s", p.pid, self._tag)
+          p.terminate()
+          p.join()
+        try:
+          if os.path.exists(address_file):
+            os.remove(address_file)
+          if os.path.exists(error_file):
+            os.remove(error_file)
+        except Exception as e:
+          logging.warning(
+              'Failed to cleanup files for tag %s in atexit handler: %s',
+              self._tag,
+              e)
+
+      atexit.register(cleanup_process)
+
+      start_time = time.time()
+      last_log = start_time
+      while True:
+        if os.path.exists(address_file):
+          break
+
+        if os.path.exists(error_file):
+          with open(error_file, 'r') as f:
+            error_msg = f.read()
+          try:
+            os.remove(error_file)
+          except OSError:
+            pass
+
+          if p.is_alive(): p.terminate()
+          raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}")
+
+        if not p.is_alive():
+          exit_code = p.exitcode
+          raise RuntimeError(
+              "Shared Server Process died unexpectedly"
+              f" with exit code {exit_code}")
+
+        if time.time() - last_log > 300:
+          logging.warning(
+              "Still waiting for %s to initialize... %ss elapsed)",
+              self._tag,
+              int(time.time() - start_time))
+          last_log = time.time()
+
+        time.sleep(0.05)
+
+      logging.info('External process successfully started for %s', self._tag)
+    else:
+      # We need to be able to authenticate with both the manager
+      # and the process.
+      self._serving_manager = _SingletonRegistrar(
+          address=('localhost', 0), authkey=AUTH_KEY)
+      multiprocessing.current_process().authkey = AUTH_KEY
+      # Initialize eagerly to avoid acting as the server if there are issues.
+      # Note, however, that _create_server itself is called lazily.
+      _process_level_singleton_manager.register_singleton(
+          self._constructor, self._tag, initialize_eagerly=True)
+      self._server = self._serving_manager.get_server()
+      logging.info(
+          'Starting proxy server at %s for shared %s',
+          self._server.address,
+          self._tag)
+      with open(address_file + '.tmp', 'w') as fout:
+        fout.write('%s:%d' % self._server.address)
+      os.rename(address_file + '.tmp', address_file)
+      t = threading.Thread(target=self._server.serve_forever, daemon=True)
+      t.start()
+      logging.info('Done starting server')
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py 
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index 0b795763236..7b2b11857bf 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -17,6 +17,9 @@
 # pytype: skip-file
 
 import logging
+import multiprocessing
+import os
+import tempfile
 import threading
 import unittest
 from typing import Any
@@ -178,8 +181,11 @@ class MultiProcessSharedTest(unittest.TestCase):
     self.assertEqual(counter1.increment(), 1)
     self.assertEqual(counter2.increment(), 2)
 
-    multi_process_shared.MultiProcessShared(
-        Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()
+    try:
+      multi_process_shared.MultiProcessShared(
+          Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()
+    except Exception:
+      pass
 
     with self.assertRaises(Exception):
       counter1.get()
@@ -193,6 +199,37 @@ class MultiProcessSharedTest(unittest.TestCase):
 
     self.assertEqual(counter3.increment(), 1)
 
+  def test_unsafe_hard_delete_autoproxywrapper(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter,
+        tag='test_unsafe_hard_delete_autoproxywrapper',
+        always_proxy=True)
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter,
+        tag='test_unsafe_hard_delete_autoproxywrapper',
+        always_proxy=True)
+
+    counter1 = shared1.acquire()
+    counter2 = shared2.acquire()
+    self.assertEqual(counter1.increment(), 1)
+    self.assertEqual(counter2.increment(), 2)
+
+    try:
+      counter2.singletonProxy_unsafe_hard_delete()
+    except Exception:
+      pass
+
+    with self.assertRaises(Exception):
+      counter1.get()
+    with self.assertRaises(Exception):
+      counter2.get()
+
+    counter3 = multi_process_shared.MultiProcessShared(
+        Counter,
+        tag='test_unsafe_hard_delete_autoproxywrapper',
+        always_proxy=True).acquire()
+    self.assertEqual(counter3.increment(), 1)
+
   def test_unsafe_hard_delete_no_op(self):
     shared1 = multi_process_shared.MultiProcessShared(
         Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)
@@ -204,8 +241,11 @@ class MultiProcessSharedTest(unittest.TestCase):
     self.assertEqual(counter1.increment(), 1)
     self.assertEqual(counter2.increment(), 2)
 
-    multi_process_shared.MultiProcessShared(
-        Counter, tag='no_tag_to_delete').unsafe_hard_delete()
+    try:
+      multi_process_shared.MultiProcessShared(
+          Counter, tag='no_tag_to_delete').unsafe_hard_delete()
+    except Exception:
+      pass
 
     self.assertEqual(counter1.increment(), 3)
     self.assertEqual(counter2.increment(), 4)
@@ -243,6 +283,183 @@ class MultiProcessSharedTest(unittest.TestCase):
       counter1.get()
 
 
+class MultiProcessSharedSpawnProcessTest(unittest.TestCase):
+  def setUp(self):
+    tempdir = tempfile.gettempdir()
+    for tag in ['basic',
+                'main',
+                'to_delete',
+                'mix1',
+                'mix2',
+                'test_process_exit',
+                'thundering_herd_test']:
+      for ext in ['', '.address', '.address.error']:
+        try:
+          os.remove(os.path.join(tempdir, tag + ext))
+        except OSError:
+          pass
+
+  def tearDown(self):
+    for p in multiprocessing.active_children():
+      if p.is_alive():
+        try:
+          p.kill()
+          p.join(timeout=1.0)
+        except Exception:
+          pass
+
+  def test_call(self):
+    shared = multi_process_shared.MultiProcessShared(
+        Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
+    self.assertEqual(shared.get(), 0)
+    self.assertEqual(shared.increment(), 1)
+    self.assertEqual(shared.increment(10), 11)
+    self.assertEqual(shared.increment(value=10), 21)
+    self.assertEqual(shared.get(), 21)
+
+  def test_unsafe_hard_delete_autoproxywrapper(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter, tag='to_delete', always_proxy=True, spawn_process=True)
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter, tag='to_delete', always_proxy=True, spawn_process=True)
+    counter3 = multi_process_shared.MultiProcessShared(
+        Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
+
+    counter1 = shared1.acquire()
+    counter2 = shared2.acquire()
+    self.assertEqual(counter1.increment(), 1)
+    self.assertEqual(counter2.increment(), 2)
+
+    try:
+      counter2.singletonProxy_unsafe_hard_delete()
+    except Exception:
+      pass
+
+    with self.assertRaises(Exception):
+      counter1.get()
+    with self.assertRaises(Exception):
+      counter2.get()
+
+    counter4 = multi_process_shared.MultiProcessShared(
+        Counter, tag='to_delete', always_proxy=True,
+        spawn_process=True).acquire()
+
+    self.assertEqual(counter3.increment(), 1)
+    self.assertEqual(counter4.increment(), 1)
+
+  def test_mix_usage(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire()
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire()
+
+    self.assertEqual(shared1.get(), 0)
+    self.assertEqual(shared1.increment(), 1)
+    self.assertEqual(shared2.get(), 0)
+    self.assertEqual(shared2.increment(), 1)
+
+  def test_process_exits_on_unsafe_hard_delete(self):
+    shared = multi_process_shared.MultiProcessShared(
+        Counter, tag='test_process_exit', always_proxy=True, 
spawn_process=True)
+    obj = shared.acquire()
+
+    self.assertEqual(obj.increment(), 1)
+
+    children = multiprocessing.active_children()
+    server_process = None
+    for p in children:
+      if p.pid != os.getpid() and p.is_alive():
+        server_process = p
+        break
+
+    self.assertIsNotNone(
+        server_process, "Could not find spawned server process")
+    try:
+      obj.singletonProxy_unsafe_hard_delete()
+    except Exception:
+      pass
+    server_process.join(timeout=5)
+
+    self.assertFalse(
+        server_process.is_alive(),
+        f"Server process {server_process.pid} is still alive after hard 
delete")
+    self.assertIsNotNone(
+        server_process.exitcode, "Process has no exit code (did not exit)")
+
+    with self.assertRaises(Exception):
+      obj.get()
+
+  def test_process_exits_on_unsafe_hard_delete_with_manager(self):
+    shared = multi_process_shared.MultiProcessShared(
+        Counter, tag='test_process_exit', always_proxy=True, 
spawn_process=True)
+    obj = shared.acquire()
+
+    self.assertEqual(obj.increment(), 1)
+
+    children = multiprocessing.active_children()
+    server_process = None
+    for p in children:
+      if p.pid != os.getpid() and p.is_alive():
+        server_process = p
+        break
+
+    self.assertIsNotNone(
+        server_process, "Could not find spawned server process")
+    try:
+      shared.unsafe_hard_delete()
+    except Exception:
+      pass
+    server_process.join(timeout=5)
+
+    self.assertFalse(
+        server_process.is_alive(),
+        f"Server process {server_process.pid} is still alive after hard 
delete")
+    self.assertIsNotNone(
+        server_process.exitcode, "Process has no exit code (did not exit)")
+
+    with self.assertRaises(Exception):
+      obj.get()
+
+  def test_zombie_reaping_on_acquire(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True)
+    obj = shared1.acquire()
+
+    children = multiprocessing.active_children()
+    server_pid = next(
+        p.pid for p in children if p.is_alive() and p.pid != os.getpid())
+
+    try:
+      obj.singletonProxy_unsafe_hard_delete()
+    except Exception:
+      pass
+
+    try:
+      os.kill(server_pid, 0)
+      is_zombie = True
+    except OSError:
+      is_zombie = False
+    self.assertTrue(
+        is_zombie,
+        f"Server process {server_pid} was reaped too early before acquire()")
+
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True)
+    _ = shared2.acquire()
+
+    # If reaping worked, our old server_pid should NOT be in this list.
+    current_children_pids = [p.pid for p in multiprocessing.active_children()]
+
+    self.assertNotIn(
+        server_pid,
+        current_children_pids,
+        f"Old server process {server_pid} was not reaped by acquire() sweep")
+    try:
+      shared2.unsafe_hard_delete()
+    except Exception:
+      pass
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to