This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f4d0290  Fix races in block scope (#17749)
f4d0290 is described below

commit f4d0290fc2cd5763aa5a9c890e4d3dcd4ea6ec6b
Author: Haozheng Fan <[email protected]>
AuthorDate: Thu May 21 07:16:22 2020 +0800

    Fix races in block scope (#17749)
    
    * Add tests
    
    * Fix block_scope
    
    Co-authored-by: Haibin Lin <[email protected]>
    Co-authored-by: Lin <[email protected]>
---
 python/mxnet/gluon/block.py                | 25 +++++++++++----------
 python/mxnet/name.py                       |  9 ++++----
 tests/python/unittest/test_thread_local.py | 36 ++++++++++++++++++++++++++++++
 3 files changed, 54 insertions(+), 16 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 6d9ea9a..ded66a7 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -52,8 +52,9 @@ class _BlockScope(object):
     def __init__(self, block):
         self._block = weakref.ref(block) if block is not None else None
         self._counter = {}
-        self._old_scope = None
-        self._name_scope = None
+        self._local = threading.local()
+        self._local._old_scope = None
+        self._local._name_scope = None
 
     @staticmethod
     def create(prefix, params, hint):
@@ -96,23 +97,23 @@ class _BlockScope(object):
         block = self._block()
         if block is None or block._empty_prefix:
             return self
-        self._old_scope = getattr(_BlockScope._current, "value", None)
+        self._local._old_scope = getattr(_BlockScope._current, "value", None)
         _BlockScope._current.value = self
-        self._name_scope = _name.Prefix(block.prefix)
-        self._name_scope.__enter__()
-        self._profiler_scope = _profiler.Scope(block._profiler_scope_name)
-        self._profiler_scope.__enter__()
+        self._local._name_scope = _name.Prefix(block.prefix)
+        self._local._name_scope.__enter__()
+        self._local._profiler_scope = 
_profiler.Scope(block._profiler_scope_name)
+        self._local._profiler_scope.__enter__()
         return self
 
     def __exit__(self, ptype, value, trace):
         block = self._block()
         if block is None or block._empty_prefix:
             return
-        self._name_scope.__exit__(ptype, value, trace)
-        self._name_scope = None
-        self._profiler_scope.__exit__(ptype, value, trace)
-        self._profiler_scope = None
-        _BlockScope._current.value = self._old_scope
+        self._local._name_scope.__exit__(ptype, value, trace)
+        self._local._name_scope = None
+        self._local._profiler_scope.__exit__(ptype, value, trace)
+        self._local._profiler_scope = None
+        _BlockScope._current.value = self._local._old_scope
 
 
 def _gather_type_ctx_info(args):
diff --git a/python/mxnet/name.py b/python/mxnet/name.py
index b276c72..e39752e 100644
--- a/python/mxnet/name.py
+++ b/python/mxnet/name.py
@@ -30,7 +30,8 @@ class NameManager(with_metaclass(_MXClassPropertyMetaClass, 
object)):
 
     def __init__(self):
         self._counter = {}
-        self._old_manager = None
+        self._local = threading.local()
+        self._local._old_manager = None
 
     def get(self, name, hint):
         """Get the canonical name for a symbol.
@@ -66,13 +67,13 @@ class NameManager(with_metaclass(_MXClassPropertyMetaClass, 
object)):
     def __enter__(self):
         if not hasattr(NameManager._current, "value"):
             NameManager._current.value = NameManager()
-        self._old_manager = NameManager._current.value
+        self._local._old_manager = NameManager._current.value
         NameManager._current.value = self
         return self
 
     def __exit__(self, ptype, value, trace):
-        assert self._old_manager
-        NameManager._current.value = self._old_manager
+        assert self._local._old_manager
+        NameManager._current.value = self._local._old_manager
 
     #pylint: disable=no-self-argument
     @classproperty
diff --git a/tests/python/unittest/test_thread_local.py 
b/tests/python/unittest/test_thread_local.py
index 5423249..975ad2a 100644
--- a/tests/python/unittest/test_thread_local.py
+++ b/tests/python/unittest/test_thread_local.py
@@ -222,3 +222,39 @@ def test_np_global_shape():
     finally:
         set_np_shape(0)
 
+def test_blockscope_multithread():
+    event = threading.Event()
+    status = [False]
+
+    class dummy_block(object):
+        def __init__(self, prefix):
+            self.prefix = prefix
+            self._profiler_scope_name = prefix
+            self._empty_prefix = False
+    
+    def f(scope):
+        try:
+            with scope:
+                event.wait()
+        except:
+            status[0] = True
+
+    def g(scope):
+        with scope:
+            pass
+        event.set()
+
+    scope = block._BlockScope(dummy_block("scope_"))
+    count = 2
+    threads = [threading.Thread(target=f, args=(scope,)),
+               threading.Thread(target=g, args=(scope,))]
+    for i in range(count):
+        threads[i].start()
+    for i in range(count):
+        threads[i].join()
+    assert status[0] is False, "_BlockScope does not work with multithread"
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

Reply via email to