szha commented on a change in pull request #20262:
URL: https://github.com/apache/incubator-mxnet/pull/20262#discussion_r650168123
##########
File path: python/mxnet/gluon/rnn/rnn_layer.py
##########
@@ -182,65 +180,79 @@ def __call__(self, inputs, states=None,
sequence_length=None, **kwargs):
else:
return super(_RNNLayer, self).__call__(inputs, states, **kwargs)
- def hybrid_forward(self, F, inputs, states, sequence_length=None,
**kwargs):
- if F is ndarray:
- batch_size = inputs.shape[self._layout.find('N')]
+ def forward(self, inputs, states, sequence_length=None):
+ batch_size = inputs.shape[self._layout.find('N')]
- if F is ndarray:
- for state, info in zip(states, self.state_info(batch_size)):
- if state.shape != info['shape']:
- raise ValueError(
- "Invalid recurrent state shape. Expecting %s, got
%s."%(
- str(info['shape']), str(state.shape)))
- out = self._forward_kernel(F, inputs, states, sequence_length,
**kwargs)
+ for state, info in zip(states, self.state_info(batch_size)):
+ if state.shape != info['shape']:
+ raise ValueError(
+ "Invalid recurrent state shape. Expecting %s, got %s."%(
+ str(info['shape']), str(state.shape)))
+ out = self._forward_kernel(inputs, states, sequence_length)
# out is (output, state)
return out[0] if self.skip_states else out
- def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
+ def infer_shape(self, inputs, *args):
+ assert inputs.ndim == 3, \
+ "Input data should be rank-3 tensor of dim [sequence length, batch
size, input size]"
+ if not self._projection_size:
+ step = self._hidden_size
+ else:
+ step = self._projection_size
+ ni = inputs.shape[2]
+ for i in range(self._num_layers):
+ for j in ['l', 'r'][:self._dir]:
+ name = '{}{}_i2h_weight'.format(j, i)
+ getattr(self, name).shape = (self._gates*self._hidden_size, ni)
+ ni = step * self._dir
+
+ def _forward_kernel(self, inputs, states, sequence_length):
""" forward using CUDNN or CPU kenrel"""
- swapaxes = F.np.swapaxes if is_np_array() else F.swapaxes
+ ctx = inputs.ctx
if self._layout == 'NTC':
- inputs = swapaxes(inputs, 0, 1)
+ inputs = np.swapaxes(inputs, 0, 1)
if self._projection_size is None:
- params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+ params = (getattr(self, '{}{}_{}_{}'.format(d, l, g,
t)).data(ctx).reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
else:
- params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+ params = (getattr(self, '{}{}_{}_{}'.format(d, l, g,
t)).data(ctx).reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h', 'h2r']
if g != 'h2r' or t != 'bias')
- rnn_param_concat = F.np._internal.rnn_param_concat if is_np_array()\
- else F._internal._rnn_param_concat
- params = rnn_param_concat(*params, dim=0)
+ params = np.concatenate(params, axis=0)
Review comment:
just to be clear, there still is a performance penalty for this call.
it's ok to address it in a separate PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]