[ 
https://issues.apache.org/jira/browse/MXNET-31?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16391710#comment-16391710
 ] 

ASF GitHub Bot commented on MXNET-31:
-------------------------------------

szha closed pull request #9934: [MXNET-31] Support variable sequence length in 
gluon.RecurrentCell 
URL: https://github.com/apache/incubator-mxnet/pull/9934
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py 
b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
index d74c107df56..d6402b769cb 100644
--- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -20,7 +20,8 @@
 __all__ = ['VariationalDropoutCell']
 
 from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
-from ...rnn.rnn_cell import _format_sequence, _get_begin_state
+from ...rnn.rnn_cell import _format_sequence, _get_begin_state, 
_mask_sequence_variable_length
+from ... import tensor_types
 
 
 class VariationalDropoutCell(ModifierCell):
@@ -113,7 +114,8 @@ def __repr__(self):
         return s.format(name=self.__class__.__name__,
                         **self.__dict__)
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         """Unrolls an RNN cell across time steps.
 
         Parameters
@@ -143,6 +145,15 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
             (batch_size, length, ...) if layout is 'NTC',
             or (length, batch_size, ...) if layout is 'TNC'.
             If `None`, output whatever is faster.
+        valid_length : Symbol, NDArray or None
+            `valid_length` specifies the length of the sequences in the batch 
without padding.
+            This option is especially useful for building sequence-to-sequence 
models where
+            the input and output sequences would potentially be padded.
+            If `valid_length` is None, all sequences are assumed to have the 
same length.
+            If `valid_length` is a Symbol or NDArray, it should have shape 
(batch_size,).
+            The ith element will be the length of the ith sequence in the 
batch.
+            The last valid state will be return and the padded outputs will be 
masked with 0.
+            Note that `valid_length` must be smaller or equal to `length`.
 
         Returns
         -------
@@ -160,7 +171,8 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
         # only when state dropout is not present.
         if self.drop_states:
             return super(VariationalDropoutCell, self).unroll(length, inputs, 
begin_state,
-                                                              layout, 
merge_outputs)
+                                                              layout, 
merge_outputs,
+                                                              
valid_length=valid_length)
 
         self.reset()
 
@@ -172,12 +184,16 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
             self._initialize_input_masks(F, first_input, states)
             inputs = F.broadcast_mul(inputs, 
self.drop_inputs_mask.expand_dims(axis=axis))
 
-        outputs, states = self.base_cell.unroll(length, inputs, states, 
layout, merge_outputs=True)
+        outputs, states = self.base_cell.unroll(length, inputs, states, 
layout, merge_outputs=True,
+                                                valid_length=valid_length)
         if self.drop_outputs:
             first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, 
squeeze_axis=True)
             self._initialize_output_mask(F, first_output)
             outputs = F.broadcast_mul(outputs, 
self.drop_outputs_mask.expand_dims(axis=axis))
-
+        merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is 
None else \
+            merge_outputs
         outputs, _, _, _ = _format_sequence(length, outputs, layout, 
merge_outputs)
-
+        if valid_length is not None:
+            outputs = _mask_sequence_variable_length(F, outputs, length, 
valid_length, axis,
+                                                     merge_outputs)
         return outputs, states
diff --git a/python/mxnet/gluon/rnn/rnn_cell.py 
b/python/mxnet/gluon/rnn/rnn_cell.py
index ea0e32faebc..61bf24e8cd1 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -83,8 +83,7 @@ def _format_sequence(length, inputs, layout, merge, 
in_layout=None):
             F = ndarray
             batch_size = inputs[0].shape[batch_axis]
         if merge is True:
-            inputs = [F.expand_dims(i, axis=axis) for i in inputs]
-            inputs = F.concat(*inputs, dim=axis)
+            inputs = F.stack(*inputs, axis=axis)
             in_axis = axis
 
     if isinstance(inputs, tensor_types) and axis != in_axis:
@@ -92,6 +91,16 @@ def _format_sequence(length, inputs, layout, merge, 
in_layout=None):
 
     return inputs, axis, F, batch_size
 
+def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, 
merge):
+    assert valid_length is not None
+    if not isinstance(data, tensor_types):
+        data = F.stack(*data, axis=time_axis)
+    outputs = F.SequenceMask(data, sequence_length=valid_length, 
use_sequence_length=True,
+                             axis=time_axis)
+    if not merge:
+        outputs = _as_list(F.split(outputs, num_outputs=length, axis=time_axis,
+                                   squeeze_axis=True))
+    return outputs
 
 class RecurrentCell(Block):
     """Abstract base class for RNN cells
@@ -163,7 +172,8 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, 
**kwargs):
             states.append(state)
         return states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         """Unrolls an RNN cell across time steps.
 
         Parameters
@@ -193,6 +203,15 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
             (batch_size, length, ...) if layout is 'NTC',
             or (length, batch_size, ...) if layout is 'TNC'.
             If `None`, output whatever is faster.
+        valid_length : Symbol, NDArray or None
+            `valid_length` specifies the length of the sequences in the batch 
without padding.
+            This option is especially useful for building sequence-to-sequence 
models where
+            the input and output sequences would potentially be padded.
+            If `valid_length` is None, all sequences are assumed to have the 
same length.
+            If `valid_length` is a Symbol or NDArray, it should have shape 
(batch_size,).
+            The ith element will be the length of the ith sequence in the 
batch.
+            The last valid state will be return and the padded outputs will be 
masked with 0.
+            Note that `valid_length` must be smaller or equal to `length`.
 
         Returns
         -------
@@ -207,15 +226,24 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
         """
         self.reset()
 
-        inputs, _, F, batch_size = _format_sequence(length, inputs, layout, 
False)
+        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, 
False)
         begin_state = _get_begin_state(self, F, begin_state, inputs, 
batch_size)
 
         states = begin_state
         outputs = []
+        all_states = []
         for i in range(length):
             output, states = self(inputs[i], states)
             outputs.append(output)
-
+            if valid_length is not None:
+                all_states.append(states)
+        if valid_length is not None:
+            states = [F.SequenceLast(F.stack(*ele_list, axis=0),
+                                     sequence_length=valid_length,
+                                     use_sequence_length=True,
+                                     axis=0)
+                      for ele_list in zip(*all_states)]
+            outputs = _mask_sequence_variable_length(F, outputs, length, 
valid_length, axis, True)
         outputs, _, _, _ = _format_sequence(length, outputs, layout, 
merge_outputs)
 
         return outputs, states
@@ -645,7 +673,8 @@ def __call__(self, inputs, states):
             next_states.append(state)
         return inputs, sum(next_states, [])
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, _, F, batch_size = _format_sequence(length, inputs, layout, 
None)
@@ -658,8 +687,10 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
             n = len(cell.state_info())
             states = begin_state[p:p+n]
             p += n
-            inputs, states = cell.unroll(length, inputs=inputs, 
begin_state=states, layout=layout,
-                                         merge_outputs=None if i < num_cells-1 
else merge_outputs)
+            inputs, states = cell.unroll(length, inputs=inputs, 
begin_state=states,
+                                         layout=layout,
+                                         merge_outputs=None if i < num_cells-1 
else merge_outputs,
+                                         valid_length=valid_length)
             next_states.extend(states)
 
         return inputs, next_states
@@ -713,7 +744,8 @@ def hybrid_forward(self, F, inputs, states):
             inputs = F.Dropout(data=inputs, p=self.rate, 
name='t%d_fwd'%self._counter)
         return inputs, states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, _, F, _ = _format_sequence(length, inputs, layout, 
merge_outputs)
@@ -722,7 +754,7 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
         else:
             return super(DropoutCell, self).unroll(
                 length, inputs, begin_state=begin_state, layout=layout,
-                merge_outputs=merge_outputs)
+                merge_outputs=merge_outputs, valid_length=None)
 
 
 class ModifierCell(HybridRecurrentCell):
@@ -827,17 +859,23 @@ def hybrid_forward(self, F, inputs, states):
         output = F.elemwise_add(output, inputs, name='t%d_fwd'%self._counter)
         return output, states
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         self.base_cell._modified = False
         outputs, states = self.base_cell.unroll(length, inputs=inputs, 
begin_state=begin_state,
-                                                layout=layout, 
merge_outputs=merge_outputs)
+                                                layout=layout, 
merge_outputs=merge_outputs,
+                                                valid_length=valid_length)
         self.base_cell._modified = True
 
         merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is 
None else \
                         merge_outputs
-        inputs, _, F, _ = _format_sequence(length, inputs, layout, 
merge_outputs)
+        inputs, axis, F, _ = _format_sequence(length, inputs, layout, 
merge_outputs)
+        if valid_length is not None:
+            # mask the padded inputs to zero
+            inputs = _mask_sequence_variable_length(F, inputs, length, 
valid_length, axis,
+                                                    merge_outputs)
         if merge_outputs:
             outputs = F.elemwise_add(outputs, inputs)
         else:
@@ -880,34 +918,57 @@ def begin_state(self, **kwargs):
             "cell cannot be called directly. Call the modifier cell instead."
         return _cells_begin_state(self._children, **kwargs)
 
-    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None,
+               valid_length=None):
         self.reset()
 
         inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, 
False)
+        if valid_length is None:
+            reversed_inputs = list(reversed(inputs))
+        else:
+            reversed_inputs = F.SequenceReverse(F.stack(*inputs, axis=0),
+                                                sequence_length=valid_length,
+                                                use_sequence_length=True)
+            reversed_inputs = _as_list(F.split(reversed_inputs, axis=0, 
num_outputs=length,
+                                               squeeze_axis=True))
         begin_state = _get_begin_state(self, F, begin_state, inputs, 
batch_size)
 
         states = begin_state
         l_cell, r_cell = self._children
         l_outputs, l_states = l_cell.unroll(length, inputs=inputs,
                                             
begin_state=states[:len(l_cell.state_info(batch_size))],
-                                            layout=layout, 
merge_outputs=merge_outputs)
+                                            layout=layout, 
merge_outputs=merge_outputs,
+                                            valid_length=valid_length)
         r_outputs, r_states = r_cell.unroll(length,
-                                            inputs=list(reversed(inputs)),
+                                            inputs=reversed_inputs,
                                             
begin_state=states[len(l_cell.state_info(batch_size)):],
-                                            layout=layout, 
merge_outputs=merge_outputs)
-
+                                            layout=layout, merge_outputs=False,
+                                            valid_length=valid_length)
+        if valid_length is None:
+            reversed_r_outputs = list(reversed(r_outputs))
+        else:
+            reversed_r_outputs = F.SequenceReverse(F.stack(*r_outputs, axis=0),
+                                                   
sequence_length=valid_length,
+                                                   use_sequence_length=True,
+                                                   axis=0)
+            reversed_r_outputs = _as_list(F.split(reversed_r_outputs, axis=0, 
num_outputs=length,
+                                                  squeeze_axis=True))
         if merge_outputs is None:
-            merge_outputs = (isinstance(l_outputs, tensor_types)
-                             and isinstance(r_outputs, tensor_types))
+            merge_outputs = isinstance(l_outputs, tensor_types)
             l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, 
merge_outputs)
-            r_outputs, _, _, _ = _format_sequence(None, r_outputs, layout, 
merge_outputs)
+            reversed_r_outputs, _, _, _ = _format_sequence(None, 
reversed_r_outputs, layout,
+                                                           merge_outputs)
 
         if merge_outputs:
-            r_outputs = F.reverse(r_outputs, axis=axis)
-            outputs = F.concat(l_outputs, r_outputs, dim=2, 
name='%sout'%self._output_prefix)
+            reversed_r_outputs = F.stack(*reversed_r_outputs, axis=axis)
+            outputs = F.concat(l_outputs, reversed_r_outputs, dim=2,
+                               name='%sout'%self._output_prefix)
+
         else:
             outputs = [F.concat(l_o, r_o, dim=1, 
name='%st%d'%(self._output_prefix, i))
-                       for i, (l_o, r_o) in enumerate(zip(l_outputs, 
reversed(r_outputs)))]
-
+                       for i, (l_o, r_o) in enumerate(zip(l_outputs, 
reversed_r_outputs))]
+        if valid_length is not None:
+            outputs = _mask_sequence_variable_length(F, outputs, length, 
valid_length, axis,
+                                                     merge_outputs)
         states = l_states + r_states
         return outputs, states
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index 22888421925..871deeb26c4 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -257,6 +257,7 @@ def check_rnn_layer_forward(layer, inputs, states=None):
     mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, 
atol=1e-5)
     mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, 
atol=1e-5)
 
+
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
@@ -274,6 +275,81 @@ def test_rnn_layers():
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
 
+
+def test_rnn_unroll_variant_length():
+    # Test for imperative usage
+    cell_list = []
+    for base_cell_class in [gluon.rnn.RNNCell, gluon.rnn.LSTMCell, 
gluon.rnn.GRUCell]:
+        cell_list.append(base_cell_class(20))
+        cell_list.append(gluon.rnn.BidirectionalCell(
+                         l_cell=base_cell_class(20),
+                         r_cell=base_cell_class(20)))
+        
cell_list.append(gluon.contrib.rnn.VariationalDropoutCell(base_cell=base_cell_class(20)))
+    stack_res_rnn_cell = gluon.rnn.SequentialRNNCell()
+    
stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+    
stack_res_rnn_cell.add(gluon.rnn.ResidualCell(base_cell=gluon.rnn.RNNCell(20)))
+    cell_list.append(stack_res_rnn_cell)
+    batch_size = 4
+    max_length = 10
+    valid_length = [3, 10, 5, 6]
+    valid_length_nd = mx.nd.array(valid_length)
+    for cell in cell_list:
+        cell.collect_params().initialize()
+        cell.hybridize()
+        # Test for NTC layout
+        data_nd = mx.nd.random.normal(0, 1, shape=(batch_size, max_length, 20))
+        outs, states = cell.unroll(length=max_length, inputs=data_nd,
+                                   valid_length=valid_length_nd,
+                                   merge_outputs=True,
+                                   layout='NTC')
+        for i, ele_length in enumerate(valid_length):
+            # Explicitly unroll each sequence and compare the final states and 
output
+            ele_out, ele_states = cell.unroll(length=ele_length,
+                                              inputs=data_nd[i:(i+1), 
:ele_length, :],
+                                              merge_outputs=True,
+                                              layout='NTC')
+            assert_allclose(ele_out.asnumpy(), outs[i:(i+1), :ele_length, 
:].asnumpy(),
+                            atol=1E-4, rtol=1E-4)
+            if ele_length < max_length:
+                # Check the padded outputs are all zero
+                assert_allclose(outs[i:(i+1), ele_length:max_length, 
:].asnumpy(), 0)
+            for valid_out_state, gt_state in zip(states, ele_states):
+                assert_allclose(valid_out_state[i:(i+1)].asnumpy(), 
gt_state.asnumpy(),
+                                atol=1E-4, rtol=1E-4)
+
+        # Test for TNC layout
+        data_nd = mx.nd.random.normal(0, 1, shape=(max_length, batch_size, 20))
+        outs, states = cell.unroll(length=max_length, inputs=data_nd,
+                                   valid_length=valid_length_nd,
+                                   layout='TNC')
+        for i, ele_length in enumerate(valid_length):
+            # Explicitly unroll each sequence and compare the final states and 
output
+            ele_out, ele_states = cell.unroll(length=ele_length,
+                                              inputs=data_nd[:ele_length, 
i:(i+1), :],
+                                              merge_outputs=True,
+                                              layout='TNC')
+            assert_allclose(ele_out.asnumpy(), outs[:ele_length, i:(i + 1), 
:].asnumpy(),
+                            atol=1E-4, rtol=1E-4)
+            if ele_length < max_length:
+                # Check the padded outputs are all zero
+                assert_allclose(outs[ele_length:max_length, i:(i+1), 
:].asnumpy(), 0)
+            for valid_out_state, gt_state in zip(states, ele_states):
+                assert_allclose(valid_out_state[i:(i+1)].asnumpy(), 
gt_state.asnumpy(),
+                                atol=1E-4, rtol=1E-4)
+    # For symbolic test, we need to make sure that it can be binded and run
+    data = mx.sym.var('data', shape=(4, 10, 2))
+    cell = gluon.rnn.RNNCell(100)
+    valid_length = mx.sym.var('valid_length', shape=(4,))
+    outs, states = cell.unroll(length=10, inputs=data, 
valid_length=valid_length,
+                               merge_outputs=True, layout='NTC')
+    mod = mx.mod.Module(states[0], data_names=('data', 'valid_length'), 
label_names=None,
+                        context=mx.cpu())
+    mod.bind(data_shapes=[('data', (4, 10, 2)), ('valid_length', (4,))], 
label_shapes=None)
+    mod.init_params()
+    mod.forward(mx.io.DataBatch([mx.random.normal(0, 1, (4, 10, 2)), 
mx.nd.array([3, 6, 10, 2])]))
+    mod.get_outputs()[0].asnumpy()
+
+
 def test_cell_fill_shape():
     cell = gluon.rnn.LSTMCell(10)
     cell.hybridize()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


> Support variable sequence length in gluon.RecurrentCell
> -------------------------------------------------------
>
>                 Key: MXNET-31
>                 URL: https://issues.apache.org/jira/browse/MXNET-31
>             Project: Apache MXNet
>          Issue Type: New Feature
>            Reporter: Xingjian Shi
>            Priority: Major
>
> When the input sequences have different lengths, the common approach is to 
> pad them to the same length and feed the padded data into the recurrent 
> neural network. To deal with this scenario, this PR adds a new 
> {{valid_length}} option in {{unroll}}. {{valid_length}} refers to the real 
> length of the sequences before padding. When the {{valid_length}} is given, 
> the last valid state will be returned and the padded portion in the output 
> will be masked to be zero. This feature is essential for implementing a NMT 
> model.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to