taliesinb opened a new issue #13264: RNN operator crashes when gradient is 
requested when is_train=False
URL: https://github.com/apache/incubator-mxnet/issues/13264
 
 
   For MXNet release version 1.3.0, the RNN operator is using the `is_train` 
parameter to decide whether to allocate space it needs for calculating 
gradients. Therefore if you request gradients during inference, the backward 
pass will fail when it notices that `init_space_` is `false`. 
   
   ```
   import numpy as np
   import mxnet as mx
   
   X = mx.sym.Variable('x')
   Params = mx.sym.Variable('params')
   HX = mx.sym.Variable('state')
   T, N, I, H = 300, 20, 800, 800
   rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
                                 state_size=H, num_layers=1, mode='gru', 
state_outputs=True, name='GRU')
   
   out_grad = mx.nd.ones([T,N,H])
   state_grad = mx.nd.ones([1,N,H])
   
   exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I), grad_req='write')
   
   exe.forward(is_train=False)
   exe.backward(is_train=False, out_grads=[out_grad, state_grad])
   print(exe.grad_arrays[0].asnumpy())
   ```
   
   This yields the following stacktrace:
   
   ```
   Traceback (most recent call last):
     File "Untitled 2.py", line 22, in <module>
       print(exe.grad_arrays[0].asnumpy())
     File 
"/Users/taliesinb/anaconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py",
 line 1972, in asnumpy
       ctypes.c_size_t(data.size)))
     File 
"/Users/taliesinb/anaconda3/lib/python3.7/site-packages/mxnet/base.py", line 
252, in check_call
       raise MXNetError(py_str(_LIB.MXGetLastError()))
   mxnet.base.MXNetError: [17:59:42] src/operator/./rnn-inl.h:551: Check 
forward init error
   
   Stack trace returned 10 entries:
   [bt] (0) 0   libmxnet.so                         0x00000001182fab90 
libmxnet.so + 15248
   [bt] (1) 1   libmxnet.so                         0x00000001182fa93f 
libmxnet.so + 14655
   [bt] (2) 2   libmxnet.so                         0x00000001182fa569 
libmxnet.so + 13673
   [bt] (3) 3   libmxnet.so                         0x0000000119d05380 
MXTVMBridge + 3787312
   [bt] (4) 4   libmxnet.so                         0x000000011995bbbf 
MXNDListFree + 1733503
   [bt] (5) 5   libmxnet.so                         0x00000001197d30c7 
MXNDListFree + 125063
   [bt] (6) 6   libmxnet.so                         0x00000001197f93ae 
MXNDListFree + 281454
   [bt] (7) 7   libmxnet.so                         0x00000001197c6ec8 
MXNDListFree + 75400
   [bt] (8) 8   libmxnet.so                         0x00000001197ca261 
MXNDListFree + 88609
   [bt] (9) 9   libmxnet.so                         0x00000001197ca17f 
MXNDListFree + 88383
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to