This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0282abbff3b Allow multiprocessshared to spawn process and delete
directly with obj (#37112)
0282abbff3b is described below
commit 0282abbff3b0933c47f3c4c21ae672ceee3f0d2c
Author: RuiLong J. <[email protected]>
AuthorDate: Mon Feb 2 11:34:07 2026 -0800
Allow multiprocessshared to spawn process and delete directly with obj
(#37112)
* Allow multiprocessshared to spawn process and delete directly with obj
* Remove oom protection
* Resolve comments
* Rename unsafe_hard_delete for the proxy object to prevent collision
* Remove support for proxy on proxy to avoid complexity
* Fix import order
* Update reap test to be compatiable for windows
* Update print to logging
* Try to tearDown test in a cleaner way
* Try patching atexit call to prevent hanging on window
* Try weakref so windows can GC the process
* Try GC manually to make sure p is cleaned up
* Use a different way to check if parent is alive
* Close the pipe atexit as well
---
.../apache_beam/utils/multi_process_shared.py | 239 ++++++++++++++++++---
.../apache_beam/utils/multi_process_shared_test.py | 225 ++++++++++++++++++-
2 files changed, 436 insertions(+), 28 deletions(-)
diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py
b/sdks/python/apache_beam/utils/multi_process_shared.py
index aecb1284a1d..de4b94bc5da 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -22,11 +22,14 @@ on it via rpc.
"""
# pytype: skip-file
+import atexit
import logging
import multiprocessing.managers
import os
import tempfile
import threading
+import time
+import traceback
from typing import Any
from typing import Callable
from typing import Dict
@@ -79,6 +82,10 @@ class _SingletonProxy:
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False
+ def singletonProxy_unsafe_hard_delete(self):
+ assert self._SingletonProxy_valid
+ self._SingletonProxy_entry.unsafe_hard_delete()
+
def __getattr__(self, name):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
@@ -105,13 +112,16 @@ class _SingletonProxy:
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
dir.append('singletonProxy_release')
+ dir.append('singletonProxy_unsafe_hard_delete')
return dir
class _SingletonEntry:
"""Represents a single, refcounted entry in this process."""
- def __init__(self, constructor, initialize_eagerly=True):
+ def __init__(
+ self, constructor, initialize_eagerly=True, hard_delete_callback=None):
self.constructor = constructor
+ self._hard_delete_callback = hard_delete_callback
self.refcount = 0
self.lock = threading.Lock()
if initialize_eagerly:
@@ -141,14 +151,28 @@ class _SingletonEntry:
if self.initialied:
del self.obj
self.initialied = False
+ if self._hard_delete_callback:
+ self._hard_delete_callback()
class _SingletonManager:
entries: Dict[Any, Any] = {}
- def register_singleton(self, constructor, tag, initialize_eagerly=True):
+ def __init__(self):
+ self._hard_delete_callback = None
+
+ def set_hard_delete_callback(self, callback):
+ self._hard_delete_callback = callback
+
+ def register_singleton(
+ self,
+ constructor,
+ tag,
+ initialize_eagerly=True,
+ hard_delete_callback=None):
assert tag not in self.entries, tag
- self.entries[tag] = _SingletonEntry(constructor, initialize_eagerly)
+ self.entries[tag] = _SingletonEntry(
+ constructor, initialize_eagerly, hard_delete_callback)
def has_singleton(self, tag):
return tag in self.entries
@@ -160,7 +184,7 @@ class _SingletonManager:
return self.entries[tag].release(obj)
def unsafe_hard_delete_singleton(self, tag):
- return self.entries[tag].unsafe_hard_delete()
+ self.entries[tag].unsafe_hard_delete()
_process_level_singleton_manager = _SingletonManager()
@@ -203,6 +227,87 @@ class _AutoProxyWrapper:
def get_auto_proxy_object(self):
return self._proxyObject
+ def unsafe_hard_delete(self):
+ self._proxyObject.unsafe_hard_delete()
+
+
+def _run_server_process(address_file, tag, constructor, authkey, life_line):
+ """
+ Runs in a separate process.
+ Includes a 'Suicide Pact' monitor: If parent dies, I die.
+ """
+ parent_pid = os.getppid()
+
+ def cleanup_files():
+ logging.info("Server process exiting. Deleting files for %s", tag)
+ try:
+ if os.path.exists(address_file):
+ os.remove(address_file)
+ if os.path.exists(address_file + ".error"):
+ os.remove(address_file + ".error")
+ except Exception as e:
+ logging.warning('Failed to cleanup files for tag %s: %s', tag, e)
+
+ def handle_unsafe_hard_delete():
+ cleanup_files()
+ os._exit(0)
+
+ def _monitor_parent():
+ """Checks if parent is alive every second."""
+ while True:
+ try:
+ # This will break if parent dies.
+ life_line.recv_bytes()
+ except (EOFError, OSError, BrokenPipeError):
+ logging.warning(
+ "Process %s detected Parent %s died. Self-destructing.",
+ os.getpid(),
+ parent_pid)
+ cleanup_files()
+ os._exit(0)
+ time.sleep(0.5)
+
+ atexit.register(cleanup_files)
+
+ try:
+ t = threading.Thread(target=_monitor_parent, daemon=True)
+
+ logging.getLogger().setLevel(logging.INFO)
+ multiprocessing.current_process().authkey = authkey
+
+ serving_manager = _SingletonRegistrar(
+ address=('localhost', 0), authkey=authkey)
+ _process_level_singleton_manager.set_hard_delete_callback(
+ handle_unsafe_hard_delete)
+ _process_level_singleton_manager.register_singleton(
+ constructor,
+ tag,
+ initialize_eagerly=True,
+ hard_delete_callback=handle_unsafe_hard_delete)
+ # Start monitoring parent after initialisation is done to avoid
+ # potential race conditions.
+ t.start()
+
+ server = serving_manager.get_server()
+ logging.info(
+ 'Process %s: Proxy serving %s at %s', os.getpid(), tag, server.address)
+
+ with open(address_file + '.tmp', 'w') as fout:
+ fout.write('%s:%d' % server.address)
+ os.rename(address_file + '.tmp', address_file)
+
+ server.serve_forever()
+
+ except Exception:
+ tb = traceback.format_exc()
+ try:
+ with open(address_file + ".error.tmp", 'w') as fout:
+ fout.write(tb)
+ os.rename(address_file + ".error.tmp", address_file + ".error")
+ except Exception:
+ logging.error("CRITICAL ERROR IN SHARED SERVER:\n%s", tb)
+ os._exit(1)
+
class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
@@ -252,7 +357,8 @@ class MultiProcessShared(Generic[T]):
tag: Any,
*,
path: str = tempfile.gettempdir(),
- always_proxy: Optional[bool] = None):
+ always_proxy: Optional[bool] = None,
+ spawn_process: bool = False):
self._constructor = constructor
self._tag = tag
self._path = path
@@ -262,6 +368,7 @@ class MultiProcessShared(Generic[T]):
self._rpc_address = None
self._cross_process_lock = fasteners.InterProcessLock(
os.path.join(self._path, self._tag) + '.lock')
+ self._spawn_process = spawn_process
def _get_manager(self):
if self._manager is None:
@@ -301,6 +408,11 @@ class MultiProcessShared(Generic[T]):
# Caveat: They must always agree, as they will be ignored if the object
# is already constructed.
singleton = self._get_manager().acquire_singleton(self._tag)
+ # Trigger a sweep of zombie processes.
+ # calling active_children() has the side-effect of joining any finished
+ # processes, effectively reaping zombies from previous unsafe_hard_deletes.
+ if self._spawn_process:
+ multiprocessing.active_children()
return _AutoProxyWrapper(singleton)
def release(self, obj):
@@ -318,22 +430,101 @@ class MultiProcessShared(Generic[T]):
self._get_manager().unsafe_hard_delete_singleton(self._tag)
def _create_server(self, address_file):
- # We need to be able to authenticate with both the manager and the process.
- self._serving_manager = _SingletonRegistrar(
- address=('localhost', 0), authkey=AUTH_KEY)
- multiprocessing.current_process().authkey = AUTH_KEY
- # Initialize eagerly to avoid acting as the server if there are issues.
- # Note, however, that _create_server itself is called lazily.
- _process_level_singleton_manager.register_singleton(
- self._constructor, self._tag, initialize_eagerly=True)
- self._server = self._serving_manager.get_server()
- logging.info(
- 'Starting proxy server at %s for shared %s',
- self._server.address,
- self._tag)
- with open(address_file + '.tmp', 'w') as fout:
- fout.write('%s:%d' % self._server.address)
- os.rename(address_file + '.tmp', address_file)
- t = threading.Thread(target=self._server.serve_forever, daemon=True)
- t.start()
- logging.info('Done starting server')
+ if self._spawn_process:
+ error_file = address_file + ".error"
+
+ if os.path.exists(error_file):
+ try:
+ os.remove(error_file)
+ except OSError:
+ pass
+
+ # Create a pipe to connect with child process
+ # used to clean up child process if parent dies
+ reader, writer = multiprocessing.Pipe(duplex=False)
+ self._life_line = writer
+
+ ctx = multiprocessing.get_context('spawn')
+ p = ctx.Process(
+ target=_run_server_process,
+ args=(address_file, self._tag, self._constructor, AUTH_KEY, reader),
+ daemon=False # Must be False for nested proxies
+ )
+ p.start()
+ logging.info("Parent: Waiting for %s to write address file...",
self._tag)
+
+ def cleanup_process():
+ if self._life_line:
+ self._life_line.close()
+ if p.is_alive():
+ logging.info(
+ "Parent: Terminating server process %s for %s", p.pid, self._tag)
+ p.terminate()
+ p.join()
+ try:
+ if os.path.exists(address_file):
+ os.remove(address_file)
+ if os.path.exists(error_file):
+ os.remove(error_file)
+ except Exception as e:
+ logging.warning(
+ 'Failed to cleanup files for tag %s in atexit handler: %s',
+ self._tag,
+ e)
+
+ atexit.register(cleanup_process)
+
+ start_time = time.time()
+ last_log = start_time
+ while True:
+ if os.path.exists(address_file):
+ break
+
+ if os.path.exists(error_file):
+ with open(error_file, 'r') as f:
+ error_msg = f.read()
+ try:
+ os.remove(error_file)
+ except OSError:
+ pass
+
+ if p.is_alive(): p.terminate()
+ raise RuntimeError(f"Shared Server Process crashed:\n{error_msg}")
+
+ if not p.is_alive():
+ exit_code = p.exitcode
+ raise RuntimeError(
+ "Shared Server Process died unexpectedly"
+ f" with exit code {exit_code}")
+
+ if time.time() - last_log > 300:
+ logging.warning(
+ "Still waiting for %s to initialize... %ss elapsed)",
+ self._tag,
+ int(time.time() - start_time))
+ last_log = time.time()
+
+ time.sleep(0.05)
+
+ logging.info('External process successfully started for %s', self._tag)
+ else:
+ # We need to be able to authenticate with both the manager
+ # and the process.
+ self._serving_manager = _SingletonRegistrar(
+ address=('localhost', 0), authkey=AUTH_KEY)
+ multiprocessing.current_process().authkey = AUTH_KEY
+ # Initialize eagerly to avoid acting as the server if there are issues.
+ # Note, however, that _create_server itself is called lazily.
+ _process_level_singleton_manager.register_singleton(
+ self._constructor, self._tag, initialize_eagerly=True)
+ self._server = self._serving_manager.get_server()
+ logging.info(
+ 'Starting proxy server at %s for shared %s',
+ self._server.address,
+ self._tag)
+ with open(address_file + '.tmp', 'w') as fout:
+ fout.write('%s:%d' % self._server.address)
+ os.rename(address_file + '.tmp', address_file)
+ t = threading.Thread(target=self._server.serve_forever, daemon=True)
+ t.start()
+ logging.info('Done starting server')
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index 0b795763236..7b2b11857bf 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -17,6 +17,9 @@
# pytype: skip-file
import logging
+import multiprocessing
+import os
+import tempfile
import threading
import unittest
from typing import Any
@@ -178,8 +181,11 @@ class MultiProcessSharedTest(unittest.TestCase):
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
- multi_process_shared.MultiProcessShared(
- Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()
+ try:
+ multi_process_shared.MultiProcessShared(
+ Counter, tag='test_unsafe_hard_delete').unsafe_hard_delete()
+ except Exception:
+ pass
with self.assertRaises(Exception):
counter1.get()
@@ -193,6 +199,37 @@ class MultiProcessSharedTest(unittest.TestCase):
self.assertEqual(counter3.increment(), 1)
+ def test_unsafe_hard_delete_autoproxywrapper(self):
+ shared1 = multi_process_shared.MultiProcessShared(
+ Counter,
+ tag='test_unsafe_hard_delete_autoproxywrapper',
+ always_proxy=True)
+ shared2 = multi_process_shared.MultiProcessShared(
+ Counter,
+ tag='test_unsafe_hard_delete_autoproxywrapper',
+ always_proxy=True)
+
+ counter1 = shared1.acquire()
+ counter2 = shared2.acquire()
+ self.assertEqual(counter1.increment(), 1)
+ self.assertEqual(counter2.increment(), 2)
+
+ try:
+ counter2.singletonProxy_unsafe_hard_delete()
+ except Exception:
+ pass
+
+ with self.assertRaises(Exception):
+ counter1.get()
+ with self.assertRaises(Exception):
+ counter2.get()
+
+ counter3 = multi_process_shared.MultiProcessShared(
+ Counter,
+ tag='test_unsafe_hard_delete_autoproxywrapper',
+ always_proxy=True).acquire()
+ self.assertEqual(counter3.increment(), 1)
+
def test_unsafe_hard_delete_no_op(self):
shared1 = multi_process_shared.MultiProcessShared(
Counter, tag='test_unsafe_hard_delete_no_op', always_proxy=True)
@@ -204,8 +241,11 @@ class MultiProcessSharedTest(unittest.TestCase):
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)
- multi_process_shared.MultiProcessShared(
- Counter, tag='no_tag_to_delete').unsafe_hard_delete()
+ try:
+ multi_process_shared.MultiProcessShared(
+ Counter, tag='no_tag_to_delete').unsafe_hard_delete()
+ except Exception:
+ pass
self.assertEqual(counter1.increment(), 3)
self.assertEqual(counter2.increment(), 4)
@@ -243,6 +283,183 @@ class MultiProcessSharedTest(unittest.TestCase):
counter1.get()
+class MultiProcessSharedSpawnProcessTest(unittest.TestCase):
+ def setUp(self):
+ tempdir = tempfile.gettempdir()
+ for tag in ['basic',
+ 'main',
+ 'to_delete',
+ 'mix1',
+ 'mix2',
+ 'test_process_exit',
+ 'thundering_herd_test']:
+ for ext in ['', '.address', '.address.error']:
+ try:
+ os.remove(os.path.join(tempdir, tag + ext))
+ except OSError:
+ pass
+
+ def tearDown(self):
+ for p in multiprocessing.active_children():
+ if p.is_alive():
+ try:
+ p.kill()
+ p.join(timeout=1.0)
+ except Exception:
+ pass
+
+ def test_call(self):
+ shared = multi_process_shared.MultiProcessShared(
+ Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
+ self.assertEqual(shared.get(), 0)
+ self.assertEqual(shared.increment(), 1)
+ self.assertEqual(shared.increment(10), 11)
+ self.assertEqual(shared.increment(value=10), 21)
+ self.assertEqual(shared.get(), 21)
+
+ def test_unsafe_hard_delete_autoproxywrapper(self):
+ shared1 = multi_process_shared.MultiProcessShared(
+ Counter, tag='to_delete', always_proxy=True, spawn_process=True)
+ shared2 = multi_process_shared.MultiProcessShared(
+ Counter, tag='to_delete', always_proxy=True, spawn_process=True)
+ counter3 = multi_process_shared.MultiProcessShared(
+ Counter, tag='basic', always_proxy=True, spawn_process=True).acquire()
+
+ counter1 = shared1.acquire()
+ counter2 = shared2.acquire()
+ self.assertEqual(counter1.increment(), 1)
+ self.assertEqual(counter2.increment(), 2)
+
+ try:
+ counter2.singletonProxy_unsafe_hard_delete()
+ except Exception:
+ pass
+
+ with self.assertRaises(Exception):
+ counter1.get()
+ with self.assertRaises(Exception):
+ counter2.get()
+
+ counter4 = multi_process_shared.MultiProcessShared(
+ Counter, tag='to_delete', always_proxy=True,
+ spawn_process=True).acquire()
+
+ self.assertEqual(counter3.increment(), 1)
+ self.assertEqual(counter4.increment(), 1)
+
+ def test_mix_usage(self):
+ shared1 = multi_process_shared.MultiProcessShared(
+ Counter, tag='mix1', always_proxy=True, spawn_process=False).acquire()
+ shared2 = multi_process_shared.MultiProcessShared(
+ Counter, tag='mix2', always_proxy=True, spawn_process=True).acquire()
+
+ self.assertEqual(shared1.get(), 0)
+ self.assertEqual(shared1.increment(), 1)
+ self.assertEqual(shared2.get(), 0)
+ self.assertEqual(shared2.increment(), 1)
+
+ def test_process_exits_on_unsafe_hard_delete(self):
+ shared = multi_process_shared.MultiProcessShared(
+ Counter, tag='test_process_exit', always_proxy=True,
spawn_process=True)
+ obj = shared.acquire()
+
+ self.assertEqual(obj.increment(), 1)
+
+ children = multiprocessing.active_children()
+ server_process = None
+ for p in children:
+ if p.pid != os.getpid() and p.is_alive():
+ server_process = p
+ break
+
+ self.assertIsNotNone(
+ server_process, "Could not find spawned server process")
+ try:
+ obj.singletonProxy_unsafe_hard_delete()
+ except Exception:
+ pass
+ server_process.join(timeout=5)
+
+ self.assertFalse(
+ server_process.is_alive(),
+ f"Server process {server_process.pid} is still alive after hard
delete")
+ self.assertIsNotNone(
+ server_process.exitcode, "Process has no exit code (did not exit)")
+
+ with self.assertRaises(Exception):
+ obj.get()
+
+ def test_process_exits_on_unsafe_hard_delete_with_manager(self):
+ shared = multi_process_shared.MultiProcessShared(
+ Counter, tag='test_process_exit', always_proxy=True,
spawn_process=True)
+ obj = shared.acquire()
+
+ self.assertEqual(obj.increment(), 1)
+
+ children = multiprocessing.active_children()
+ server_process = None
+ for p in children:
+ if p.pid != os.getpid() and p.is_alive():
+ server_process = p
+ break
+
+ self.assertIsNotNone(
+ server_process, "Could not find spawned server process")
+ try:
+ shared.unsafe_hard_delete()
+ except Exception:
+ pass
+ server_process.join(timeout=5)
+
+ self.assertFalse(
+ server_process.is_alive(),
+ f"Server process {server_process.pid} is still alive after hard
delete")
+ self.assertIsNotNone(
+ server_process.exitcode, "Process has no exit code (did not exit)")
+
+ with self.assertRaises(Exception):
+ obj.get()
+
+ def test_zombie_reaping_on_acquire(self):
+ shared1 = multi_process_shared.MultiProcessShared(
+ Counter, tag='test_zombie_reap', always_proxy=True, spawn_process=True)
+ obj = shared1.acquire()
+
+ children = multiprocessing.active_children()
+ server_pid = next(
+ p.pid for p in children if p.is_alive() and p.pid != os.getpid())
+
+ try:
+ obj.singletonProxy_unsafe_hard_delete()
+ except Exception:
+ pass
+
+ try:
+ os.kill(server_pid, 0)
+ is_zombie = True
+ except OSError:
+ is_zombie = False
+ self.assertTrue(
+ is_zombie,
+ f"Server process {server_pid} was reaped too early before acquire()")
+
+ shared2 = multi_process_shared.MultiProcessShared(
+ Counter, tag='unrelated_tag', always_proxy=True, spawn_process=True)
+ _ = shared2.acquire()
+
+ # If reaping worked, our old server_pid should NOT be in this list.
+ current_children_pids = [p.pid for p in multiprocessing.active_children()]
+
+ self.assertNotIn(
+ server_pid,
+ current_children_pids,
+ f"Old server process {server_pid} was not reaped by acquire() sweep")
+ try:
+ shared2.unsafe_hard_delete()
+ except Exception:
+ pass
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()