sxjscience commented on a change in pull request #7264: gluon conv rnns
URL: https://github.com/apache/incubator-mxnet/pull/7264#discussion_r140282753
 
 

 ##########
 File path: python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
 ##########
 @@ -0,0 +1,971 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=arguments-differ, too-many-lines
+# coding: utf-8
+"""Definition of various recurrent neural network cells."""
+
+from math import floor
+
+from ...rnn import HybridRecurrentCell
+
+
+def _get_conv_out_size(dimensions, kernels, paddings, strides, dilations):
+    return tuple(int(floor((x+2*p-d*(k-1)-1)/s)+1) if x else 0 for x, k, p, s, 
d in
+                 zip(dimensions, kernels, paddings, strides, dilations))
+
+
+class _BaseConvRNNCell(HybridRecurrentCell):
+    """Abstract base class for convolutional RNNs"""
+    def __init__(self, hidden_channels, input_shape, activation,
+                 i2h_kernel, i2h_stride, i2h_pad, i2h_dilate,
+                 h2h_kernel, h2h_dilate,
+                 i2h_weight_initializer, h2h_weight_initializer,
+                 i2h_bias_initializer, h2h_bias_initializer,
+                 dims,
+                 conv_layout='NCHW',
+                 prefix=None, params=None):
+        super(_BaseConvRNNCell, self).__init__(prefix=prefix, params=params)
+
+        self._hidden_channels = hidden_channels
+        self._input_shape = input_shape
+        self._conv_layout = conv_layout
+        self._activation = activation
+
+        # Convolution setting
+        assert all(isinstance(spec, int) or len(spec) == dims
+                   for spec in [i2h_kernel, i2h_stride, i2h_pad, i2h_dilate,
+                                h2h_kernel, h2h_dilate]), \
+               "For {dims}D convolution, the convolution settings can only be 
either int " \
+               "or tuple of length {dims}".format(dims=dims)
+
+        self._i2h_kernel = i2h_kernel if isinstance(i2h_kernel, tuple) else 
(i2h_kernel,) * dims
+        self._i2h_stride = i2h_stride if isinstance(i2h_stride, tuple) else 
(i2h_stride,) * dims
+        self._i2h_pad = i2h_pad if isinstance(i2h_pad, tuple) else (i2h_pad,) 
* dims
+        self._i2h_dilate = i2h_dilate if isinstance(i2h_dilate, tuple) else 
(i2h_dilate,) * dims
+        self._h2h_kernel = h2h_kernel if isinstance(h2h_kernel, tuple) else 
(h2h_kernel,) * dims
+        assert all(k % 2 == 1 for k in self._h2h_kernel), \
+            "Only support odd number, get h2h_kernel= %s" % str(h2h_kernel)
+        self._h2h_dilate = h2h_dilate if isinstance(h2h_dilate, tuple) else 
(h2h_dilate,) * dims
+        self._h2h_pad = tuple(d*(k-1)//2 for d, k in zip(self._h2h_dilate, 
self._h2h_kernel))
+
+        self._channel_axis, \
+        self._in_channels, \
+        i2h_param_shape, \
+        h2h_param_shape, \
+        self._state_shape = self._decide_shapes(dims)
+
+        self.i2h_weight = self.params.get('i2h_weight', shape=i2h_param_shape,
+                                          init=i2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2h_weight = self.params.get('h2h_weight', shape=h2h_param_shape,
+                                          init=h2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.i2h_bias = self.params.get('i2h_bias', 
shape=(hidden_channels*self._num_gates,),
+                                        init=i2h_bias_initializer,
+                                        allow_deferred_init=True)
+        self.h2h_bias = self.params.get('h2h_bias', 
shape=(hidden_channels*self._num_gates,),
+                                        init=h2h_bias_initializer,
+                                        allow_deferred_init=True)
+
+    def _decide_shapes(self, dims):
+        channel_axis = self._conv_layout.find('C')
+        input_shape = self._input_shape
+        in_channels = input_shape[channel_axis - 1]
+        hidden_channels = self._hidden_channels
+        if channel_axis == 1:
+            dimensions = input_shape[1:]
+        else:
+            dimensions = input_shape[:-1]
+
+        h2h_strides = (1,) * dims
+        total_out = hidden_channels * self._num_gates
+
+        i2h_param_shape = (total_out,)
+        h2h_param_shape = (total_out,)
+        state_shape = (hidden_channels,)
+        conv_out_size = _get_conv_out_size(dimensions,
+                                           self._h2h_kernel,
+                                           self._h2h_pad,
+                                           h2h_strides,
+                                           self._h2h_dilate)
+        if channel_axis == 1:
+            i2h_param_shape += (in_channels,) + self._i2h_kernel
+            h2h_param_shape += (hidden_channels,) + self._h2h_kernel
+            state_shape += conv_out_size
+        else:
+            i2h_param_shape += self._i2h_kernel + (in_channels,)
+            h2h_param_shape += self._h2h_kernel + (hidden_channels,)
+            state_shape = conv_out_size + state_shape
+
+        return channel_axis, in_channels, i2h_param_shape, h2h_param_shape, 
state_shape
+
+    def __repr__(self):
+        s = '{name}({mapping}'
+        if hasattr(self, '_activation'):
+            s += ', {_activation}'
+        s += ', {_conv_layout}'
+        s += ')'
+        attrs = self.__dict__
+        mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if 
self._in_channels
+                   else self._hidden_channels)
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **attrs)
+
+    @property
+    def _num_gates(self):
+        return len(self._gate_names)
+
+    def _conv_forward(self, F, inputs, states,
+                      i2h_weight, h2h_weight, i2h_bias, h2h_bias,
+                      prefix):
+        i2h = F.Convolution(data=inputs,
+                            num_filter=self._hidden_channels*self._num_gates,
+                            kernel=self._i2h_kernel,
+                            stride=self._i2h_stride,
+                            pad=self._i2h_pad,
+                            dilate=self._i2h_dilate,
+                            weight=i2h_weight,
+                            bias=i2h_bias,
+                            layout=self._conv_layout,
+                            name=prefix+'i2h')
+        h2h = F.Convolution(data=states[0],
+                            num_filter=self._hidden_channels*self._num_gates,
+                            kernel=self._h2h_kernel,
+                            dilate=self._h2h_dilate,
+                            pad=self._h2h_pad,
+                            stride=(1,)*len(self._h2h_kernel),
+                            weight=h2h_weight,
+                            bias=h2h_bias,
+                            layout=self._conv_layout,
+                            name=prefix+'h2h')
+        return i2h, h2h
+
+    def state_info(self, batch_size=0):
+        raise NotImplementedError("_BaseConvRNNCell is abstract class for 
convolutional RNN")
+
+    def hybrid_forward(self, F, inputs, states):
+        raise NotImplementedError("_BaseConvRNNCell is abstract class for 
convolutional RNN")
+
+
+class _ConvRNNCell(_BaseConvRNNCell):
+    def __init__(self, hidden_channels, input_shape, activation,
+                 i2h_kernel, i2h_stride, i2h_pad, i2h_dilate, h2h_kernel, 
h2h_dilate,
+                 i2h_weight_initializer, h2h_weight_initializer,
+                 i2h_bias_initializer, h2h_bias_initializer,
+                 dims, conv_layout, prefix, params):
+        super(_ConvRNNCell, self).__init__(hidden_channels=hidden_channels,
+                                           input_shape=input_shape,
+                                           activation=activation,
+                                           i2h_kernel=i2h_kernel, 
i2h_stride=i2h_stride,
+                                           i2h_pad=i2h_pad, 
i2h_dilate=i2h_dilate,
+                                           h2h_kernel=h2h_kernel, 
h2h_dilate=h2h_dilate,
+                                           
i2h_weight_initializer=i2h_weight_initializer,
+                                           
h2h_weight_initializer=h2h_weight_initializer,
+                                           
i2h_bias_initializer=i2h_bias_initializer,
+                                           
h2h_bias_initializer=h2h_bias_initializer,
+                                           dims=dims,
+                                           conv_layout=conv_layout,
+                                           prefix=prefix, params=params)
+
+    def state_info(self, batch_size=0):
+        return [{'shape': (batch_size,)+self._state_shape, '__layout__': 
self._conv_layout}]
+
+    def _alias(self):
+        return 'conv_rnn'
+
+    @property
+    def _gate_names(self):
+        return ('',)
+
+    def hybrid_forward(self, F, inputs, states, i2h_weight,
+                       h2h_weight, i2h_bias, h2h_bias):
+        prefix = 't%d_'%self._counter
+        i2h, h2h = self._conv_forward(F, inputs, states,
+                                      i2h_weight, h2h_weight, i2h_bias, 
h2h_bias,
+                                      prefix)
+        output = self._get_activation(F, i2h + h2h, self._activation,
+                                      name=prefix+'out')
+        return output, [output]
+
+
+class Conv1DRNNCell(_ConvRNNCell):
+    r"""1D Convolutional RNN cell.
+
+    .. math::
+
+        h_t = tanh(W_i \ast x_t + R_i \ast h_{t-1} + b_i)
+
+    Parameters
+    ----------
+    hidden_channels : int
+        Number of output channels.
+    input_shape : tuple of int
+        Input tensor shape at each time step for each sample, excluding 
dimension of the batch size
+        and sequence length. Must be consistent with `conv_layout`.
+        For example, for layout 'NCW' the shape should be (C, W).
+    activation : str or Block, default 'tanh'
+        Type of activation function.
+        If argument type is string, it's equivalent to 
nn.Activation(act_type=str). See
+        :func:`~mxnet.ndarray.Activation` for available choices.
+        Alternatively, other activation blocks such as nn.LeakyReLU can be 
used.
+    i2h_kernel : int or tuple of int, default (3,)
+        Input convolution kernel sizes.
+    i2h_stride : int or tuple of int, default (1,)
+        Input convolution stride sizes.
+    i2h_pad : int or tuple of int, default (1,)
+        Pad for input convolution.
+    i2h_dilate : int or tuple of int, default (1,)
+        Input convolution dilate.
+    h2h_kernel : int or tuple of int, default (3,)
+        Recurrent convolution kernel sizes. Only odd-numbered sizes are 
supported.
+    h2h_dilate : int or tuple of int, default (1,)
+        Recurrent convolution dilate.
+    i2h_weight_initializer : str or Initializer
+        Initializer for the input weights matrix, used for the input 
convolutions.
+    h2h_weight_initializer : str or Initializer
+        Initializer for the recurrent weights matrix, used for the input 
convolutions.
+    i2h_bias_initializer : str or Initializer, default zeros
+        Initializer for the input convolution bias vectors.
+    h2h_bias_initializer : str or Initializer, default zeros
+        Initializer for the recurrent convolution bias vectors.
+    conv_layout : str, default 'NCW'
+        Layout for all convolution inputs, outputs and weights. Options are 
'NCW' and 'NWC'.
+    prefix : str, default 'conv_rnn_'
+        Prefix for name of layers (and name of weight if params is None).
+    params : RNNParams, default None
+        Container for weight sharing between cells. Created if None.
+    """
+    def __init__(self, hidden_channels, input_shape, activation='tanh',
+                 i2h_kernel=(3,), i2h_stride=(1,), i2h_pad=(1,), 
i2h_dilate=(1,),
 
 Review comment:
   I think we should remove the default value of i2h_kernel and h2h_kernel and 
keep the default value of i2h_pad, i2h_stride,... h2h_dilate to be the same as 
that of the convolution layer.
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to