ThomasDelteil commented on a change in pull request #11254: add blocklist
URL: https://github.com/apache/incubator-mxnet/pull/11254#discussion_r195175883
##########
File path: python/mxnet/gluon/block.py
##########
@@ -841,6 +842,140 @@ def hybrid_forward(self, F, x, *args, **kwargs):
# pylint: disable= invalid-name
raise NotImplementedError
+
+class BlockList(Block):
+ r"""Holds sub-blocks in a list.
+
+ BlockList can be indexed like a regular Python list, but blocks it
+ contains are properly registered, and will be visible by all Block
methods::
+
+ class MyBlock(gluon.Block):
+ def __init__(self):
+ super(MyBlock, self).__init__()
+ self.blocks = gluon.BlockList([nn.Dense(10) for i in
range(10)])
+
+ def forward(self, x):
+ # BlockList can act as an iterable, or be indexed using ints
+ for i, l in enumerate(self.blocks):
+ x = self.blocks[i // 2](x) + l(x)
+ return x
+
+ Parameters
+ ----------
+ blocks : iterable, optional
+ an iterable of blocks to add
+ """
+ def __init__(self, blocks=None, **kwargs):
+ super(BlockList, self).__init__(**kwargs)
+ if blocks is not None:
+ self += blocks
+
+ def __getitem__(self, idx):
+ if isinstance(idx, slice):
+ return BlockList(list(self._children.values())[idx])
+ else:
+ return self._children.values()[idx]
+
+ def __setitem__(self, idx, block):
+ self.register_child(block, str(idx))
+
+ def __len__(self):
+ return len(self._children)
+
+ def __iter__(self):
+ return iter(self._children.values())
+
+ def __iadd__(self, blocks):
+ return self.extend(blocks)
+
+ def append(self, block):
+ r"""Appends a given block to the end of the list.
+ Arguments:
+ block (nn.Module): block to append
+ """
+ self.register_child(block, str(len(self)))
+ return self
+
+ def extend(self, blocks):
+ r"""Appends blocks from a Python iterable to the end of the list.
+ Arguments:
+ blocks (iterable): iterable of blocks to append
+ """
+ if not isinstance(blocks, Iterable):
+ raise TypeError("BlockList.extend should be called with an "
+ "iterable, but got " + type(blocks).__name__)
+ offset = len(self)
+ for i, block in enumerate(blocks):
+ self.register_child(block, str(offset + i))
+ return self
+
+
+class HybridBlockList(HybridBlock):
+ r"""Holds hybrid sub-blocks in a list.
+
+ HybridBlockList can be indexed like a regular Python list, but blocks it
Review comment:
but the* blocks
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services