This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f4d0290 Fix races in block scope (#17749)
f4d0290 is described below
commit f4d0290fc2cd5763aa5a9c890e4d3dcd4ea6ec6b
Author: Haozheng Fan <[email protected]>
AuthorDate: Thu May 21 07:16:22 2020 +0800
Fix races in block scope (#17749)
* Add tests
* Fix block_scope
Co-authored-by: Haibin Lin <[email protected]>
Co-authored-by: Lin <[email protected]>
---
python/mxnet/gluon/block.py | 25 +++++++++++----------
python/mxnet/name.py | 9 ++++----
tests/python/unittest/test_thread_local.py | 36 ++++++++++++++++++++++++++++++
3 files changed, 54 insertions(+), 16 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 6d9ea9a..ded66a7 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -52,8 +52,9 @@ class _BlockScope(object):
def __init__(self, block):
self._block = weakref.ref(block) if block is not None else None
self._counter = {}
- self._old_scope = None
- self._name_scope = None
+ self._local = threading.local()
+ self._local._old_scope = None
+ self._local._name_scope = None
@staticmethod
def create(prefix, params, hint):
@@ -96,23 +97,23 @@ class _BlockScope(object):
block = self._block()
if block is None or block._empty_prefix:
return self
- self._old_scope = getattr(_BlockScope._current, "value", None)
+ self._local._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
- self._name_scope = _name.Prefix(block.prefix)
- self._name_scope.__enter__()
- self._profiler_scope = _profiler.Scope(block._profiler_scope_name)
- self._profiler_scope.__enter__()
+ self._local._name_scope = _name.Prefix(block.prefix)
+ self._local._name_scope.__enter__()
+ self._local._profiler_scope =
_profiler.Scope(block._profiler_scope_name)
+ self._local._profiler_scope.__enter__()
return self
def __exit__(self, ptype, value, trace):
block = self._block()
if block is None or block._empty_prefix:
return
- self._name_scope.__exit__(ptype, value, trace)
- self._name_scope = None
- self._profiler_scope.__exit__(ptype, value, trace)
- self._profiler_scope = None
- _BlockScope._current.value = self._old_scope
+ self._local._name_scope.__exit__(ptype, value, trace)
+ self._local._name_scope = None
+ self._local._profiler_scope.__exit__(ptype, value, trace)
+ self._local._profiler_scope = None
+ _BlockScope._current.value = self._local._old_scope
def _gather_type_ctx_info(args):
diff --git a/python/mxnet/name.py b/python/mxnet/name.py
index b276c72..e39752e 100644
--- a/python/mxnet/name.py
+++ b/python/mxnet/name.py
@@ -30,7 +30,8 @@ class NameManager(with_metaclass(_MXClassPropertyMetaClass,
object)):
def __init__(self):
self._counter = {}
- self._old_manager = None
+ self._local = threading.local()
+ self._local._old_manager = None
def get(self, name, hint):
"""Get the canonical name for a symbol.
@@ -66,13 +67,13 @@ class NameManager(with_metaclass(_MXClassPropertyMetaClass,
object)):
def __enter__(self):
if not hasattr(NameManager._current, "value"):
NameManager._current.value = NameManager()
- self._old_manager = NameManager._current.value
+ self._local._old_manager = NameManager._current.value
NameManager._current.value = self
return self
def __exit__(self, ptype, value, trace):
- assert self._old_manager
- NameManager._current.value = self._old_manager
+ assert self._local._old_manager
+ NameManager._current.value = self._local._old_manager
#pylint: disable=no-self-argument
@classproperty
diff --git a/tests/python/unittest/test_thread_local.py
b/tests/python/unittest/test_thread_local.py
index 5423249..975ad2a 100644
--- a/tests/python/unittest/test_thread_local.py
+++ b/tests/python/unittest/test_thread_local.py
@@ -222,3 +222,39 @@ def test_np_global_shape():
finally:
set_np_shape(0)
+def test_blockscope_multithread():
+ event = threading.Event()
+ status = [False]
+
+ class dummy_block(object):
+ def __init__(self, prefix):
+ self.prefix = prefix
+ self._profiler_scope_name = prefix
+ self._empty_prefix = False
+
+ def f(scope):
+ try:
+ with scope:
+ event.wait()
+ except:
+ status[0] = True
+
+ def g(scope):
+ with scope:
+ pass
+ event.set()
+
+ scope = block._BlockScope(dummy_block("scope_"))
+ count = 2
+ threads = [threading.Thread(target=f, args=(scope,)),
+ threading.Thread(target=g, args=(scope,))]
+ for i in range(count):
+ threads[i].start()
+ for i in range(count):
+ threads[i].join()
+ assert status[0] is False, "_BlockScope does not work with multithread"
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()