This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new c55fc57 add Sequential compatibility to rnn layers (#7352) c55fc57 is described below commit c55fc571d22ec458365d87740b589827ddfd86cf Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Thu Aug 10 10:24:03 2017 -0700 add Sequential compatibility to rnn layers (#7352) --- python/mxnet/gluon/rnn/rnn_layer.py | 46 ++++++++++++++++++++++++--------- tests/python/unittest/test_gluon_rnn.py | 29 +++++++++++++++++++++ 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index a9bcee5..86b7c61 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -168,10 +168,13 @@ class _RNNLayer(Block): states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states - def forward(self, inputs, states): + def forward(self, inputs, states=None): + batch_size = inputs.shape[self._layout.find('N')] + skip_states = states is None + if skip_states: + states = self.begin_state(batch_size) if isinstance(states, ndarray.NDArray): states = [states] - batch_size = inputs.shape[self._layout.find('N')] for state, info in zip(states, self.state_info(batch_size)): if state.shape != info['shape']: raise ValueError( @@ -182,8 +185,12 @@ class _RNNLayer(Block): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() if inputs.context.device_type == 'gpu': - return self._forward_gpu(inputs, states) - return self._forward_cpu(inputs, states) + out = self._forward_gpu(inputs, states) + else: + out = self._forward_cpu(inputs, states) + + # out is (output, state) + return out[0] if skip_states else out def _forward_cpu(self, inputs, states): ns = len(states) @@ -282,10 +289,12 @@ class RNN(_RNNLayer): If `bidirectional` is True, output shape will instead be `(sequence_length, batch_size, 2*num_hidden)` - Recurrent state shape: - The recurrent state's shape is `(num_layers, batch_size, num_hidden)`. - If `bidirectional` is True, state shape will instead be + Recurrent state: + The recurrent state is an NDArray with shape `(num_layers, batch_size, num_hidden)`. + If `bidirectional` is True, the recurrent state shape will instead be `(2*num_layers, batch_size, num_hidden)` + If input recurrent state is None, zeros are used as default begin states, + and the output recurrent state is omitted. Examples @@ -293,6 +302,9 @@ class RNN(_RNNLayer): >>> layer = mx.gluon.rnn.RNN(100, 3) >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> # by default zeros are used as begin state + >>> output = layer(input) + >>> # manually specify begin state. >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, h0) """ @@ -379,11 +391,13 @@ class LSTM(_RNNLayer): If `bidirectional` is True, output shape will instead be `(sequence_length, batch_size, 2*num_hidden)` - Recurrent state shape: + Recurrent state: The recurrent state is a list of two NDArrays. Both has shape `(num_layers, batch_size, num_hidden)`. - If `bidirectional` is True, state shape will instead be + If `bidirectional` is True, each recurrent state will instead have shape `(2*num_layers, batch_size, num_hidden)`. + If input recurrent state is None, zeros are used as default begin states, + and the output recurrent state is omitted. Examples @@ -391,6 +405,9 @@ class LSTM(_RNNLayer): >>> layer = mx.gluon.rnn.LSTM(100, 3) >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> # by default zeros are used as begin state + >>> output = layer(input) + >>> # manually specify begin state. >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) >>> c0 = mx.nd.random_uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, [h0, c0]) @@ -474,10 +491,12 @@ class GRU(_RNNLayer): If `bidirectional` is True, output shape will instead be `(sequence_length, batch_size, 2*num_hidden)` - Recurrent state shape: - The recurrent state's shape is `(num_layers, batch_size, num_hidden)`. - If `bidirectional` is True, state shape will instead be + Recurrent state: + The recurrent state is an NDArray with shape `(num_layers, batch_size, num_hidden)`. + If `bidirectional` is True, the recurrent state shape will instead be `(2*num_layers, batch_size, num_hidden)` + If input recurrent state is None, zeros are used as default begin states, + and the output recurrent state is omitted. Examples @@ -485,6 +504,9 @@ class GRU(_RNNLayer): >>> layer = mx.gluon.rnn.GRU(100, 3) >>> layer.initialize() >>> input = mx.nd.random_uniform(shape=(5, 3, 10)) + >>> # by default zeros are used as begin state + >>> output = layer(input) + >>> # manually specify begin state. >>> h0 = mx.nd.random_uniform(shape=(3, 3, 100)) >>> output, hn = layer(input, h0) """ diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index ac671e5..4062013 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -209,6 +209,35 @@ def test_rnn_cells(): net.add(gluon.rnn.GRUCell(100, input_size=100)) check_rnn_forward(net, mx.nd.ones((8, 3, 200))) +def check_rnn_layer_forward(layer, inputs, states=None): + layer.collect_params().initialize() + with mx.autograd.record(): + out = layer(inputs, states) + if states is not None: + assert isinstance(out, tuple) and len(out) == 2 + out = out[0] + else: + assert isinstance(out, mx.nd.NDArray) + out.backward() + mx.nd.waitall() + +def test_rnn_layers(): + check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20))) + check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10))) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20))) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)), [mx.nd.ones((2, 3, 10)), mx.nd.ones((2, 3, 10))]) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20))) + check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10))) + + net = gluon.nn.Sequential() + net.add(gluon.rnn.LSTM(10, 2, bidirectional=True)) + net.add(gluon.nn.BatchNorm(axis=2)) + net.add(gluon.nn.Flatten()) + net.add(gluon.nn.Dense(3, activation='relu')) + net.collect_params().initialize() + with mx.autograd.record(): + net(mx.nd.ones((2, 3, 10))).backward() + if __name__ == '__main__': import nose -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].