This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e36a945 raise warning when we detect Block inside nested list/dict
(#9148)
e36a945 is described below
commit e36a945db481bc3d9c6d674c8ca404a02d00dc93
Author: Xingjian Shi <[email protected]>
AuthorDate: Mon Jan 15 12:18:09 2018 -0800
raise warning when we detect Block inside nested list/dict (#9148)
* add warning for block in nested list dict
Block inside contiainers is not supported
* test ci again
---
python/mxnet/gluon/block.py | 29 +++++++++++++++++++++++
python/mxnet/gluon/nn/basic_layers.py | 2 +-
tests/python/unittest/test_gluon.py | 44 +++++++++++++++++++++++++++++++++++
3 files changed, 74 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index fd75e4b..c84c528 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -200,6 +200,33 @@ class Block(object):
super(Block, self).__setattr__(name, value)
+ def _check_container_with_block(self):
+ def _find_block_in_container(data):
+ # Find whether a nested container structure contains Blocks
+ if isinstance(data, (list, tuple)):
+ for ele in data:
+ if _find_block_in_container(ele):
+ return True
+ return False
+ elif isinstance(data, dict):
+ for _, v in data.items():
+ if _find_block_in_container(v):
+ return True
+ return False
+ elif isinstance(data, Block):
+ return True
+ else:
+ return False
+ for k, v in self.__dict__.items():
+ if isinstance(v, (list, tuple, dict)) and not (k.startswith('__')
or k == '_children'):
+ if _find_block_in_container(v):
+ warnings.warn('"{name}" is a container with Blocks. '
+ 'Note that Blocks inside the list, tuple or
dict will not be '
+ 'registered automatically. Make sure to
register them using '
+ 'register_child() or switching to '
+ 'nn.Sequential/nn.HybridSequential instead. '
+ .format(name=self.__class__.__name__ + "." +
k))
+
def _alias(self):
return self.__class__.__name__.lower()
@@ -252,6 +279,8 @@ class Block(object):
-------
The selected :py:class:`ParameterDict`
"""
+ # We need to check here because blocks inside containers are not
supported.
+ self._check_container_with_block()
ret = ParameterDict(self._params.prefix)
if not select:
ret.update(self.params)
diff --git a/python/mxnet/gluon/nn/basic_layers.py
b/python/mxnet/gluon/nn/basic_layers.py
index ab5d5e1..a66cc22 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -90,7 +90,7 @@ class HybridSequential(HybridBlock):
Example::
- net = nn.Sequential()
+ net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index 57bf5c9..80109cf 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -475,6 +475,50 @@ def test_block_attr_regular():
b.c = c2
assert b.c is c2 and b._children[0] is c2
+def test_block_attr_list_of_block():
+ class Model1(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model1, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = [nn.Dense(i * 10) for i in range(6)]
+
+ class Model2(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model2, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = dict()
+ self.layers['a'] = [nn.Dense(10), nn.Dense(10)]
+
+ class Model3(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model3, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = nn.Sequential()
+ self.layers.add(*[nn.Dense(i * 10) for i in range(6)])
+
+ class Model4(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model4, self).__init__(**kwargs)
+ with self.name_scope():
+ self.data = {'a': '4', 'b': 123}
+
+ with warnings.catch_warnings(record=True) as w:
+ model = Model1()
+ model.collect_params()
+ assert len(w) > 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model2()
+ model.collect_params()
+ assert len(w) > 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model3()
+ model.collect_params()
+ assert len(w) == 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model4()
+ model.collect_params()
+ assert len(w) == 0
+
def test_sequential_warning():
with warnings.catch_warnings(record=True) as w:
b = gluon.nn.Sequential()
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].