This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e36a945  raise warning when we detect Block inside nested list/dict 
(#9148)
e36a945 is described below

commit e36a945db481bc3d9c6d674c8ca404a02d00dc93
Author: Xingjian Shi <[email protected]>
AuthorDate: Mon Jan 15 12:18:09 2018 -0800

    raise warning when we detect Block inside nested list/dict (#9148)
    
    * add warning for block in nested list dict
    
    Block inside contiainers is not supported
    
    * test ci again
---
 python/mxnet/gluon/block.py           | 29 +++++++++++++++++++++++
 python/mxnet/gluon/nn/basic_layers.py |  2 +-
 tests/python/unittest/test_gluon.py   | 44 +++++++++++++++++++++++++++++++++++
 3 files changed, 74 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index fd75e4b..c84c528 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -200,6 +200,33 @@ class Block(object):
 
         super(Block, self).__setattr__(name, value)
 
+    def _check_container_with_block(self):
+        def _find_block_in_container(data):
+            # Find whether a nested container structure contains Blocks
+            if isinstance(data, (list, tuple)):
+                for ele in data:
+                    if _find_block_in_container(ele):
+                        return True
+                return False
+            elif isinstance(data, dict):
+                for _, v in data.items():
+                    if _find_block_in_container(v):
+                        return True
+                return False
+            elif isinstance(data, Block):
+                return True
+            else:
+                return False
+        for k, v in self.__dict__.items():
+            if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') 
or k == '_children'):
+                if _find_block_in_container(v):
+                    warnings.warn('"{name}" is a container with Blocks. '
+                                  'Note that Blocks inside the list, tuple or 
dict will not be '
+                                  'registered automatically. Make sure to 
register them using '
+                                  'register_child() or switching to '
+                                  'nn.Sequential/nn.HybridSequential instead. '
+                                  .format(name=self.__class__.__name__ + "." + 
k))
+
     def _alias(self):
         return self.__class__.__name__.lower()
 
@@ -252,6 +279,8 @@ class Block(object):
         -------
         The selected :py:class:`ParameterDict`
         """
+        # We need to check here because blocks inside containers are not 
supported.
+        self._check_container_with_block()
         ret = ParameterDict(self._params.prefix)
         if not select:
             ret.update(self.params)
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index ab5d5e1..a66cc22 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -90,7 +90,7 @@ class HybridSequential(HybridBlock):
 
     Example::
 
-        net = nn.Sequential()
+        net = nn.HybridSequential()
         # use net's name_scope to give child Blocks appropriate names.
         with net.name_scope():
             net.add(nn.Dense(10, activation='relu'))
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 57bf5c9..80109cf 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -475,6 +475,50 @@ def test_block_attr_regular():
     b.c = c2
     assert b.c is c2 and b._children[0] is c2
 
+def test_block_attr_list_of_block():
+    class Model1(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Model1, self).__init__(**kwargs)
+            with self.name_scope():
+                self.layers = [nn.Dense(i * 10) for i in range(6)]
+
+    class Model2(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Model2, self).__init__(**kwargs)
+            with self.name_scope():
+                self.layers = dict()
+                self.layers['a'] = [nn.Dense(10), nn.Dense(10)]
+
+    class Model3(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Model3, self).__init__(**kwargs)
+            with self.name_scope():
+                self.layers = nn.Sequential()
+                self.layers.add(*[nn.Dense(i * 10) for i in range(6)])
+
+    class Model4(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Model4, self).__init__(**kwargs)
+            with self.name_scope():
+                self.data = {'a': '4', 'b': 123}
+
+    with warnings.catch_warnings(record=True) as w:
+        model = Model1()
+        model.collect_params()
+        assert len(w) > 0
+    with warnings.catch_warnings(record=True) as w:
+        model = Model2()
+        model.collect_params()
+        assert len(w) > 0
+    with warnings.catch_warnings(record=True) as w:
+        model = Model3()
+        model.collect_params()
+        assert len(w) == 0
+    with warnings.catch_warnings(record=True) as w:
+        model = Model4()
+        model.collect_params()
+        assert len(w) == 0
+
 def test_sequential_warning():
     with warnings.catch_warnings(record=True) as w:
         b = gluon.nn.Sequential()

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to