anirudh2290 closed pull request #13043: Fix test failure due to hybridize call in test_gluon_rnn.test_layer_fill_shape URL: https://github.com/apache/incubator-mxnet/pull/13043
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a690fb11d8a..a2e19c5e1be 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -190,6 +190,7 @@ List of Contributors * [Denisa Roberts](https://github.com/D-Roberts) * [Dick Carter](https://github.com/DickJC123) * [Rahul Padmanabhan](https://github.com/rahul3) +* [Yuxi Hu](https://github.com/yuxihu) Label Bot --------- diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 1f115cd64ad..a836765f51a 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -855,10 +855,15 @@ OpStatePtr CachedOp::Forward( int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); OpStatePtr op_state; - if (config_.static_alloc) { - op_state = StaticForward(default_ctx, inputs, outputs); - } else { - op_state = DynamicForward(default_ctx, inputs, outputs); + try { + if (config_.static_alloc) { + op_state = StaticForward(default_ctx, inputs, outputs); + } else { + op_state = DynamicForward(default_ctx, inputs, outputs); + } + } catch (const dmlc::Error& e) { + Engine::Get()->set_bulk_size(prev_bulk_size); + throw e; } Engine::Get()->set_bulk_size(prev_bulk_size); @@ -1058,10 +1063,15 @@ void CachedOp::Backward( int prev_bulk_size = Engine::Get()->set_bulk_size(config_.backward_bulk_size); - if (config_.static_alloc) { - StaticBackward(retain_graph, state, inputs, reqs, outputs); - } else { - DynamicBackward(retain_graph, state, inputs, reqs, outputs); + try { + if (config_.static_alloc) { + StaticBackward(retain_graph, state, inputs, reqs, outputs); + } else { + DynamicBackward(retain_graph, state, inputs, reqs, outputs); + } + } catch (const dmlc::Error& e) { + Engine::Get()->set_bulk_size(prev_bulk_size); + throw e; } Engine::Get()->set_bulk_size(prev_bulk_size); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 0c5ff841775..32ff8d33813 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -494,9 +494,16 @@ std::vector<NDArray*> Imperative::Backward( bool prev_training = set_is_training(is_train); int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); - RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - is_recording()); + try { + RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + is_recording()); + } catch (const dmlc::Error& e) { + Engine::Get()->set_bulk_size(prev_bulk_size); + set_is_recording(prev_recording); + set_is_training(prev_training); + throw e; + } Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index bfe9592e5d0..eee3adda2c6 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -594,6 +594,7 @@ def test_cell_fill_shape(): @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_layer_fill_shape(): layer = gluon.rnn.LSTM(10) + layer.hybridize() check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) print(layer) assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1] ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
