reminisce commented on a change in pull request #7067: variational dropout cell
URL: https://github.com/apache/incubator-mxnet/pull/7067#discussion_r140419392
 
 

 ##########
 File path: python/mxnet/gluon/rnn/rnn_cell.py
 ##########
 @@ -738,6 +736,164 @@ def unroll(self, length, inputs, begin_state=None, 
layout='NTC', merge_outputs=N
         return outputs, states
 
 
+class VariationalDropoutCell(ModifierCell):
+    """
+    Applies Variational Dropout on base cell.
+    (https://arxiv.org/pdf/1512.05287.pdf,
+     https://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf).
+
+    Variational dropout uses the same dropout mask across time-steps. It can 
be applied to RNN
+    inputs, outputs, and states. The masks for them are not shared.
+
+    The dropout mask is initialized when stepping forward for the first time 
and will remain
+    the same until .reset() is called. Thus, if using the cell and stepping 
manually without calling
+    .unroll(), the .reset() should be called after each sequence.
+
+    Parameters
+    ----------
+    base_cell : RecurrentCell
+        The cell on which to perform variational dropout.
+    drop_inputs : float, default 0.
+        The dropout rate for inputs. Won't apply dropout if it equals 0.
+    drop_states : float, default 0.
+        The dropout rate for state inputs on the first state channel.
+        Won't apply dropout if it equals 0.
+    drop_outputs : float, default 0.
+        The dropout rate for outputs. Won't apply dropout if it equals 0.
+    """
+    def __init__(self, base_cell, drop_inputs=0., drop_states=0., 
drop_outputs=0.):
+        assert not isinstance(base_cell, BidirectionalCell), \
+            "BidirectionalCell doesn't support vardrop since it doesn't 
support step. " \
+            "Please add VariationalDropoutCell to the cells underneath 
instead."
+        assert not isinstance(base_cell, SequentialRNNCell) or not 
base_cell._bidirectional, \
+            "Bidirectional SequentialRNNCell doesn't support vardrop. " \
+            "Please add VariationalDropoutCell to the cells underneath 
instead."
+        super(VariationalDropoutCell, self).__init__(base_cell)
+        self.drop_inputs = drop_inputs
+        self.drop_states = drop_states
+        self.drop_outputs = drop_outputs
+        self.drop_inputs_mask = None
+        self.drop_states_mask = None
+        self.drop_outputs_mask = None
+
+    def _alias(self):
+        return 'vardrop'
+
+    def reset(self):
+        super(VariationalDropoutCell, self).reset()
+        self.drop_inputs_mask = None
+        self.drop_states_mask = None
+        self.drop_outputs_mask = None
+
+    def _initialize_input_masks(self, F, inputs, states):
+        if self.drop_states and self.drop_states_mask is None:
+            self.drop_states_mask = F.Dropout(F.ones_like(states[0]),
+                                              p=self.drop_states)
+
+        if self.drop_inputs and self.drop_inputs_mask is None:
+            self.drop_inputs_mask = F.Dropout(F.ones_like(inputs),
+                                              p=self.drop_inputs)
+
+    def _initialize_output_mask(self, F, output):
+        if self.drop_outputs and self.drop_outputs_mask is None:
+            self.drop_outputs_mask = F.Dropout(F.ones_like(output),
+                                               p=self.drop_outputs)
+
+
+    def hybrid_forward(self, F, inputs, states):
 
 Review comment:
   Since this is exposed to users, it would be very helpful to add 
documentation here like you did for function unroll below.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to