This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 14731066599 Fix multi_process_shared.py release with AutoProxy (#27607)
14731066599 is described below

commit 14731066599cd4fdb9971faba4c3995cf68b030d
Author: Danny McCormick <[email protected]>
AuthorDate: Mon Jul 24 12:29:22 2023 -0400

    Fix multi_process_shared.py release with AutoProxy (#27607)
    
    * Failing test to demonstrate the problem
    
    * Allow release call to be proxied correctly
    
    * clean up patch
    
    * Cleanup comment
    
    * Format
    
    * unused imports
    
    * Ignore bad lint rule
---
 .../apache_beam/utils/multi_process_shared.py      | 30 ++++++++++++++++++--
 .../apache_beam/utils/multi_process_shared_test.py | 32 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/utils/multi_process_shared.py 
b/sdks/python/apache_beam/utils/multi_process_shared.py
index b9414d017f3..17bdf2c8489 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared.py
@@ -36,6 +36,26 @@ from typing import TypeVar
 
 import fasteners
 
+# In some python versions, there is a bug where AutoProxy doesn't handle
+# the kwarg 'manager_owned'. We implement our own backup here to make sure
+# we avoid this problem. More info here:
+# 
https://stackoverflow.com/questions/46779860/multiprocessing-managers-and-custom-classes
+autoproxy = multiprocessing.managers.AutoProxy  # type: ignore[attr-defined]
+
+
+def patched_autoproxy(
+    token,
+    serializer,
+    manager=None,
+    authkey=None,
+    exposed=None,
+    incref=True,
+    manager_owned=True):
+  return autoproxy(token, serializer, manager, authkey, exposed, incref)
+
+
+multiprocessing.managers.AutoProxy = patched_autoproxy  # type: 
ignore[attr-defined]
+
 T = TypeVar('T')
 AUTH_KEY = b'mps'
 
@@ -55,7 +75,7 @@ class _SingletonProxy:
       raise RuntimeError('Entry was released.')
     return self._SingletonProxy_entry.obj.__call__(*args, **kwargs)
 
-  def _SingletonProxy_release(self):
+  def singletonProxy_release(self):
     assert self._SingletonProxy_valid
     self._SingletonProxy_valid = False
 
@@ -68,6 +88,7 @@ class _SingletonProxy:
     # Needed for multiprocessing.managers's proxying.
     dir = self._SingletonProxy_entry.obj.__dir__()
     dir.append('singletonProxy_call__')
+    dir.append('singletonProxy_release')
     return dir
 
 
@@ -92,7 +113,7 @@ class _SingletonEntry:
       return _SingletonProxy(self)
 
   def release(self, proxy):
-    proxy._SingletonProxy_release()
+    proxy.singletonProxy_release()
     with self.lock:
       self.refcount -= 1
       if self.refcount == 0:
@@ -151,6 +172,9 @@ class _AutoProxyWrapper:
   def __getattr__(self, name):
     return getattr(self._proxyObject, name)
 
+  def get_auto_proxy_object(self):
+    return self._proxyObject
+
 
 class MultiProcessShared(Generic[T]):
   """MultiProcessShared is used to share a single object across processes.
@@ -252,7 +276,7 @@ class MultiProcessShared(Generic[T]):
     return _AutoProxyWrapper(singleton)
 
   def release(self, obj):
-    self._manager.release_singleton(self._tag, obj)
+    self._manager.release_singleton(self._tag, obj.get_auto_proxy_object())
 
   def _create_server(self, address_file):
     # We need to be able to authenticate with both the manager and the process.
diff --git a/sdks/python/apache_beam/utils/multi_process_shared_test.py 
b/sdks/python/apache_beam/utils/multi_process_shared_test.py
index b5ae90da915..aeae9159d32 100644
--- a/sdks/python/apache_beam/utils/multi_process_shared_test.py
+++ b/sdks/python/apache_beam/utils/multi_process_shared_test.py
@@ -133,6 +133,38 @@ class MultiProcessSharedTest(unittest.TestCase):
     with self.assertRaisesRegex(Exception, 'released'):
       counter1.get()
 
+  def test_release_always_proxy(self):
+    shared1 = multi_process_shared.MultiProcessShared(
+        Counter, tag='test_release_always_proxy', always_proxy=True)
+    shared2 = multi_process_shared.MultiProcessShared(
+        Counter, tag='test_release_always_proxy', always_proxy=True)
+
+    counter1 = shared1.acquire()
+    counter2 = shared2.acquire()
+    self.assertEqual(counter1.increment(), 1)
+    self.assertEqual(counter2.increment(), 2)
+
+    counter1again = shared1.acquire()
+    self.assertEqual(counter1again.increment(), 3)
+
+    shared1.release(counter1)
+    shared2.release(counter2)
+
+    with self.assertRaisesRegex(Exception, 'released'):
+      counter1.get()
+    with self.assertRaisesRegex(Exception, 'released'):
+      counter2.get()
+
+    self.assertEqual(counter1again.get(), 3)
+
+    shared1.release(counter1again)
+
+    counter1New = shared1.acquire()
+    self.assertEqual(counter1New.get(), 0)
+
+    with self.assertRaisesRegex(Exception, 'released'):
+      counter1.get()
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to