This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c55fc57  add Sequential compatibility to rnn layers (#7352)
c55fc57 is described below

commit c55fc571d22ec458365d87740b589827ddfd86cf
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Thu Aug 10 10:24:03 2017 -0700

    add Sequential compatibility to rnn layers (#7352)
---
 python/mxnet/gluon/rnn/rnn_layer.py     | 46 ++++++++++++++++++++++++---------
 tests/python/unittest/test_gluon_rnn.py | 29 +++++++++++++++++++++
 2 files changed, 63 insertions(+), 12 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index a9bcee5..86b7c61 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -168,10 +168,13 @@ class _RNNLayer(Block):
             states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
         return states
 
-    def forward(self, inputs, states):
+    def forward(self, inputs, states=None):
+        batch_size = inputs.shape[self._layout.find('N')]
+        skip_states = states is None
+        if skip_states:
+            states = self.begin_state(batch_size)
         if isinstance(states, ndarray.NDArray):
             states = [states]
-        batch_size = inputs.shape[self._layout.find('N')]
         for state, info in zip(states, self.state_info(batch_size)):
             if state.shape != info['shape']:
                 raise ValueError(
@@ -182,8 +185,12 @@ class _RNNLayer(Block):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, 
inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
         if inputs.context.device_type == 'gpu':
-            return self._forward_gpu(inputs, states)
-        return self._forward_cpu(inputs, states)
+            out = self._forward_gpu(inputs, states)
+        else:
+            out = self._forward_cpu(inputs, states)
+
+        # out is (output, state)
+        return out[0] if skip_states else out
 
     def _forward_cpu(self, inputs, states):
         ns = len(states)
@@ -282,10 +289,12 @@ class RNN(_RNNLayer):
         If `bidirectional` is True, output shape will instead be
         `(sequence_length, batch_size, 2*num_hidden)`
 
-    Recurrent state shape:
-        The recurrent state's shape is `(num_layers, batch_size, num_hidden)`.
-        If `bidirectional` is True, state shape will instead be
+    Recurrent state:
+        The recurrent state is an NDArray with shape `(num_layers, batch_size, 
num_hidden)`.
+        If `bidirectional` is True, the recurrent state shape will instead be
         `(2*num_layers, batch_size, num_hidden)`
+        If input recurrent state is None, zeros are used as default begin 
states,
+        and the output recurrent state is omitted.
 
 
     Examples
@@ -293,6 +302,9 @@ class RNN(_RNNLayer):
     >>> layer = mx.gluon.rnn.RNN(100, 3)
     >>> layer.initialize()
     >>> input = mx.nd.random_uniform(shape=(5, 3, 10))
+    >>> # by default zeros are used as begin state
+    >>> output = layer(input)
+    >>> # manually specify begin state.
     >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100))
     >>> output, hn = layer(input, h0)
     """
@@ -379,11 +391,13 @@ class LSTM(_RNNLayer):
         If `bidirectional` is True, output shape will instead be
         `(sequence_length, batch_size, 2*num_hidden)`
 
-    Recurrent state shape:
+    Recurrent state:
         The recurrent state is a list of two NDArrays. Both has shape
         `(num_layers, batch_size, num_hidden)`.
-        If `bidirectional` is True, state shape will instead be
+        If `bidirectional` is True, each recurrent state will instead have 
shape
         `(2*num_layers, batch_size, num_hidden)`.
+        If input recurrent state is None, zeros are used as default begin 
states,
+        and the output recurrent state is omitted.
 
 
     Examples
@@ -391,6 +405,9 @@ class LSTM(_RNNLayer):
     >>> layer = mx.gluon.rnn.LSTM(100, 3)
     >>> layer.initialize()
     >>> input = mx.nd.random_uniform(shape=(5, 3, 10))
+    >>> # by default zeros are used as begin state
+    >>> output = layer(input)
+    >>> # manually specify begin state.
     >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100))
     >>> c0 = mx.nd.random_uniform(shape=(3, 3, 100))
     >>> output, hn = layer(input, [h0, c0])
@@ -474,10 +491,12 @@ class GRU(_RNNLayer):
         If `bidirectional` is True, output shape will instead be
         `(sequence_length, batch_size, 2*num_hidden)`
 
-    Recurrent state shape:
-        The recurrent state's shape is `(num_layers, batch_size, num_hidden)`.
-        If `bidirectional` is True, state shape will instead be
+    Recurrent state:
+        The recurrent state is an NDArray with shape `(num_layers, batch_size, 
num_hidden)`.
+        If `bidirectional` is True, the recurrent state shape will instead be
         `(2*num_layers, batch_size, num_hidden)`
+        If input recurrent state is None, zeros are used as default begin 
states,
+        and the output recurrent state is omitted.
 
 
     Examples
@@ -485,6 +504,9 @@ class GRU(_RNNLayer):
     >>> layer = mx.gluon.rnn.GRU(100, 3)
     >>> layer.initialize()
     >>> input = mx.nd.random_uniform(shape=(5, 3, 10))
+    >>> # by default zeros are used as begin state
+    >>> output = layer(input)
+    >>> # manually specify begin state.
     >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100))
     >>> output, hn = layer(input, h0)
     """
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index ac671e5..4062013 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -209,6 +209,35 @@ def test_rnn_cells():
     net.add(gluon.rnn.GRUCell(100, input_size=100))
     check_rnn_forward(net, mx.nd.ones((8, 3, 200)))
 
+def check_rnn_layer_forward(layer, inputs, states=None):
+    layer.collect_params().initialize()
+    with mx.autograd.record():
+        out = layer(inputs, states)
+        if states is not None:
+            assert isinstance(out, tuple) and len(out) == 2
+            out = out[0]
+        else:
+            assert isinstance(out, mx.nd.NDArray)
+        out.backward()
+    mx.nd.waitall()
+
+def test_rnn_layers():
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)), 
[mx.nd.ones((2, 3, 10)), mx.nd.ones((2, 3, 10))])
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
+
+    net = gluon.nn.Sequential()
+    net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))
+    net.add(gluon.nn.BatchNorm(axis=2))
+    net.add(gluon.nn.Flatten())
+    net.add(gluon.nn.Dense(3, activation='relu'))
+    net.collect_params().initialize()
+    with mx.autograd.record():
+        net(mx.nd.ones((2, 3, 10))).backward()
+
 
 if __name__ == '__main__':
     import nose

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to