This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 32c588f fix cached op (#8031)
32c588f is described below
commit 32c588ff7f04a0a6173f649df188d8804db7c1d6
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Mon Sep 25 14:48:31 2017 -0700
fix cached op (#8031)
---
src/imperative/cached_op.cc | 6 ++++++
tests/python/unittest/test_gluon_rnn.py | 8 ++++----
2 files changed, 10 insertions(+), 4 deletions(-)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index afaa32c..224f088 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -267,6 +267,12 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
stypes[i] = state.buff[i].storage_type();
}
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ shapes[bwd_input_eid_[i]] = inputs[i]->shape();
+ dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
+ stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
+ }
+
std::pair<uint32_t, uint32_t> node_range, entry_range;
node_range = {num_forward_nodes, idx.num_nodes()};
entry_range = {num_forward_entries, idx.num_node_entries()};
diff --git a/tests/python/unittest/test_gluon_rnn.py
b/tests/python/unittest/test_gluon_rnn.py
index 079b337..89da900 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -201,8 +201,8 @@ def check_rnn_forward(layer, inputs, deterministic=True):
out.backward()
if deterministic:
- mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3)
- mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(),
rtol=1e-3)
+ mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3,
atol=1e-5)
+ mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(),
rtol=1e-3, atol=1e-5)
@@ -253,8 +253,8 @@ def check_rnn_layer_forward(layer, inputs, states=None):
assert isinstance(out, mx.nd.NDArray)
out.backward()
- mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3)
- mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3)
+ mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3,
atol=1e-5)
+ mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3,
atol=1e-5)
def test_rnn_layers():
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].