This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ccfef53 variational dropout cell (#7067)
ccfef53 is described below
commit ccfef5375805ced8ba2492d6deb542e32f12610b
Author: Sheng Zha <[email protected]>
AuthorDate: Fri Sep 29 10:35:27 2017 -0700
variational dropout cell (#7067)
* variational dropout cell
* update per comments
* rename, add tests
* add more tests
* add unroll for efficiency
* add comments
* move variational dropout to contrib
---
python/mxnet/gluon/contrib/rnn/__init__.py | 2 +
python/mxnet/gluon/contrib/rnn/rnn_cell.py | 182 ++++++++++++++++++++++++++++
tests/python/unittest/test_gluon_contrib.py | 38 ++++++
tests/python/unittest/test_gluon_rnn.py | 2 +-
4 files changed, 223 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/gluon/contrib/rnn/__init__.py
b/python/mxnet/gluon/contrib/rnn/__init__.py
index f27f205..e0a5220 100644
--- a/python/mxnet/gluon/contrib/rnn/__init__.py
+++ b/python/mxnet/gluon/contrib/rnn/__init__.py
@@ -20,3 +20,5 @@
"""Contrib recurrent neural network module."""
from .conv_rnn_cell import *
+
+from .rnn_cell import *
diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py
b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
new file mode 100644
index 0000000..df386a7
--- /dev/null
+++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Definition of various recurrent neural network cells."""
+
+from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell
+from ...rnn.rnn_cell import _format_sequence, _get_begin_state
+
+
+class VariationalDropoutCell(ModifierCell):
+ """
+ Applies Variational Dropout on base cell.
+ (https://arxiv.org/pdf/1512.05287.pdf,
+ https://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf).
+
+ Variational dropout uses the same dropout mask across time-steps. It can
be applied to RNN
+ inputs, outputs, and states. The masks for them are not shared.
+
+ The dropout mask is initialized when stepping forward for the first time
and will remain
+ the same until .reset() is called. Thus, if using the cell and stepping
manually without calling
+ .unroll(), the .reset() should be called after each sequence.
+
+ Parameters
+ ----------
+ base_cell : RecurrentCell
+ The cell on which to perform variational dropout.
+ drop_inputs : float, default 0.
+ The dropout rate for inputs. Won't apply dropout if it equals 0.
+ drop_states : float, default 0.
+ The dropout rate for state inputs on the first state channel.
+ Won't apply dropout if it equals 0.
+ drop_outputs : float, default 0.
+ The dropout rate for outputs. Won't apply dropout if it equals 0.
+ """
+ def __init__(self, base_cell, drop_inputs=0., drop_states=0.,
drop_outputs=0.):
+ assert not drop_states or not isinstance(base_cell,
BidirectionalCell), \
+ "BidirectionalCell doesn't support variational state dropout. " \
+ "Please add VariationalDropoutCell to the cells underneath
instead."
+ assert not drop_states \
+ or not isinstance(base_cell, SequentialRNNCell) or not
base_cell._bidirectional, \
+ "Bidirectional SequentialRNNCell doesn't support variational state
dropout. " \
+ "Please add VariationalDropoutCell to the cells underneath
instead."
+ super(VariationalDropoutCell, self).__init__(base_cell)
+ self.drop_inputs = drop_inputs
+ self.drop_states = drop_states
+ self.drop_outputs = drop_outputs
+ self.drop_inputs_mask = None
+ self.drop_states_mask = None
+ self.drop_outputs_mask = None
+
+ def _alias(self):
+ return 'vardrop'
+
+ def reset(self):
+ super(VariationalDropoutCell, self).reset()
+ self.drop_inputs_mask = None
+ self.drop_states_mask = None
+ self.drop_outputs_mask = None
+
+ def _initialize_input_masks(self, F, inputs, states):
+ if self.drop_states and self.drop_states_mask is None:
+ self.drop_states_mask = F.Dropout(F.ones_like(states[0]),
+ p=self.drop_states)
+
+ if self.drop_inputs and self.drop_inputs_mask is None:
+ self.drop_inputs_mask = F.Dropout(F.ones_like(inputs),
+ p=self.drop_inputs)
+
+ def _initialize_output_mask(self, F, output):
+ if self.drop_outputs and self.drop_outputs_mask is None:
+ self.drop_outputs_mask = F.Dropout(F.ones_like(output),
+ p=self.drop_outputs)
+
+
+ def hybrid_forward(self, F, inputs, states):
+ cell = self.base_cell
+ self._initialize_input_masks(F, inputs, states)
+
+ if self.drop_states:
+ states = list(states)
+ # state dropout only needs to be applied on h, which is always the
first state.
+ states[0] = states[0] * self.drop_states_mask
+
+ if self.drop_inputs:
+ inputs = inputs * self.drop_inputs_mask
+
+ next_output, next_states = cell(inputs, states)
+
+ self._initialize_output_mask(F, next_output)
+ if self.drop_outputs:
+ next_output = next_output * self.drop_outputs_mask
+
+ return next_output, next_states
+
+ def __repr__(self):
+ s = '{name}(p_out = {drop_outputs}, p_state = {drop_states})'
+ return s.format(name=self.__class__.__name__,
+ **self.__dict__)
+
+ def unroll(self, length, inputs, begin_state=None, layout='NTC',
merge_outputs=None):
+ """Unrolls an RNN cell across time steps.
+
+ Parameters
+ ----------
+ length : int
+ Number of steps to unroll.
+ inputs : Symbol, list of Symbol, or None
+ If `inputs` is a single Symbol (usually the output
+ of Embedding symbol), it should have shape
+ (batch_size, length, ...) if `layout` is 'NTC',
+ or (length, batch_size, ...) if `layout` is 'TNC'.
+
+ If `inputs` is a list of symbols (usually output of
+ previous unroll), they should all have shape
+ (batch_size, ...).
+ begin_state : nested list of Symbol, optional
+ Input states created by `begin_state()`
+ or output state of another cell.
+ Created from `begin_state()` if `None`.
+ layout : str, optional
+ `layout` of input symbol. Only used if inputs
+ is a single Symbol.
+ merge_outputs : bool, optional
+ If `False`, returns outputs as a list of Symbols.
+ If `True`, concatenates output across time steps
+ and returns a single symbol with shape
+ (batch_size, length, ...) if layout is 'NTC',
+ or (length, batch_size, ...) if layout is 'TNC'.
+ If `None`, output whatever is faster.
+
+ Returns
+ -------
+ outputs : list of Symbol or Symbol
+ Symbol (if `merge_outputs` is True) or list of Symbols
+ (if `merge_outputs` is False) corresponding to the output from
+ the RNN from this unrolling.
+
+ states : list of Symbol
+ The new state of this RNN after this unrolling.
+ The type of this symbol is same as the output of `begin_state()`.
+ """
+
+ # Dropout on inputs and outputs can be performed on the whole sequence
+ # only when state dropout is not present.
+ if self.drop_states:
+ return super(VariationalDropoutCell, self).unroll(length, inputs,
begin_state,
+ layout,
merge_outputs)
+
+ self.reset()
+
+ inputs, axis, F, batch_size = _format_sequence(length, inputs, layout,
True)
+ states = _get_begin_state(self, F, begin_state, inputs, batch_size)
+
+ if self.drop_inputs:
+ first_input = inputs.slice_axis(axis, 0, 1).split(1, axis=axis,
squeeze_axis=True)
+ self._initialize_input_masks(F, first_input, states)
+ inputs = F.broadcast_mul(inputs,
self.drop_inputs_mask.expand_dims(axis=axis))
+
+ outputs, states = self.base_cell.unroll(length, inputs, states,
layout, merge_outputs=True)
+ if self.drop_outputs:
+ first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis,
squeeze_axis=True)
+ self._initialize_output_mask(F, first_output)
+ outputs = F.broadcast_mul(outputs,
self.drop_outputs_mask.expand_dims(axis=axis))
+
+ outputs, _, _, _ = _format_sequence(length, outputs, layout,
merge_outputs)
+
+ return outputs, states
diff --git a/tests/python/unittest/test_gluon_contrib.py
b/tests/python/unittest/test_gluon_contrib.py
index 7921698..c99836c 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import mxnet as mx
from mxnet.gluon import contrib
+from mxnet.test_utils import almost_equal
import numpy as np
from numpy.testing import assert_allclose
@@ -93,6 +94,43 @@ def test_convgru():
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50),
out_shape=(1, 100, 18, 28, 48))
+def test_vardrop():
+ def check_vardrop(drop_inputs, drop_states, drop_outputs):
+ cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100,
prefix='rnn_'),
+ drop_outputs=drop_outputs,
+ drop_states=drop_states,
+ drop_inputs=drop_inputs)
+ cell.collect_params().initialize(init='xavier')
+ input_data = mx.nd.random_uniform(shape=(10, 3, 50),
ctx=mx.context.current_context())
+ with mx.autograd.record():
+ outputs1, _ = cell.unroll(3, input_data, merge_outputs=True)
+ mask1 = cell.drop_outputs_mask.asnumpy()
+ mx.nd.waitall()
+ outputs2, _ = cell.unroll(3, input_data, merge_outputs=True)
+ mask2 = cell.drop_outputs_mask.asnumpy()
+ assert not almost_equal(mask1, mask2)
+ assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy())
+
+ inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
+ outputs, _ = cell.unroll(3, inputs, merge_outputs=False)
+ outputs = mx.sym.Group(outputs)
+
+ args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50),
rnn_t1_data=(10,50), rnn_t2_data=(10,50))
+ assert outs == [(10, 100), (10, 100), (10, 100)]
+
+ cell.reset()
+ cell.hybridize()
+ with mx.autograd.record():
+ outputs3, _ = cell.unroll(3, input_data, merge_outputs=True)
+ mx.nd.waitall()
+ outputs4, _ = cell.unroll(3, input_data, merge_outputs=True)
+ assert not almost_equal(outputs3.asnumpy(), outputs4.asnumpy())
+ assert not almost_equal(outputs1.asnumpy(), outputs3.asnumpy())
+
+ check_vardrop(0.5, 0.5, 0.5)
+ check_vardrop(0.5, 0, 0.5)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_gluon_rnn.py
b/tests/python/unittest/test_gluon_rnn.py
index 89da900..7d2842a 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -19,6 +19,7 @@ import mxnet as mx
from mxnet import gluon
import numpy as np
from numpy.testing import assert_allclose
+from mxnet.test_utils import almost_equal
def test_rnn():
@@ -205,7 +206,6 @@ def check_rnn_forward(layer, inputs, deterministic=True):
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(),
rtol=1e-3, atol=1e-5)
-
def test_rnn_cells():
check_rnn_forward(gluon.rnn.LSTMCell(100, input_size=200), mx.nd.ones((8,
3, 200)))
check_rnn_forward(gluon.rnn.RNNCell(100, input_size=200), mx.nd.ones((8,
3, 200)))
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].