piiswrong closed pull request #9148: raise warning when we detect Block inside
nested list/dict
URL: https://github.com/apache/incubator-mxnet/pull/9148
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index fd75e4b7b7..c84c528aa8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -200,6 +200,33 @@ def __setattr__(self, name, value):
super(Block, self).__setattr__(name, value)
+ def _check_container_with_block(self):
+ def _find_block_in_container(data):
+ # Find whether a nested container structure contains Blocks
+ if isinstance(data, (list, tuple)):
+ for ele in data:
+ if _find_block_in_container(ele):
+ return True
+ return False
+ elif isinstance(data, dict):
+ for _, v in data.items():
+ if _find_block_in_container(v):
+ return True
+ return False
+ elif isinstance(data, Block):
+ return True
+ else:
+ return False
+ for k, v in self.__dict__.items():
+ if isinstance(v, (list, tuple, dict)) and not (k.startswith('__')
or k == '_children'):
+ if _find_block_in_container(v):
+ warnings.warn('"{name}" is a container with Blocks. '
+ 'Note that Blocks inside the list, tuple or
dict will not be '
+ 'registered automatically. Make sure to
register them using '
+ 'register_child() or switching to '
+ 'nn.Sequential/nn.HybridSequential instead. '
+ .format(name=self.__class__.__name__ + "." +
k))
+
def _alias(self):
return self.__class__.__name__.lower()
@@ -252,6 +279,8 @@ def collect_params(self, select=None):
-------
The selected :py:class:`ParameterDict`
"""
+ # We need to check here because blocks inside containers are not
supported.
+ self._check_container_with_block()
ret = ParameterDict(self._params.prefix)
if not select:
ret.update(self.params)
diff --git a/python/mxnet/gluon/nn/basic_layers.py
b/python/mxnet/gluon/nn/basic_layers.py
index ab5d5e167f..a66cc22628 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -90,7 +90,7 @@ class HybridSequential(HybridBlock):
Example::
- net = nn.Sequential()
+ net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
net.add(nn.Dense(10, activation='relu'))
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index 57bf5c97c7..80109cf99a 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -475,6 +475,50 @@ def test_block_attr_regular():
b.c = c2
assert b.c is c2 and b._children[0] is c2
+def test_block_attr_list_of_block():
+ class Model1(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model1, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = [nn.Dense(i * 10) for i in range(6)]
+
+ class Model2(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model2, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = dict()
+ self.layers['a'] = [nn.Dense(10), nn.Dense(10)]
+
+ class Model3(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model3, self).__init__(**kwargs)
+ with self.name_scope():
+ self.layers = nn.Sequential()
+ self.layers.add(*[nn.Dense(i * 10) for i in range(6)])
+
+ class Model4(gluon.Block):
+ def __init__(self, **kwargs):
+ super(Model4, self).__init__(**kwargs)
+ with self.name_scope():
+ self.data = {'a': '4', 'b': 123}
+
+ with warnings.catch_warnings(record=True) as w:
+ model = Model1()
+ model.collect_params()
+ assert len(w) > 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model2()
+ model.collect_params()
+ assert len(w) > 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model3()
+ model.collect_params()
+ assert len(w) == 0
+ with warnings.catch_warnings(record=True) as w:
+ model = Model4()
+ model.collect_params()
+ assert len(w) == 0
+
def test_sequential_warning():
with warnings.catch_warnings(record=True) as w:
b = gluon.nn.Sequential()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services