This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 65061dc  fix rnn (#10954)
65061dc is described below

commit 65061dc93710afce92edfe548aa3352473da3cdb
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Tue May 15 14:30:08 2018 -0700

    fix rnn (#10954)
---
 python/mxnet/gluon/rnn/rnn_layer.py     |  5 +++--
 tests/python/unittest/test_gluon_rnn.py | 29 ++++++++++++++++++++++-------
 2 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index 34ad05d..89224cf 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -23,7 +23,7 @@
 from __future__ import print_function
 __all__ = ['RNN', 'LSTM', 'GRU']
 
-from ... import ndarray
+from ... import ndarray, autograd
 from .. import Block
 from . import rnn_cell
 
@@ -185,7 +185,8 @@ class _RNNLayer(Block):
             for i in range(self._dir):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, 
inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
-        if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
+        if inputs.context.device_type == 'gpu' or \
+           self._mode == 'lstm' and not (self._dropout and 
autograd.is_training()):
             out = self._forward_kernel(inputs, states)
         else:
             out = self._forward(inputs, states)
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index f22b13d..24d5a93 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -80,7 +80,7 @@ def test_lstm_cpu_inference():
 
     mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
                                       rtol=1e-3, atol=1e-5)
-    
+
 
 def test_gru():
     cell = gluon.rnn.GRUCell(100, prefix='rnn_')
@@ -242,7 +242,7 @@ def test_rnn_cells():
     net.add(gluon.rnn.GRUCell(100, input_size=100))
     check_rnn_forward(net, mx.nd.ones((8, 3, 200)))
 
-def check_rnn_layer_forward(layer, inputs, states=None):
+def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
     layer.collect_params().initialize()
     inputs.attach_grad()
     with mx.autograd.record():
@@ -268,17 +268,32 @@ def check_rnn_layer_forward(layer, inputs, states=None):
             assert isinstance(out, mx.nd.NDArray)
         out.backward()
 
-    mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
-    mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, 
atol=1e-5)
+    if not run_only:
+        mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
+        mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3, atol=1e-5)
 
 
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
     check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)), 
[mx.nd.ones((2, 3, 10)), mx.nd.ones((2, 3, 10))])
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
     check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
-    check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
+
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
+                            run_only=True)
+    check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, 
dropout=0.5),
+                            mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
+                            run_only=True)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, 
dropout=0.5),
+                            mx.nd.ones((8, 3, 20)),
+                            [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], 
run_only=True)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 
3, 20)),
+                            run_only=True)
+    check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, 
dropout=0.5),
+                            mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
 
     net = gluon.nn.Sequential()
     net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to