This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ccfef53  variational dropout cell (#7067)
ccfef53 is described below

commit ccfef5375805ced8ba2492d6deb542e32f12610b
Author: Sheng Zha <[email protected]>
AuthorDate: Fri Sep 29 10:35:27 2017 -0700

    variational dropout cell (#7067)
    
    * variational dropout cell
    
    * update per comments
    
    * rename, add tests
    
    * add more tests
    
    * add unroll for efficiency
    
    * add comments
    
    * move variational dropout to contrib
---
 python/mxnet/gluon/contrib/rnn/__init__.py  |   2 +
 python/mxnet/gluon/contrib/rnn/rnn_cell.py  | 182 ++++++++++++++++++++++++++++
 tests/python/unittest/test_gluon_contrib.py |  38 ++++++
 tests/python/unittest/test_gluon_rnn.py     |   2 +-
 4 files changed, 223 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/contrib/rnn/__init__.py 
b/python/mxnet/gluon/contrib/rnn/__init__.py
index f27f205..e0a5220 100644
--- a/python/mxnet/gluon/contrib/rnn/__init__.py
+++ b/python/mxnet/gluon/contrib/rnn/__init__.py
@@ -20,3 +20,5 @@
 """Contrib recurrent neural network module."""
 
 from .conv_rnn_cell import *
+
+from .rnn_cell import *
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py 
b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
new file mode 100644
index 0000000..df386a7
--- /dev/null
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Definition of various recurrent neural network cells."""
+
+from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
+from ...rnn.rnn_cell import _format_sequence, _get_begin_state
+
+
+class VariationalDropoutCell(ModifierCell):
+    """
+    Applies Variational Dropout on base cell.
+    (https://arxiv.org/pdf/1512.05287.pdf,
+     https://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf).
+
+    Variational dropout uses the same dropout mask across time-steps. It can 
be applied to RNN
+    inputs, outputs, and states. The masks for them are not shared.
+
+    The dropout mask is initialized when stepping forward for the first time 
and will remain
+    the same until .reset() is called. Thus, if using the cell and stepping 
manually without calling
+    .unroll(), the .reset() should be called after each sequence.
+
+    Parameters
+    ----------
+    base_cell : RecurrentCell
+        The cell on which to perform variational dropout.
+    drop_inputs : float, default 0.
+        The dropout rate for inputs. Won't apply dropout if it equals 0.
+    drop_states : float, default 0.
+        The dropout rate for state inputs on the first state channel.
+        Won't apply dropout if it equals 0.
+    drop_outputs : float, default 0.
+        The dropout rate for outputs. Won't apply dropout if it equals 0.
+    """
+    def __init__(self, base_cell, drop_inputs=0., drop_states=0., 
drop_outputs=0.):
+        assert not drop_states or not isinstance(base_cell, 
BidirectionalCell), \
+            "BidirectionalCell doesn't support variational state dropout. " \
+            "Please add VariationalDropoutCell to the cells underneath 
instead."
+        assert not drop_states \
+               or not isinstance(base_cell, SequentialRNNCell) or not 
base_cell._bidirectional, \
+            "Bidirectional SequentialRNNCell doesn't support variational state 
dropout. " \
+            "Please add VariationalDropoutCell to the cells underneath 
instead."
+        super(VariationalDropoutCell, self).__init__(base_cell)
+        self.drop_inputs = drop_inputs
+        self.drop_states = drop_states
+        self.drop_outputs = drop_outputs
+        self.drop_inputs_mask = None
+        self.drop_states_mask = None
+        self.drop_outputs_mask = None
+
+    def _alias(self):
+        return 'vardrop'
+
+    def reset(self):
+        super(VariationalDropoutCell, self).reset()
+        self.drop_inputs_mask = None
+        self.drop_states_mask = None
+        self.drop_outputs_mask = None
+
+    def _initialize_input_masks(self, F, inputs, states):
+        if self.drop_states and self.drop_states_mask is None:
+            self.drop_states_mask = F.Dropout(F.ones_like(states[0]),
+                                              p=self.drop_states)
+
+        if self.drop_inputs and self.drop_inputs_mask is None:
+            self.drop_inputs_mask = F.Dropout(F.ones_like(inputs),
+                                              p=self.drop_inputs)
+
+    def _initialize_output_mask(self, F, output):
+        if self.drop_outputs and self.drop_outputs_mask is None:
+            self.drop_outputs_mask = F.Dropout(F.ones_like(output),
+                                               p=self.drop_outputs)
+
+
+    def hybrid_forward(self, F, inputs, states):
+        cell = self.base_cell
+        self._initialize_input_masks(F, inputs, states)
+
+        if self.drop_states:
+            states = list(states)
+            # state dropout only needs to be applied on h, which is always the 
first state.
+            states[0] = states[0] * self.drop_states_mask
+
+        if self.drop_inputs:
+            inputs = inputs * self.drop_inputs_mask
+
+        next_output, next_states = cell(inputs, states)
+
+        self._initialize_output_mask(F, next_output)
+        if self.drop_outputs:
+            next_output = next_output * self.drop_outputs_mask
+
+        return next_output, next_states
+
+    def __repr__(self):
+        s = '{name}(p_out = {drop_outputs}, p_state = {drop_states})'
+        return s.format(name=self.__class__.__name__,
+                        **self.__dict__)
+
+    def unroll(self, length, inputs, begin_state=None, layout='NTC', 
merge_outputs=None):
+        """Unrolls an RNN cell across time steps.
+
+        Parameters
+        ----------
+        length : int
+            Number of steps to unroll.
+        inputs : Symbol, list of Symbol, or None
+            If `inputs` is a single Symbol (usually the output
+            of Embedding symbol), it should have shape
+            (batch_size, length, ...) if `layout` is 'NTC',
+            or (length, batch_size, ...) if `layout` is 'TNC'.
+
+            If `inputs` is a list of symbols (usually output of
+            previous unroll), they should all have shape
+            (batch_size, ...).
+        begin_state : nested list of Symbol, optional
+            Input states created by `begin_state()`
+            or output state of another cell.
+            Created from `begin_state()` if `None`.
+        layout : str, optional
+            `layout` of input symbol. Only used if inputs
+            is a single Symbol.
+        merge_outputs : bool, optional
+            If `False`, returns outputs as a list of Symbols.
+            If `True`, concatenates output across time steps
+            and returns a single symbol with shape
+            (batch_size, length, ...) if layout is 'NTC',
+            or (length, batch_size, ...) if layout is 'TNC'.
+            If `None`, output whatever is faster.
+
+        Returns
+        -------
+        outputs : list of Symbol or Symbol
+            Symbol (if `merge_outputs` is True) or list of Symbols
+            (if `merge_outputs` is False) corresponding to the output from
+            the RNN from this unrolling.
+
+        states : list of Symbol
+            The new state of this RNN after this unrolling.
+            The type of this symbol is same as the output of `begin_state()`.
+        """
+
+        # Dropout on inputs and outputs can be performed on the whole sequence
+        # only when state dropout is not present.
+        if self.drop_states:
+            return super(VariationalDropoutCell, self).unroll(length, inputs, 
begin_state,
+                                                              layout, 
merge_outputs)
+
+        self.reset()
+
+        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, 
True)
+        states = _get_begin_state(self, F, begin_state, inputs, batch_size)
+
+        if self.drop_inputs:
+            first_input = inputs.slice_axis(axis, 0, 1).split(1, axis=axis, 
squeeze_axis=True)
+            self._initialize_input_masks(F, first_input, states)
+            inputs = F.broadcast_mul(inputs, 
self.drop_inputs_mask.expand_dims(axis=axis))
+
+        outputs, states = self.base_cell.unroll(length, inputs, states, 
layout, merge_outputs=True)
+        if self.drop_outputs:
+            first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, 
squeeze_axis=True)
+            self._initialize_output_mask(F, first_output)
+            outputs = F.broadcast_mul(outputs, 
self.drop_outputs_mask.expand_dims(axis=axis))
+
+        outputs, _, _, _ = _format_sequence(length, outputs, layout, 
merge_outputs)
+
+        return outputs, states
diff --git a/tests/python/unittest/test_gluon_contrib.py 
b/tests/python/unittest/test_gluon_contrib.py
index 7921698..c99836c 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -18,6 +18,7 @@
 from __future__ import print_function
 import mxnet as mx
 from mxnet.gluon import contrib
+from mxnet.test_utils import almost_equal
 import numpy as np
 from numpy.testing import assert_allclose
 
@@ -93,6 +94,43 @@ def test_convgru():
     check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), 
out_shape=(1, 100, 18, 28, 48))
 
 
+def test_vardrop():
+    def check_vardrop(drop_inputs, drop_states, drop_outputs):
+        cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, 
prefix='rnn_'),
+                                                  drop_outputs=drop_outputs,
+                                                  drop_states=drop_states,
+                                                  drop_inputs=drop_inputs)
+        cell.collect_params().initialize(init='xavier')
+        input_data = mx.nd.random_uniform(shape=(10, 3, 50), 
ctx=mx.context.current_context())
+        with mx.autograd.record():
+            outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
+            mask1 = cell.drop_outputs_mask.asnumpy()
+            mx.nd.waitall()
+            outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
+            mask2 = cell.drop_outputs_mask.asnumpy()
+        assert not almost_equal(mask1, mask2)
+        assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
+
+        inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+        outputs, _ = cell.unroll(3, inputs, merge_outputs=False)
+        outputs = mx.sym.Group(outputs)
+
+        args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), 
rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+        assert outs == [(10, 100), (10, 100), (10, 100)]
+
+        cell.reset()
+        cell.hybridize()
+        with mx.autograd.record():
+            outputs3, _ = cell.unroll(3, input_data, merge_outputs=True)
+            mx.nd.waitall()
+            outputs4, _ = cell.unroll(3, input_data, merge_outputs=True)
+        assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy())
+        assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy())
+
+    check_vardrop(0.5, 0.5, 0.5)
+    check_vardrop(0.5, 0, 0.5)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index 89da900..7d2842a 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -19,6 +19,7 @@ import mxnet as mx
 from mxnet import gluon
 import numpy as np
 from numpy.testing import assert_allclose
+from mxnet.test_utils import almost_equal
 
 
 def test_rnn():
@@ -205,7 +206,6 @@ def check_rnn_forward(layer, inputs, deterministic=True):
         mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3, atol=1e-5)
 
 
-
 def test_rnn_cells():
     check_rnn_forward(gluon.rnn.LSTMCell(100, input_size=200), mx.nd.ones((8, 
3, 200)))
     check_rnn_forward(gluon.rnn.RNNCell(100, input_size=200), mx.nd.ones((8, 
3, 200)))

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to