piiswrong commented on a change in pull request #7067: variational dropout cell
URL: https://github.com/apache/incubator-mxnet/pull/7067#discussion_r130946499
##########
File path: python/mxnet/gluon/rnn/rnn_cell.py
##########
@@ -719,6 +719,68 @@ def unroll(self, length, inputs, begin_state=None,
layout='NTC', merge_outputs=N
return outputs, states
+class VariationalDropoutCell(ModifierCell):
+ """
+ Applies Variational Dropout on base cell.
+ (https://arxiv.org/pdf/1512.05287.pdf,
+ https://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf).
+ """
+ def __init__(self, base_cell, vardrop_inputs=0., vardrop_outputs=0.,
vardrop_states=0.):
+ assert not isinstance(base_cell, BidirectionalCell), \
+ "BidirectionalCell doesn't support vardrop since it doesn't
support step. " \
+ "Please add VariationalDropoutCell to the cells underneath
instead."
+ assert not isinstance(base_cell, SequentialRNNCell) or not
base_cell._bidirectional, \
+ "Bidirectional SequentialRNNCell doesn't support vardrop. " \
+ "Please add VariationalDropoutCell to the cells underneath
instead."
+ super(VariationalDropoutCell, self).__init__(base_cell)
+ self.vardrop_inputs = vardrop_inputs
+ self.vardrop_states = vardrop_states
+ self.vardrop_outputs = vardrop_outputs
+ self.vardrop_inputs_mask = None
+ self.vardrop_states_mask = None
+ self.vardrop_outputs_mask = None
+
+ def _alias(self):
+ return 'vardrop'
+
+ def reset(self):
+ super(VariationalDropoutCell, self).reset()
Review comment:
Looks like reset must be called before each sequence, otherwise the mask
will always be the same. note this in the doc
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services