This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 14731066599 Fix multi_process_shared.py release with AutoProxy (#27607)
14731066599 is described below
commit 14731066599cd4fdb9971faba4c3995cf68b030d
Author: Danny McCormick <[email protected]>
AuthorDate: Mon Jul 24 12:29:22 2023 -0400
Fix multi_process_shared.py release with AutoProxy (#27607)
* Failing test to demonstrate the problem
* Allow release call to be proxied correctly
* clean up patch
* Cleanup comment
* Format
* unused imports
* Ignore bad lint rule
---
.../apache_beam/utils/multi_process_shared.py | 30 ++++++++++++++++++--
.../apache_beam/utils/multi_process_shared_test.py | 32 ++++++++++++++++++++++
2 files changed, 59 insertions(+), 3 deletions(-)
diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py
b/sdks/python/apache_beam/utils/multi_process_shared.py
index b9414d017f3..17bdf2c8489 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -36,6 +36,26 @@ from typing import TypeVar
import fasteners
+# In some python versions, there is a bug where AutoProxy doesn't handle
+# the kwarg 'manager_owned'. We implement our own backup here to make sure
+# we avoid this problem. More info here:
+#
https://stackoverflow.com/questions/46779860/multiprocessing-managers-and-custom-classes
+autoproxy = multiprocessing.managers.AutoProxy # type: ignore[attr-defined]
+
+
+def patched_autoproxy(
+ token,
+ serializer,
+ manager=None,
+ authkey=None,
+ exposed=None,
+ incref=True,
+ manager_owned=True):
+ return autoproxy(token, serializer, manager, authkey, exposed, incref)
+
+
+multiprocessing.managers.AutoProxy = patched_autoproxy # type:
ignore[attr-defined]
+
T = TypeVar('T')
AUTH_KEY = b'mps'
@@ -55,7 +75,7 @@ class _SingletonProxy:
raise RuntimeError('Entry was released.')
return self._SingletonProxy_entry.obj.__call__(*args, **kwargs)
- def _SingletonProxy_release(self):
+ def singletonProxy_release(self):
assert self._SingletonProxy_valid
self._SingletonProxy_valid = False
@@ -68,6 +88,7 @@ class _SingletonProxy:
# Needed for multiprocessing.managers's proxying.
dir = self._SingletonProxy_entry.obj.__dir__()
dir.append('singletonProxy_call__')
+ dir.append('singletonProxy_release')
return dir
@@ -92,7 +113,7 @@ class _SingletonEntry:
return _SingletonProxy(self)
def release(self, proxy):
- proxy._SingletonProxy_release()
+ proxy.singletonProxy_release()
with self.lock:
self.refcount -= 1
if self.refcount == 0:
@@ -151,6 +172,9 @@ class _AutoProxyWrapper:
def __getattr__(self, name):
return getattr(self._proxyObject, name)
+ def get_auto_proxy_object(self):
+ return self._proxyObject
+
class MultiProcessShared(Generic[T]):
"""MultiProcessShared is used to share a single object across processes.
@@ -252,7 +276,7 @@ class MultiProcessShared(Generic[T]):
return _AutoProxyWrapper(singleton)
def release(self, obj):
- self._manager.release_singleton(self._tag, obj)
+ self._manager.release_singleton(self._tag, obj.get_auto_proxy_object())
def _create_server(self, address_file):
# We need to be able to authenticate with both the manager and the process.
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index b5ae90da915..aeae9159d32 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -133,6 +133,38 @@ class MultiProcessSharedTest(unittest.TestCase):
with self.assertRaisesRegex(Exception, 'released'):
counter1.get()
+ def test_release_always_proxy(self):
+ shared1 = multi_process_shared.MultiProcessShared(
+ Counter, tag='test_release_always_proxy', always_proxy=True)
+ shared2 = multi_process_shared.MultiProcessShared(
+ Counter, tag='test_release_always_proxy', always_proxy=True)
+
+ counter1 = shared1.acquire()
+ counter2 = shared2.acquire()
+ self.assertEqual(counter1.increment(), 1)
+ self.assertEqual(counter2.increment(), 2)
+
+ counter1again = shared1.acquire()
+ self.assertEqual(counter1again.increment(), 3)
+
+ shared1.release(counter1)
+ shared2.release(counter2)
+
+ with self.assertRaisesRegex(Exception, 'released'):
+ counter1.get()
+ with self.assertRaisesRegex(Exception, 'released'):
+ counter2.get()
+
+ self.assertEqual(counter1again.get(), 3)
+
+ shared1.release(counter1again)
+
+ counter1New = shared1.acquire()
+ self.assertEqual(counter1New.get(), 0)
+
+ with self.assertRaisesRegex(Exception, 'released'):
+ counter1.get()
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)