This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ccffde  Gluon LSTM Projection and Clipping Support (#13056)
1ccffde is described below

commit 1ccffde1d2bf2557801690eb82b7a06356183623
Author: Sheng Zha <[email protected]>
AuthorDate: Thu Nov 1 21:11:44 2018 -0700

    Gluon LSTM Projection and Clipping Support (#13056)
    
    * support projection in LSTM
    
    * add tests
    
    * update rnn to use cudnn ex
    
    * extend cudnn test to handle different versions
    
    * add lstm clip
    
    * use CUDNN_VERSION
    
    * merge USE_CUDNN_LSTM_CLIP and USE_CUDNN_LSTM_PROJ
    
    * assign false value to clip nan explicitly to RNN and  GRU
    
    * update test
---
 ci/docker/runtime_functions.sh          |   6 +
 python/mxnet/gluon/rnn/rnn_layer.py     | 130 ++++++++++----
 python/mxnet/test_utils.py              |   2 +-
 src/operator/cudnn_rnn-inl.h            | 305 +++++++++++++++++++++++++++++---
 src/operator/nn/concat.cc               |  21 ++-
 src/operator/rnn-inl.h                  |  84 +++++++--
 src/operator/rnn.cu                     |   2 +-
 tests/python/gpu/test_gluon_gpu.py      | 138 ++++++++++++++-
 tests/python/gpu/test_operator_gpu.py   |  16 +-
 tests/python/unittest/common.py         |  11 +-
 tests/python/unittest/test_gluon.py     |   7 +-
 tests/python/unittest/test_gluon_rnn.py |  15 +-
 tests/python/unittest/test_operator.py  |  20 +--
 13 files changed, 638 insertions(+), 119 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 43006f2..0adec07 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -702,6 +702,7 @@ unittest_ubuntu_python2_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1  # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_gpu.xml --verbose tests/python/gpu
 }
 
@@ -734,6 +735,7 @@ unittest_ubuntu_python3_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_gpu.xml --verbose tests/python/gpu
 }
 
@@ -750,6 +752,7 @@ unittest_ubuntu_tensorrt_gpu() {
     export PYTHONPATH=./python/
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
     export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
+    export CUDNN_VERSION=7.0.3
     python tests/python/tensorrt/lenet5_train.py
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_trt_gpu.xml --verbose --nocapture tests/python/tensorrt/
 }
@@ -761,6 +764,7 @@ unittest_ubuntu_python2_quantization_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1  # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
 }
 
@@ -771,6 +775,7 @@ unittest_ubuntu_python3_quantization_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
 }
 
@@ -865,6 +870,7 @@ unittest_centos7_cpu() {
 unittest_centos7_gpu() {
     set -ex
     cd /work/mxnet
+    export CUDNN_VERSION=7.0.3
     python3.6 -m "nose" $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file 
nosetests_gpu.xml --verbose tests/python/gpu
 }
 
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index e44b360..c43dc85 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -35,11 +35,14 @@ class _RNNLayer(HybridBlock):
                  dropout, bidirectional, input_size,
                  i2h_weight_initializer, h2h_weight_initializer,
                  i2h_bias_initializer, h2h_bias_initializer,
-                 mode, **kwargs):
+                 mode, projection_size, h2r_weight_initializer,
+                 lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
+                 **kwargs):
         super(_RNNLayer, self).__init__(**kwargs)
         assert layout in ('TNC', 'NTC'), \
             "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
         self._hidden_size = hidden_size
+        self._projection_size = projection_size if projection_size else None
         self._num_layers = num_layers
         self._mode = mode
         self._layout = layout
@@ -50,25 +53,50 @@ class _RNNLayer(HybridBlock):
         self._h2h_weight_initializer = h2h_weight_initializer
         self._i2h_bias_initializer = i2h_bias_initializer
         self._h2h_bias_initializer = h2h_bias_initializer
+        self._h2r_weight_initializer = h2r_weight_initializer
+        self._lstm_state_clip_min = lstm_state_clip_min
+        self._lstm_state_clip_max = lstm_state_clip_max
+        self._lstm_state_clip_nan = lstm_state_clip_nan
 
         self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
 
         ng, ni, nh = self._gates, input_size, hidden_size
-        for i in range(num_layers):
-            for j in ['l', 'r'][:self._dir]:
-                self._register_param('{}{}_i2h_weight'.format(j, i),
-                                     shape=(ng*nh, ni),
-                                     init=i2h_weight_initializer)
-                self._register_param('{}{}_h2h_weight'.format(j, i),
-                                     shape=(ng*nh, nh),
-                                     init=h2h_weight_initializer)
-                self._register_param('{}{}_i2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=i2h_bias_initializer)
-                self._register_param('{}{}_h2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=h2h_bias_initializer)
-            ni = nh * self._dir
+        if not projection_size:
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, nh),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                ni = nh * self._dir
+        else:
+            np = self._projection_size
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, np),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                    self._register_param('{}{}_h2r_weight'.format(j, i),
+                                         shape=(np, nh),
+                                         init=h2r_weight_initializer)
+                ni = np * self._dir
 
     def _register_param(self, name, shape, init):
         p = self.params.get(name, shape=shape, init=init,
@@ -114,6 +142,9 @@ class _RNNLayer(HybridBlock):
 
     def _unfuse(self):
         """Unfuses the fused RNN in to a stack of rnn cells."""
+        assert not self._projection_size, "_unfuse does not support projection 
layer yet!"
+        assert not self._lstm_state_clip_min and not 
self._lstm_state_clip_max, \
+                "_unfuse does not support state clipping yet!"
         get_cell = {'rnn_relu': lambda **kwargs: 
rnn_cell.RNNCell(self._hidden_size,
                                                                   
activation='relu',
                                                                   **kwargs),
@@ -189,7 +220,7 @@ class _RNNLayer(HybridBlock):
         skip_states = states is None
         if skip_states:
             if F is ndarray:
-                states = self.begin_state(batch_size, ctx=inputs.context)
+                states = self.begin_state(batch_size, ctx=inputs.context, 
dtype=inputs.dtype)
             else:
                 states = self.begin_state(0, func=symbol.zeros)
         if isinstance(states, tensor_types):
@@ -209,16 +240,29 @@ class _RNNLayer(HybridBlock):
         """ forward using CUDNN or CPU kenrel"""
         if self._layout == 'NTC':
             inputs = F.swapaxes(inputs, dim1=0, dim2=1)
-        params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
-                  for t in ['weight', 'bias']
-                  for l in range(self._num_layers)
-                  for d in ['l', 'r'][:self._dir]
-                  for g in ['i2h', 'h2h'])
+        if self._projection_size is None:
+            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h'])
+        else:
+            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h', 'h2r']
+                      if g != 'h2r' or t != 'bias')
+
         params = F._internal._rnn_param_concat(*params, dim=0)
 
         rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
+                    projection_size=self._projection_size,
                     num_layers=self._num_layers, bidirectional=self._dir == 2,
-                    p=self._dropout, state_outputs=True, mode=self._mode)
+                    p=self._dropout, state_outputs=True, mode=self._mode,
+                    lstm_state_clip_min=self._lstm_state_clip_min,
+                    lstm_state_clip_max=self._lstm_state_clip_max,
+                    lstm_state_clip_nan=self._lstm_state_clip_nan)
 
         if self._mode == 'lstm':
             outputs, states = rnn[0], [rnn[1], rnn[2]]
@@ -318,7 +362,8 @@ class RNN(_RNNLayer):
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, 
h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
-                                  'rnn_'+activation, **kwargs)
+                                  'rnn_'+activation, None, None, None, None, 
False,
+                                  **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
@@ -373,6 +418,20 @@ class LSTM(_RNNLayer):
         to zero.
     h2h_bias_initializer : str or Initializer
         Initializer for the bias vector.
+    projection_size: int, default None
+        The number of features after projection.
+    h2r_weight_initializer : str or Initializer, default None
+        Initializer for the projected recurrent weights matrix, used for the 
linear
+        transformation of the recurrent state to the projected space.
+    state_clip_min : float or None, default None
+        Minimum clip value of LSTM states. This option must be used together 
with
+        state_clip_max. If None, clipping is not applied.
+    state_clip_max : float or None, default None
+        Maximum clip value of LSTM states. This option must be used together 
with
+        state_clip_min. If None, clipping is not applied.
+    state_clip_nan : boolean, default False
+        Whether to stop NaN from propagating in state by clipping it to 
min/max.
+        If the clipping range is not specified, this option is ignored.
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
@@ -416,18 +475,28 @@ class LSTM(_RNNLayer):
                  dropout=0, bidirectional=False, input_size=0,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
+                 projection_size=None, h2r_weight_initializer=None,
+                 state_clip_min=None, state_clip_max=None, 
state_clip_nan=False,
                  **kwargs):
         super(LSTM, self).__init__(hidden_size, num_layers, layout,
                                    dropout, bidirectional, input_size,
                                    i2h_weight_initializer, 
h2h_weight_initializer,
                                    i2h_bias_initializer, h2h_bias_initializer,
-                                   'lstm', **kwargs)
+                                   'lstm', projection_size, 
h2r_weight_initializer,
+                                   state_clip_min, state_clip_max, 
state_clip_nan,
+                                   **kwargs)
 
     def state_info(self, batch_size=0):
-        return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                 '__layout__': 'LNC'},
-                {'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
-                 '__layout__': 'LNC'}]
+        if self._projection_size is None:
+            return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
+                     '__layout__': 'LNC'}]
+        else:
+            return [{'shape': (self._num_layers * self._dir, batch_size, 
self._projection_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
+                     '__layout__': 'LNC'}]
 
 
 class GRU(_RNNLayer):
@@ -522,7 +591,8 @@ class GRU(_RNNLayer):
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, 
h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
-                                  'gru', **kwargs)
+                                  'gru', None, None, None, None, False,
+                                  **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, 
self._hidden_size),
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 0bb28a0..5487e35 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -336,7 +336,7 @@ def rand_sparse_ndarray(shape, stype, density=None, 
dtype=None, distribution=Non
         assert(False), "unknown storage type"
         return False
 
-def rand_ndarray(shape, stype, density=None, dtype=None,
+def rand_ndarray(shape, stype='default', density=None, dtype=None,
                  modifier_func=None, shuffle_csr_indices=False, 
distribution=None):
     if stype == 'default':
         arr = mx.nd.array(random_arrays(shape), dtype=dtype)
diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index 077428f..7c450b7 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -26,6 +26,8 @@
 #ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_
 #define MXNET_OPERATOR_CUDNN_RNN_INL_H_
 
+#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
+
 #include <mxnet/storage.h>
 #include <vector>
 #include <map>
@@ -38,7 +40,7 @@ namespace mxnet {
 namespace op {
 #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 template<typename DType>
-class CuDNNRNNOp : public Operator{
+class CuDNNRNNOp : public Operator {
  public:
   explicit CuDNNRNNOp(RNNParam param) {
     this->param_ = param;
@@ -69,6 +71,32 @@ class CuDNNRNNOp : public Operator{
       default:
         LOG(FATAL) << "Not implmented";
     }
+#if USE_CUDNN_LSTM_PROJ
+    if (param_.projection_size.has_value()) {
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "Projection is only supported for LSTM.";
+      CHECK_GE(param_.state_size, param_.projection_size.value())
+        << "State size must be larger than projection size.";
+    }
+#else
+    CHECK(!param_.projection_size.has_value())
+      << "Projection is only supported for LSTM with CuDNN version later than 
7.1.1.";
+#endif
+#if USE_CUDNN_LSTM_PROJ
+    if (param_.lstm_state_clip_min.has_value()
+        || param_.lstm_state_clip_max.has_value()) {
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "State clipping is only supported for LSTM.";
+      CHECK(param_.lstm_state_clip_min.has_value() && 
param_.lstm_state_clip_max.has_value())
+        << "lstm_state_clip_min and lstm_state_clip_max must be specified 
together.";
+      CHECK_GE(param_.lstm_state_clip_max.value(), 
param_.lstm_state_clip_min.value())
+        << "lstm_state_clip_max must be greater or equal to 
lstm_state_clip_min";
+    }
+#else
+    CHECK(!param_.lstm_state_clip_min.has_value()
+          && !param_.lstm_state_clip_max.has_value())
+      << "State clipping is only supported for LSTM with CuDNN version later 
than 7.2.1.";
+#endif
     // RNN Direction
     direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : 
CUDNN_UNIDIRECTIONAL;
     // Other
@@ -92,6 +120,13 @@ class CuDNNRNNOp : public Operator{
 
     CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
     CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
+
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
+    #endif
   }
 
   ~CuDNNRNNOp() {
@@ -123,6 +158,12 @@ class CuDNNRNNOp : public Operator{
         Storage::Get()->Free(dropout_states_);
       }
     }
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
+    #endif
   }
 
   virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
@@ -169,7 +210,89 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
+    #if USE_CUDNN_LSTM_PROJ
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         param_.input_size_,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    int out_size =
+      (param_.projection_size.has_value()) ? param_.projection_size.value() : 
param_.state_size;
+    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         out_size,
+                                         seqLengthArray.data(),
+                                         nullptr));
     if (ctx.is_train) {
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+    }
+    #endif
+
+    #if USE_CUDNN_LSTM_PROJ
+    bool clip_state = param_.lstm_state_clip_min.has_value();
+    bool clip_nan = param_.lstm_state_clip_nan;
+    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+                               rnn_desc_,
+                               clip_state ? CUDNN_RNN_CLIP_MINMAX : 
CUDNN_RNN_CLIP_NONE,
+                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : 
CUDNN_PROPAGATE_NAN,
+                               clip_state ? param_.lstm_state_clip_min.value() 
: 0.0,
+                               clip_state ? param_.lstm_state_clip_max.value() 
: 0.0));
+    #endif
+
+    if (ctx.is_train) {
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #else
       CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
                                          rnn_desc_,
                                          param_.seq_length_,
@@ -191,8 +314,36 @@ class CuDNNRNNOp : public Operator{
                                          workspace_byte_,
                                          reserve_space_.dptr,
                                          reserve_space_byte_));
+      #endif
     } else {
-      // inference mode
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                            rnn_desc_,
+                                            x_data_desc_,
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_data_desc_,
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            temp_space.dptr_,
+                                            workspace_byte_));
+      #else
       CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
                                           rnn_desc_,
                                           param_.seq_length_,
@@ -212,6 +363,7 @@ class CuDNNRNNOp : public Operator{
                                           cy_ptr,
                                           temp_space.dptr_,
                                           workspace_byte_));
+      #endif
     }
   }
 
@@ -283,6 +435,52 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+                                      rnn_desc_,
+                                      y_data_desc_,
+                                      y.dptr_,
+                                      dy_data_desc_,
+                                      dy.dptr_,
+                                      nullptr,
+                                      nullptr,
+                                      dhy_desc_,
+                                      dhy_ptr,
+                                      dcy_desc_,
+                                      dcy_ptr,
+                                      w_desc_,
+                                      w.dptr_,
+                                      hx_desc_,
+                                      hx.dptr_,
+                                      cx_desc_,
+                                      cx_ptr,
+                                      dx_data_desc_,
+                                      dx.dptr_,
+                                      dhx_desc_,
+                                      dhx.dptr_,
+                                      dcx_desc_,
+                                      dcx_ptr,
+                                      nullptr,
+                                      nullptr,
+                                      temp_space.dptr_,
+                                      workspace_byte_,
+                                      reserve_space_.dptr,
+                                      reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+                                         rnn_desc_,
+                                         x_data_desc_,
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         y_data_desc_,
+                                         y.dptr_,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         dw_desc_,
+                                         dw.dptr_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+    #else
     CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
                                     rnn_desc_,
                                     param_.seq_length_,
@@ -325,6 +523,7 @@ class CuDNNRNNOp : public Operator{
                                        dw.dptr_,
                                        reserve_space_.dptr,
                                        reserve_space_byte_));
+    #endif
   }
 
  private:
@@ -367,8 +566,6 @@ class CuDNNRNNOp : public Operator{
         dimA[0] = param_.batch_size_;
         dimA[1] = param_.input_size_;
         dimA[2] = 1;
-        dimA[0] = param_.batch_size_;
-        dimA[1] = param_.input_size_;
         strideA[0] = dimA[2] * dimA[1];
         strideA[1] = dimA[2];
         strideA[2] = 1;
@@ -391,10 +588,10 @@ class CuDNNRNNOp : public Operator{
         strideA[2] = 1;
 
         CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
-                                             dtype_,
-                                             3,
-                                             dimA,
-                                             strideA));
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
         CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
                                               dtype_,
                                               3,
@@ -413,42 +610,85 @@ class CuDNNRNNOp : public Operator{
       strideA[0] = dimA[2] * dimA[1];
       strideA[1] = dimA[2];
       strideA[2] = 1;
+      #if USE_CUDNN_LSTM_PROJ
+      int dimB[3];
+      int strideB[3];
+      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimB[1] = param_.batch_size_;
+      dimB[2] = param_.projection_size.has_value() ?
+                param_.projection_size.value() : param_.state_size;
+      strideB[0] = dimB[2] * dimB[1];
+      strideB[1] = dimB[2];
+      strideB[2] = 1;
+      #endif
 
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
                                             dtype_,
                                             3,
@@ -470,26 +710,26 @@ class CuDNNRNNOp : public Operator{
                                            seed_));
       // RNN descriptors
       #if CUDNN_MAJOR >= 6
-        cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
-        CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
-                                            rnn_desc_,
-                                            param_.state_size,
-                                            param_.num_layers,
-                                            dropout_desc_,
-                                            input_mode_,
-                                            direction_,
-                                            mode_,
-                                            rnn_algo,
-                                            dtype_));
+      cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+      CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.state_size,
+                                          param_.num_layers,
+                                          dropout_desc_,
+                                          input_mode_,
+                                          direction_,
+                                          mode_,
+                                          rnn_algo,
+                                          dtype_));
       #else
-        CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
-                                         param_.state_size,
-                                         param_.num_layers,
-                                         dropout_desc_,
-                                         input_mode_,
-                                         direction_,
-                                         mode_,
-                                         dtype_));
+      CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
+                                       param_.state_size,
+                                       param_.num_layers,
+                                       dropout_desc_,
+                                       input_mode_,
+                                       direction_,
+                                       mode_,
+                                       dtype_));
       #endif
       #if CUDNN_MAJOR >= 7
         cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
@@ -503,6 +743,14 @@ class CuDNNRNNOp : public Operator{
       #endif
         CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
       #endif
+      #if USE_CUDNN_LSTM_PROJ
+      if (param_.projection_size.has_value()) {
+        CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
+                                               rnn_desc_,
+                                               param_.projection_size.value(),
+                                               0));
+      }
+      #endif
       // Get temp space sizes
       CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
                                           rnn_desc_,
@@ -591,6 +839,9 @@ class CuDNNRNNOp : public Operator{
   size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
   int workspace_size_, dropout_size_;
   std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, 
dy_desc_vec_;
+  #if USE_CUDNN_LSTM_PROJ
+  cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, 
dy_data_desc_;
+  #endif
   cudnnTensorDescriptor_t hx_desc_, cx_desc_;
   cudnnTensorDescriptor_t hy_desc_, cy_desc_;
   cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index ac8a814..544b253 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -86,14 +86,17 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
   TShape dshape;
   index_t size = 0;
-  int num_zero = 0;
+  std::vector<int> zero_indices;
   int axis = -1;
   for (int i = 0; i < param_.num_args; ++i) {
     TShape tmp = (*in_shape)[i];
     if (tmp.ndim()) {
       axis = CheckAxis(param_.dim, tmp.ndim());
-      num_zero += tmp[axis] == 0;
-      size += tmp[axis];
+      if (tmp[axis] == 0) {
+        zero_indices.emplace_back(i);
+      } else {
+        size += tmp[axis];
+      }
       tmp[axis] = 0;
       shape_assign(&dshape, tmp);
     }
@@ -113,18 +116,18 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& 
attrs,
         << "Incompatible input shape: expected " << dshape << ", got " << 
(*in_shape)[i];
   }
 
-  if (!num_zero) dshape[axis] = size;
+  if (zero_indices.empty()) dshape[axis] = size;
   CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << 
(*out_shape)[0];
-  if ((*out_shape)[0][axis] != 0 && num_zero) {
+  if ((*out_shape)[0][axis] != 0 && !zero_indices.empty()) {
     int residual = (*out_shape)[0][axis] - size;
     CHECK_GE(residual, 0)
         << "Input size already exceeds output size. Residual: " << residual;
-    CHECK(num_zero <= 2 && num_zero >= 0)
-        << "Expecting 1 or 2 inputs that need shape inference. Got: " << 
num_zero;
+    CHECK(zero_indices.size() <= 2 && zero_indices.size() >= 0)
+        << "Expecting 1 or 2 inputs that need shape inference. Got: " << 
zero_indices.size();
     bool need_infer = !(*out_shape)[0].Size();
-    for (int i = 0; i < num_zero; i++) {
-      (*in_shape)[i*2][axis] = residual / num_zero;
+    for (int i : zero_indices) {
+      (*in_shape)[i][axis] = residual / zero_indices.size();
       need_infer = need_infer || !(*in_shape)[i].Size();
     }
     return !need_infer;
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 11ca066..545e31b 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -54,7 +54,8 @@ inline int GetRnnParamSize(int num_layer,
                            int input_size,
                            int state_size,
                            int direction,
-                           int mode) {
+                           int mode,
+                           const dmlc::optional<int>& projection_size) {
   int size = state_size * direction;
   switch (mode) {
     case rnn_enum::kRnnRelu:
@@ -69,7 +70,15 @@ inline int GetRnnParamSize(int num_layer,
   }
   int size1 = (input_size + state_size + 2) * size;  // first layer size
   int size2 = (state_size * direction + state_size + 2) * size;  // other 
layers size
+  if (projection_size.has_value()) {
+    int proj_size = projection_size.value();
+    size1 = (input_size + proj_size + 2) * size;
+    size2 = (proj_size * direction + proj_size + 2) * size;
+  }
   int param_size = size1 + (num_layer - 1) * size2;
+  if (projection_size.has_value()) {
+    param_size += projection_size.value() * state_size * num_layer * direction;
+  }
   return param_size;
 }
 
@@ -154,6 +163,9 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   float p, pkeep_;
   int seq_length_, batch_size_, input_size_;
   bool lstm_q_;  // whether type is lstm
+  dmlc::optional<int> projection_size;
+  dmlc::optional<double> lstm_state_clip_min, lstm_state_clip_max;
+  bool lstm_state_clip_nan;
 
   DMLC_DECLARE_PARAMETER(RNNParam) {
     DMLC_DECLARE_FIELD(state_size)
@@ -174,10 +186,29 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
 
     DMLC_DECLARE_FIELD(p).set_default(0.)
     .set_range(0, 1)
-    .describe("Dropout probability, fraction of the input that gets dropped 
out at training time");
+    .describe("drop rate of the dropout on the outputs of each RNN layer, 
except the last layer.");
 
     DMLC_DECLARE_FIELD(state_outputs).set_default(false)
     .describe("Whether to have the states as symbol outputs.");
+
+    DMLC_DECLARE_FIELD(projection_size)
+    .set_default(dmlc::optional<int>())
+    .describe("size of project size");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_min)
+    .set_default(dmlc::optional<double>())
+    .describe("Minimum clip value of LSTM states. This option must be used 
together with "
+              "lstm_state_clip_max.");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_max)
+    .set_default(dmlc::optional<double>())
+    .describe("Maximum clip value of LSTM states. This option must be used 
together with "
+              "lstm_state_clip_min.");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_nan)
+    .set_default(false)
+    .describe("Whether to stop NaN from propagating in state by clipping it to 
min/max. "
+              "If clipping range is not specified, this option is ignored.");
   }
 };
 
@@ -349,8 +380,15 @@ template<typename DType>
 class RNNOp : public Operator{
  public:
   explicit RNNOp(RNNParam p)
-    :param_(p), init_space_(false), reserve_space_size_(0)
-  {}
+    :param_(p), init_space_(false), reserve_space_size_(0) {
+    if (param_.projection_size.has_value()) {
+      LOG(FATAL) << "hidden layer projection is only supported for GPU with 
CuDNN later than 7.1.1";
+    }
+    if (param_.lstm_state_clip_min.has_value()
+        || param_.lstm_state_clip_max.has_value()) {
+      LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN 
later than 7.2.1";
+    }
+  }
 
   ~RNNOp() {
     if (init_space_) {
@@ -646,26 +684,33 @@ class RNNProp : public OperatorProperty {
     int input_size = dshape[2];
     int numDirections = param_.bidirectional ? 2 : 1;
     int total_layers = numDirections * param_.num_layers;  // double for 
bidirectional
+    int layer_size = (param_.projection_size.has_value()) ?
+                     param_.projection_size.value() : param_.state_size;
     SHAPE_ASSIGN_CHECK(*in_shape,
                        rnn_enum::kState,
-                       Shape3(total_layers, batch_size, param_.state_size));
+                       Shape3(total_layers, batch_size, layer_size));
     if (param_.mode == rnn_enum::kLstm)
       SHAPE_ASSIGN_CHECK(*in_shape,
-                        rnn_enum::kStateCell,
-                        Shape3(total_layers, batch_size, param_.state_size));
+                         rnn_enum::kStateCell,
+                         Shape3(total_layers, batch_size, param_.state_size));
 
     // calculate parameter vector length
     int param_size = GetRnnParamSize(param_.num_layers,
-                                    input_size,
-                                    param_.state_size,
-                                    numDirections,
-                                    param_.mode);
+                                     input_size,
+                                     param_.state_size,
+                                     numDirections,
+                                     param_.mode,
+                                     param_.projection_size);
     SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
 
     out_shape->clear();
     // output: [sequence len, batch, output size]
     TShape oshape = dshape;
-    oshape[2] = numDirections * param_.state_size;
+    if (param_.projection_size.has_value()) {
+      oshape[2] = numDirections * param_.projection_size.value();
+    } else {
+      oshape[2] = numDirections * param_.state_size;
+    }
     out_shape->push_back(oshape);
     if (!param_.state_outputs) {
       return true;
@@ -674,11 +719,20 @@ class RNNProp : public OperatorProperty {
       TShape outStateShape = dshape;
       outStateShape[0] = total_layers;
       outStateShape[1] = batch_size;
-      outStateShape[2] = param_.state_size;
+      if (param_.projection_size.has_value()) {
+        outStateShape[2] = param_.projection_size.value();
+      } else {
+        outStateShape[2] = param_.state_size;
+      }
       out_shape->push_back(outStateShape);
       // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm)
-        out_shape->push_back(outStateShape);
+      if (param_.mode == rnn_enum::kLstm) {
+        TShape cellStateShape = dshape;
+        cellStateShape[0] = total_layers;
+        cellStateShape[1] = batch_size;
+        cellStateShape[2] = param_.state_size;
+        out_shape->push_back(cellStateShape);
+      }
       return true;
     }
   }
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 5951793..402a8cf 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -40,7 +40,7 @@ Operator* CreateOp<gpu>(RNNParam param, int dtype) {
     op = new CuDNNRNNOp<DType>(param);
   })
 #else
-  LOG(FATAL) << "RNN is only available for cuDNN at the moment.";
+  LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
 #endif  // MXNET_USE_CUDNN && CUDNN_MAJOR
   return op;
 }
diff --git a/tests/python/gpu/test_gluon_gpu.py 
b/tests/python/gpu/test_gluon_gpu.py
index 8ada95b..54bfcee 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -22,6 +22,7 @@ import tempfile
 import time
 import multiprocessing as mp
 import unittest
+import random
 import mxnet as mx
 import numpy as np
 import unittest
@@ -31,11 +32,12 @@ from mxnet.test_utils import check_consistency, 
set_default_context, assert_almo
 from mxnet.base import MXNetError
 from mxnet import autograd
 from numpy.testing import assert_allclose
+from mxnet.test_utils import rand_ndarray
 
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_not_satisfied
 from test_gluon import *
 from test_loss import *
 from test_gluon_rnn import *
@@ -79,7 +81,81 @@ def check_rnn_layer_w_rand_inputs(layer):
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_lstmp():
+    hidden_size, projection_size = 3, 2
+    rtol, atol = 1e-2, 1e-2
+    batch_size, seq_len = 7, 11
+    input_size = 5
+    lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), 
ctx=mx.gpu(0))
+    shapes = {'i2h_weight': (hidden_size*4, input_size),
+              'h2h_weight': (hidden_size*4, projection_size),
+              'i2h_bias': (hidden_size*4,),
+              'h2h_bias': (hidden_size*4,),
+              'h2r_weight': (projection_size, hidden_size)}
+    weights = {k: rand_ndarray(v) for k, v in shapes.items()}
+    lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
+                                input_size=input_size, prefix='lstm0_')
+    lstm_cell = gluon.contrib.rnn.LSTMPCell(hidden_size=hidden_size,
+                                            projection_size=projection_size,
+                                            input_size=input_size,
+                                            prefix='lstm0_l0_')
+    lstm_layer.initialize(ctx=mx.gpu(0))
+    lstm_cell.initialize(ctx=mx.gpu(0))
+    layer_params = lstm_layer.collect_params()
+    cell_params = lstm_cell.collect_params()
+    for k, v in weights.items():
+        layer_params['lstm0_l0_'+k].set_data(v.copy())
+        cell_params['lstm0_l0_'+k].set_data(v.copy())
+    with autograd.record():
+        layer_output = lstm_layer(lstm_input.copy())
+        cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), 
layout='TNC',
+                                       merge_outputs=True)[0]
+    assert_almost_equal(layer_output.asnumpy(), cell_output.asnumpy(), 
rtol=rtol, atol=atol)
+    layer_output.backward()
+    cell_output.backward()
+    for k, v in weights.items():
+        layer_grad = layer_params['lstm0_l0_'+k].grad()
+        cell_grad = cell_params['lstm0_l0_'+k].grad()
+        print('checking gradient for {}'.format('lstm0_l0_'+k))
+        assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(),
+                            rtol=rtol, atol=atol)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), 
mx.nd.ones((8, 3, 20)))
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, 
bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), 
mx.nd.ones((4, 3, 10))])
+
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, 
projection_size=5), mx.nd.ones((8, 3, 20)),
+                            run_only=True)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, 
dropout=0.5, projection_size=5),
+                            mx.nd.ones((8, 3, 20)),
+                            [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], 
run_only=True)
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_lstm_clip():
+    hidden_size, projection_size = 4096, 2048
+    batch_size, seq_len = 32, 80
+    input_size = 50
+    clip_min, clip_max, clip_nan = -5, 5, True
+    lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), 
ctx=mx.gpu(0))
+    lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), 
ctx=mx.gpu(0)),
+                   mx.nd.uniform(shape=(2, batch_size, hidden_size), 
ctx=mx.gpu(0))]
+    lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
+                                input_size=input_size, prefix='lstm0_',
+                                bidirectional=True,
+                                state_clip_min=clip_min,
+                                state_clip_max=clip_max,
+                                state_clip_nan=clip_nan)
+    lstm_layer.initialize(ctx=mx.gpu(0))
+    with autograd.record():
+        _, layer_output_states = lstm_layer(lstm_input, lstm_states)
+    cell_states = layer_output_states[0].asnumpy()
+    assert (cell_states >= clip_min).all() and (cell_states <= clip_max).all()
+    assert not np.isnan(cell_states).any()
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn_layer():
     check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
     check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
@@ -90,7 +166,65 @@ def test_rnn_layer():
     check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, 
bidirectional=True))
 
 
+def check_layer_bidirectional(size, in_size, proj_size):
+    class RefBiLSTM(gluon.Block):
+        def __init__(self, size, proj_size, **kwargs):
+            super(RefBiLSTM, self).__init__(**kwargs)
+            with self.name_scope():
+                self._lstm_fwd = gluon.rnn.LSTM(size, 
projection_size=proj_size, bidirectional=False, prefix='l0')
+                self._lstm_bwd = gluon.rnn.LSTM(size, 
projection_size=proj_size, bidirectional=False, prefix='r0')
+
+        def forward(self, inpt):
+            fwd = self._lstm_fwd(inpt)
+            bwd_inpt = nd.flip(inpt, 0)
+            bwd = self._lstm_bwd(bwd_inpt)
+            bwd = nd.flip(bwd, 0)
+            return nd.concat(fwd, bwd, dim=2)
+    weights = {}
+    for d in ['l', 'r']:
+        weights['lstm_{}0_i2h_weight'.format(d)] = 
mx.random.uniform(shape=(size*4, in_size))
+        if proj_size:
+            weights['lstm_{}0_h2h_weight'.format(d)] = 
mx.random.uniform(shape=(size*4, proj_size))
+            weights['lstm_{}0_h2r_weight'.format(d)] = 
mx.random.uniform(shape=(proj_size, size))
+        else:
+            weights['lstm_{}0_h2h_weight'.format(d)] = 
mx.random.uniform(shape=(size*4, size))
+        weights['lstm_{}0_i2h_bias'.format(d)] = 
mx.random.uniform(shape=(size*4,))
+        weights['lstm_{}0_h2h_bias'.format(d)] = 
mx.random.uniform(shape=(size*4,))
+
+    net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True, 
prefix='lstm_')
+    ref_net = RefBiLSTM(size, proj_size, prefix='lstm_')
+    net.initialize()
+    ref_net.initialize()
+    net_params = net.collect_params()
+    ref_net_params = ref_net.collect_params()
+    for k in weights:
+        net_params[k].set_data(weights[k])
+        ref_net_params[k.replace('l0', 'l0l0').replace('r0', 
'r0l0')].set_data(weights[k])
+
+    data = mx.random.uniform(shape=(11, 10, in_size))
+    assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
+
 @with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_layer_bidirectional():
+    check_layer_bidirectional(7, 5, 0)
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_layer_bidirectional_proj():
+    check_layer_bidirectional(7, 5, 3)
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_rnn_layer_begin_state_type():
+    fake_data = nd.random.uniform(shape=(3, 5, 7), dtype='float16')
+    modeling_layer = gluon.rnn.LSTM(hidden_size=11, num_layers=2, dropout=0.2, 
bidirectional=True)
+    modeling_layer.cast('float16')
+    modeling_layer.initialize()
+    modeling_layer(fake_data)
+
+
 def test_gluon_ctc_consistency():
     loss = mx.gluon.loss.CTCLoss()
     data = mx.nd.arange(0, 4, repeat=40, 
ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0)
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 02895cd..e329968 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -32,7 +32,7 @@ from numpy.testing import assert_allclose
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_not_satisfied
 from test_operator import *
 from test_optimizer import *
 from test_random import *
@@ -411,7 +411,7 @@ def test_batchnorm_versions():
 
 
 @with_seed(1234)
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_convolution_with_type():
     sym1 = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')
 
@@ -1364,7 +1364,7 @@ def check_rnn_consistency(cell1, cell2):
     assert_allclose(mod1.get_outputs()[0].asnumpy(), 
mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='rnn_relu', prefix='')
 
@@ -1376,7 +1376,7 @@ def test_rnn():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_forget_bias():
     forget_bias = 2.0
     fused = mx.rnn.FusedRNNCell(10, forget_bias=forget_bias, num_layers=2, 
mode='lstm', prefix='')
@@ -1398,7 +1398,7 @@ def test_lstm_forget_bias():
     assert_allclose(args[bias_name].asnumpy(), expected_bias)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='')
 
@@ -1410,7 +1410,7 @@ def test_gru():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_bidirectional():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='',
             bidirectional=True)
@@ -1429,7 +1429,7 @@ def test_bidirectional():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_unfuse():
     for mode in ['rnn_tanh', 'rnn_relu', 'lstm', 'gru']:
         fused = mx.rnn.FusedRNNCell(
@@ -1605,7 +1605,7 @@ def test_deformable_convolution_options():
                                                name='deformable_conv')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_residual_fused():
     cell = mx.rnn.ResidualCell(
             mx.rnn.FusedRNNCell(50, num_layers=3, mode='lstm',
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index f98bb79..abfba73 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -95,16 +95,17 @@ def random_seed(seed=None):
         random.seed(next_seed)
 
 
-def assert_raises_cudnn_disabled():
+def assert_raises_cudnn_not_satisfied(min_version):
     def test_helper(orig_test):
         @make_decorator(orig_test)
         def test_new(*args, **kwargs):
-            cudnn_disabled = (os.getenv('CUDNN_OFF_TEST_ONLY') == "true")
-            if not cudnn_disabled or mx.context.current_context().device_type 
== 'cpu':
+            cudnn_off = os.getenv('CUDNN_OFF_TEST_ONLY') == 'true'
+            cudnn_env_version = os.getenv('CUDNN_VERSION', None if cudnn_off 
else '7.3.1')
+            cudnn_test_disabled = cudnn_off or cudnn_env_version < min_version
+            if not cudnn_test_disabled or 
mx.context.current_context().device_type == 'cpu':
                 orig_test(*args, **kwargs)
             else:
-                errors = (MXNetError, RuntimeError)
-                assert_raises(errors, orig_test, *args, **kwargs)
+                assert_raises((MXNetError, RuntimeError), orig_test, *args, 
**kwargs)
         return test_new
     return test_helper
 
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 02dc6ce..3049674 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -23,7 +23,8 @@ from mxnet import gluon
 from mxnet.gluon import nn
 from mxnet.test_utils import assert_almost_equal
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
-from common import setup_module, with_seed, assertRaises, teardown, 
assert_raises_cudnn_disabled
+from common import (setup_module, with_seed, assertRaises, teardown,
+                    assert_raises_cudnn_not_satisfied)
 import numpy as np
 from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
@@ -339,7 +340,7 @@ def test_symbol_block():
     net.hybridize()
     assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
-    # Test case to verify if initializing the SymbolBlock from a model with 
params 
+    # Test case to verify if initializing the SymbolBlock from a model with 
params
     # other than fp32 param dtype.
 
     # 1. Load a resnet model, cast it to fp64 and export
@@ -1320,7 +1321,7 @@ def test_apply():
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_summary():
     net = gluon.model_zoo.vision.resnet50_v1()
     net.initialize()
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index c1d5f6a..bfe9592 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -22,7 +22,7 @@ import copy
 from numpy.testing import assert_allclose
 import unittest
 from mxnet.test_utils import almost_equal, assert_almost_equal
-from common import assert_raises_cudnn_disabled
+from common import assert_raises_cudnn_not_satisfied
 
 
 def test_rnn():
@@ -71,7 +71,7 @@ def test_lstm_forget_bias():
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), 
expected_bias)
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_cpu_inference():
     # should behave the same as lstm cell
     EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 
0.95215213],
@@ -243,7 +243,7 @@ def test_bidirectional():
     assert outs == [(10, 200), (10, 200), (10, 200)]
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_layer_bidirectional():
     class RefBiLSTM(gluon.Block):
         def __init__(self, size, **kwargs):
@@ -278,7 +278,7 @@ def test_layer_bidirectional():
         net_params[k].set_data(weights[k])
         ref_net_params[k.replace('l0', 'l0l0').replace('r0', 
'r0l0')].set_data(weights[k])
 
-    data = mx.random.uniform(shape=(3, 10, in_size))
+    data = mx.random.uniform(shape=(11, 10, in_size))
     assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
 
 
@@ -467,7 +467,7 @@ def check_rnn_layer_forward(layer, inputs, states=None, 
run_only=False):
         mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), 
rtol=1e-3, atol=1e-5)
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), 
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
@@ -490,12 +490,11 @@ def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, 
dropout=0.5),
                             mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), 
run_only=True)
 
-    net = gluon.nn.HybridSequential()
+    net = gluon.nn.Sequential()
     net.add(gluon.rnn.LSTM(10, bidirectional=True))
     net.add(gluon.nn.BatchNorm(axis=2))
     net.add(gluon.nn.Flatten())
     net.add(gluon.nn.Dense(3, activation='relu'))
-    net.hybridize()
     net.collect_params().initialize()
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
@@ -592,7 +591,7 @@ def test_cell_fill_shape():
     assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)
     check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 27d75d1..80a83df 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,7 +27,7 @@ from distutils.version import LooseVersion
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
 from mxnet.base import py_str, MXNetError, _as_list
-from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_disabled, assertRaises
+from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_not_satisfied, assertRaises
 import unittest
 
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, 
atol=1e-4):
@@ -72,7 +72,7 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, 
rtol=1e-2, atol=1e
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_sym():
     T, N, I, H = 5, 32, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', 
get_next_state=True, prefix='')
@@ -86,7 +86,7 @@ def test_lstm_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_bidirectional():
     T, N, I, H = 5, 20, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm',
@@ -107,7 +107,7 @@ def test_lstm_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru_sym():
     T, N, I, H = 5, 32, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', 
get_next_state=True, prefix='')
@@ -121,7 +121,7 @@ def test_gru_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru_bidirectional():
     T, N, I, H = 5, 20, 800, 800
 
@@ -144,7 +144,7 @@ def test_gru_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnntanh_sym():
     T, N, I, H = 5, 32, 800, 800
 
@@ -159,7 +159,7 @@ def test_rnntanh_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnntanh_bidirectional():
     T, N, I, H = 5, 20, 800, 800
 
@@ -181,7 +181,7 @@ def test_rnntanh_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnnrelu_sym():
     T, N, I, H = 5, 32, 200, 200
 
@@ -196,7 +196,7 @@ def test_rnnrelu_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnnrelu_bidirectional():
     T, N, I, H = 5, 20, 200, 200
 
@@ -4775,7 +4775,7 @@ def test_index_copy():
 
     with mx.autograd.record():
         out = mx.nd.contrib.index_copy(x, index, t)
-    out.backward() 
+    out.backward()
 
     tensor = mx.nd.array([[1,2,3],[0,0,0],[7,8,9],[0,0,0],[4,5,6]])
     x_grad = mx.nd.array([[0,0,0],[1,1,1],[0,0,0],[1,1,1],[0,0,0]])

Reply via email to