taliesinb opened a new issue #13264: RNN operator crashes when gradient is requested when is_train=False URL: https://github.com/apache/incubator-mxnet/issues/13264 For MXNet release version 1.3.0, the RNN operator is using the `is_train` parameter to decide whether to allocate space it needs for calculating gradients. Therefore if you request gradients during inference, the backward pass will fail when it notices that `init_space_` is `false`. ``` import numpy as np import mxnet as mx X = mx.sym.Variable('x') Params = mx.sym.Variable('params') HX = mx.sym.Variable('state') T, N, I, H = 300, 20, 800, 800 rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_size=H, num_layers=1, mode='gru', state_outputs=True, name='GRU') out_grad = mx.nd.ones([T,N,H]) state_grad = mx.nd.ones([1,N,H]) exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I), grad_req='write') exe.forward(is_train=False) exe.backward(is_train=False, out_grads=[out_grad, state_grad]) print(exe.grad_arrays[0].asnumpy()) ``` This yields the following stacktrace: ``` Traceback (most recent call last): File "Untitled 2.py", line 22, in <module> print(exe.grad_arrays[0].asnumpy()) File "/Users/taliesinb/anaconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 1972, in asnumpy ctypes.c_size_t(data.size))) File "/Users/taliesinb/anaconda3/lib/python3.7/site-packages/mxnet/base.py", line 252, in check_call raise MXNetError(py_str(_LIB.MXGetLastError())) mxnet.base.MXNetError: [17:59:42] src/operator/./rnn-inl.h:551: Check forward init error Stack trace returned 10 entries: [bt] (0) 0 libmxnet.so 0x00000001182fab90 libmxnet.so + 15248 [bt] (1) 1 libmxnet.so 0x00000001182fa93f libmxnet.so + 14655 [bt] (2) 2 libmxnet.so 0x00000001182fa569 libmxnet.so + 13673 [bt] (3) 3 libmxnet.so 0x0000000119d05380 MXTVMBridge + 3787312 [bt] (4) 4 libmxnet.so 0x000000011995bbbf MXNDListFree + 1733503 [bt] (5) 5 libmxnet.so 0x00000001197d30c7 MXNDListFree + 125063 [bt] (6) 6 libmxnet.so 0x00000001197f93ae MXNDListFree + 281454 [bt] (7) 7 libmxnet.so 0x00000001197c6ec8 MXNDListFree + 75400 [bt] (8) 8 libmxnet.so 0x00000001197ca261 MXNDListFree + 88609 [bt] (9) 9 libmxnet.so 0x00000001197ca17f MXNDListFree + 88383 ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
