This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 32c588f  fix cached op (#8031)
32c588f is described below

commit 32c588ff7f04a0a6173f649df188d8804db7c1d6
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Mon Sep 25 14:48:31 2017 -0700

    fix cached op (#8031)
---
 src/imperative/cached_op.cc             | 6 ++++++
 tests/python/unittest/test_gluon_rnn.py | 8 ++++----
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index afaa32c..224f088 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -267,6 +267,12 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
     stypes[i] = state.buff[i].storage_type();
   }
 
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    shapes[bwd_input_eid_[i]] = inputs[i]->shape();
+    dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
+    stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
+  }
+
   std::pair<uint32_t, uint32_t> node_range, entry_range;
   node_range = {num_forward_nodes, idx.num_nodes()};
   entry_range = {num_forward_entries, idx.num_node_entries()};
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index 079b337..89da900 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -201,8 +201,8 @@ def check_rnn_forward(layer, inputs, deterministic=True):
         out.backward()
 
     if deterministic:
-        mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3)
-        mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3)
+        mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
+        mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3, atol=1e-5)
 
 
 
@@ -253,8 +253,8 @@ def check_rnn_layer_forward(layer, inputs, states=None):
             assert isinstance(out, mx.nd.NDArray)
         out.backward()
 
-    mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3)
-    mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3)
+    mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
+    mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, 
atol=1e-5)
 
 
 def test_rnn_layers():

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to