This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new 4a830db  [1.7.x] Backport of LSTM and GRU fix (#17898) and RNN op 
(#17632) (#18316)
4a830db is described below

commit 4a830db980f5d09b3065a267eae6a2950b456831
Author: bgawrych <[email protected]>
AuthorDate: Wed Jun 3 03:57:50 2020 +0200

    [1.7.x] Backport of LSTM and GRU fix (#17898) and RNN op (#17632) (#18316)
    
    * [Large Tensor] Backport of Fixed RNN op (#17632)
    
    * Changed relevant function args to index_t
    
    * Added nightly test for RNN
    
    * Added fix for LSTM, GRU, RNN-ReLU, RNN-tanh
    
    * Using const instead of literals
    
    * Added nightly test for RNN ReLU & tanh, LSTM, GRU
    
    * Type assertion to force evaluation of output NDArray
    
    * Incorporated latest round of comments
    
    * [v1.7.x] Backport of Fix LSTM and GRU layers gradient calculations 
(#18203)
    
    * Fix input gradient calculation for bidirectional LSTM
    
    For bidiractional LSTM with number of layers > 2 input gradient calculation 
was incorrect.
    Reason of wrong calculations was overwriting y derivative (dy) tensor by
    calculated x derivative (dx) tensor before right2left layer could use dy 
for own
    gradient calculations.
    Propsed fix uses additional space to avoid overwriting.
    
    * Fix gradient calculation for GRU
    
    For GRU with number of layers > 2 i2h_weight gradient for
    layers in the middle (all except last and first) was incorrect.
    Wrong caluculations were caused by assigning output pointer to
    input instead of calculating new input pointer.
    
    * Enable tests for GRU and LSTM gradients
    
    * Fix comments
    
    * Change loop iteration deduction
    
    * Add more test cases for fused rnn layers
    
    Co-authored-by: Connor Goggins <[email protected]>
---
 src/operator/rnn-inl.h                  |  44 ++--
 src/operator/rnn_impl.h                 | 415 ++++++++++++++++----------------
 tests/nightly/test_large_array.py       |  36 ++-
 tests/python/unittest/test_gluon_rnn.py |  90 +++----
 4 files changed, 298 insertions(+), 287 deletions(-)

diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 557c111..b9cee10 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -63,7 +63,7 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   bool bidirectional, state_outputs;
   int mode;
   float p;
-  int seq_length_, batch_size_, input_size_;
+  index_t seq_length_, batch_size_, input_size_;
 
   bool use_sequence_length;
   dmlc::optional<int> projection_size;
@@ -122,8 +122,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   }
 };
 
-inline int GetRnnParamSize(int num_layer,
-                           int input_size,
+inline index_t GetRnnParamSize(int num_layer,
+                           index_t input_size,
                            int state_size,
                            int direction,
                            int mode,
@@ -140,14 +140,14 @@ inline int GetRnnParamSize(int num_layer,
       size *= 3;
       break;
   }
-  int size1 = (input_size + state_size + 2) * size;  // first layer size
-  int size2 = (state_size * direction + state_size + 2) * size;  // other 
layers size
+  index_t size1 = (input_size + state_size + 2) * size;  // first layer size
+  index_t size2 = (state_size * direction + state_size + 2) * size;  // other 
layers size
   if (projection_size.has_value()) {
-    int proj_size = projection_size.value();
+    index_t proj_size = projection_size.value();
     size1 = (input_size + proj_size + 2) * size;
     size2 = (proj_size * direction + proj_size + 2) * size;
   }
-  int param_size = size1 + (num_layer - 1) * size2;
+  index_t param_size = size1 + (num_layer - 1) * size2;
   if (projection_size.has_value()) {
     param_size += projection_size.value() * state_size * num_layer * direction;
   }
@@ -182,8 +182,8 @@ inline int GetRnnBiasSize(int num_layer,
  *  - output -> h[t](, c[t] additionally with Lstm) time by time(sz: NxH(x2))
  *  - intermediate y[1...T] as next layer's inputs(sz: TxNxHxD)
  */
-inline size_t GetRNNWorkspaceSize(int seq_length,
-                                  int batch_size,
+inline size_t GetRNNWorkspaceSize(index_t seq_length,
+                                  index_t batch_size,
                                   int hidden_size,
                                   int projection_size,
                                   int direction,
@@ -193,7 +193,9 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
     case rnn_enum::kLstm:
       size = seq_length * batch_size * hidden_size * (4 + direction) +  // 
wx*x + inter-y
           batch_size * hidden_size * 6 +                                // 
wh*h + h + c
-          seq_length * hidden_size * 8;                    // Used in 
Backward, Δbx, Δbh
+          seq_length * hidden_size * 8 +                   // Used in 
Backward, Δbx, Δbh
+          // temporary dy in backward computation for bidirectional layers
+          seq_length * batch_size * hidden_size * (direction - 1 ? direction : 
0);
       break;
     case rnn_enum::kGru:
       // Differs with Lstm, the outputs of three gates are also held in memory
@@ -214,8 +216,8 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
 
 inline size_t GetRNNReserveSpaceSize(int num_layer,
                                      int direction,
-                                     int seq_length,
-                                     int batch_size,
+                                     index_t seq_length,
+                                     index_t batch_size,
                                      int hidden_size,
                                      int mode) {
   size_t size = 0;
@@ -279,9 +281,9 @@ void RNNForwardTraining(DType* ws,
                         bool state_outputs,
                         const int num_layers,
                         const int direction,
-                        const int seq_length,
-                        const int batch_size,
-                        const int input_size,
+                        const index_t seq_length,
+                        const index_t batch_size,
+                        const index_t input_size,
                         const int state_size,
                         DType* x_ptr,
                         DType* hx_ptr,
@@ -321,9 +323,9 @@ void RNNForwardInference(DType* ws,
                          bool state_outputs,
                          const int num_layers,
                          const int direction,
-                         const int seq_length,
-                         const int batch_size,
-                         const int input_size,
+                         const index_t seq_length,
+                         const index_t batch_size,
+                         const index_t input_size,
                          const int state_size,
                          const int projection_size,
                          DType* x_ptr,
@@ -363,9 +365,9 @@ void RNNBackward(DType* ws,
                  DType* rs,
                  const int num_layers,
                  const int direction,
-                 const int seq_length,
-                 const int batch_size,
-                 const int input_size,
+                 const index_t seq_length,
+                 const index_t batch_size,
+                 const index_t input_size,
                  const int state_size,
                  DType* x_ptr,
                  DType* hx_ptr,
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index 008ba7d..9f95185 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -59,9 +59,9 @@ void LstmForwardTrainingSingleLayer(DType* ws,
                                     DType* rs,
                                     bool state_outputs,
                                     bool bid,
-                                    const int T,
-                                    const int N,
-                                    const int I,
+                                    const index_t T,
+                                    const index_t N,
+                                    const index_t I,
                                     const int H,
                                     const Tensor<cpu, 2, DType> &x,
                                     const Tensor<cpu, 2, DType> &hx,
@@ -88,17 +88,17 @@ void LstmForwardTrainingSingleLayer(DType* ws,
   const int offset = bid ? H : 0;
   const DType alpha = 1.0;
   const DType beta = 0.0;
-  const int cell_size = N * H;
+  const index_t cell_size = N * H;
   linalg_gemm(x, wx, yx_flat, alpha, beta, false, true);
 
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-  for (int i = 0; i < T; ++i) {
-    int t = bid ? T - 1 - i : i;
+  for (index_t i = 0; i < T; ++i) {
+    index_t t = bid ? T - 1 - i : i;
     linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
     #pragma omp parallel for num_threads(omp_threads)
-    for (int jk = 0; jk < cell_size; ++jk) {
-      int j = jk / H;
-      int k = jk % H;
+    for (index_t jk = 0; jk < cell_size; ++jk) {
+      index_t j = jk / H;
+      index_t k = jk % H;
       DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + 
bh[0][k]);
       DType ft = sigmoid<DType>(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + 
bh[1][k]);
       DType gt =           tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + 
bh[2][k]);
@@ -127,9 +127,9 @@ void LstmForwardTraining(DType* ws,
                          bool state_outputs,
                          const int L,
                          const int D,
-                         const int T,
-                         const int N,
-                         const int I,
+                         const index_t T,
+                         const index_t N,
+                         const index_t I,
                          const int H,
                          DType* x_ptr,
                          DType* hx_ptr,
@@ -145,16 +145,16 @@ void LstmForwardTraining(DType* ws,
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
-  const int b_size = 2 * H * 4;
-  const int r_size = D * T * N * H * 6;
-  const int y_offset = T * N * H * 5;
-  const int cell_size = N * H;
+  const index_t b_size = 2 * H * 4;
+  const index_t r_size = D * T * N * H * 6;
+  const index_t y_offset = T * N * H * 5;
+  const index_t cell_size = N * H;
   unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
   int idx = 0;  // state & cell state's idx;
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   for (int i = 0; i < L; ++i) {
-    const int input_size = i ? H * D : I;
-    const int w_size = (input_size + H) * H * 4;
+    const index_t input_size = i ? H * D : I;
+    const index_t w_size = (input_size + H) * H * 4;
     Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
     Tensor<cpu, 3, DType> y(rs2 + y_offset, Shape3(T, N, H * D));
     LstmForwardTrainingSingleLayer<DType>(ws, rs2, state_outputs, false, T, N, 
input_size, H, x,
@@ -175,7 +175,7 @@ void LstmForwardTraining(DType* ws,
       b_ptr += b_size;
       if (dropout > 0.0f) {
         #pragma omp parallel for num_threads(omp_threads)
-        for (int j = 0; j < T * N * H * D; j++) {
+        for (index_t j = 0; j < T * N * H * D; j++) {
           int rand_data = rand_r(&seed_);
           if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * 
dropout)) {
             dropout_random[i * T * N * H * D + j] = 0;
@@ -196,7 +196,7 @@ void LstmForwardTraining(DType* ws,
     }
   }
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < T * N * H * D; ++i) {
+  for (index_t i = 0; i < T * N * H * D; ++i) {
     y_ptr[i] = (rs2 + y_offset)[i];
   }
 }
@@ -205,9 +205,9 @@ template<typename DType>
 void LstmForwardInferenceSingleLayer(DType* ws,
                                      bool state_outputs,
                                      bool bid,
-                                     const int T,
-                                     const int N,
-                                     const int I,
+                                     const index_t T,
+                                     const index_t N,
+                                     const index_t I,
                                      const int H,
                                      const int P,
                                      const Tensor<cpu, 2, DType> &x,
@@ -237,19 +237,19 @@ void LstmForwardInferenceSingleLayer(DType* ws,
   const int proj_offset = bid ? P : 0;
   const DType alpha = 1.0;
   const DType beta = 0.0;
-  const int cell_size = N * H;
+  const index_t cell_size = N * H;
   linalg_gemm(x, wx, yx_flat, alpha, beta, false, true);
 
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-  for (int i = 0; i < T; ++i) {
-    int t = bid ? T - 1 - i : i;
+  for (index_t i = 0; i < T; ++i) {
+    index_t t = bid ? T - 1 - i : i;
     if (P > 0) {
       linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true);
     } else {
       linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
     }
     #pragma omp parallel for num_threads(omp_threads)
-    for (int jk = 0; jk < cell_size; ++jk) {
+    for (index_t jk = 0; jk < cell_size; ++jk) {
       int j = jk / H;
       int k = jk % H;
       DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + 
bh[0][k]);
@@ -282,9 +282,9 @@ void LstmForwardInference(DType* ws,
                           bool state_outputs,
                           const int L,
                           const int D,
-                          const int T,
-                          const int N,
-                          const int I,
+                          const index_t T,
+                          const index_t N,
+                          const index_t I,
                           const int H,
                           const int P,
                           DType* x_ptr,
@@ -298,16 +298,16 @@ void LstmForwardInference(DType* ws,
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, P ? P : H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
-  const int b_size = 2 * H * 4;
-  const int cell_size = N * H;
-  const int projection_size = (P ? P : H) * N;
+  const index_t b_size = 2 * H * 4;
+  const index_t cell_size = N * H;
+  const index_t projection_size = (P ? P : H) * N;
   DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2;
   DType* y_cur_ptr = y_ptr;
   int idx = 0;  // state & cell state's idx;
   bool flag = L % 2 ? false : true;
   for (int i = 0; i < L; ++i) {
-    const int input_size = i ? (P ? P : H) * D : I;
-    int w_size = (input_size + (P ? P : H)) * H * 4;
+    const index_t input_size = i ? (P ? P : H) * D : I;
+    index_t w_size = (input_size + (P ? P : H)) * H * 4;
     if (P > 0) {
       w_size += P * H;
     }
@@ -351,9 +351,9 @@ void LstmBackwardSingleLayer(DType* ws,
                              DType* rs,
                              DType* tmp_buf,
                              bool bid,
-                             const int T,
-                             const int N,
-                             const int I,
+                             const index_t T,
+                             const index_t N,
+                             const index_t I,
                              const int H,
                              const Tensor<cpu, 2, DType> &x,
                              const Tensor<cpu, 2, DType> &hx,
@@ -403,41 +403,41 @@ void LstmBackwardSingleLayer(DType* ws,
   const DType beta0 = 0.0;
   const DType beta1 = 1.0;
   const DType beta2 = 2.0;
-  const int cell_size = N * H;
+  const index_t cell_size = N * H;
   if (dhy_ptr != nullptr) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < cell_size; ++i) {
+    for (index_t i = 0; i < cell_size; ++i) {
       dh.dptr_[i] = dhy_ptr[i];
     }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < cell_size; ++i) {
+    for (index_t i = 0; i < cell_size; ++i) {
       dh.dptr_[i] = 0;
     }
   }
   if (dcy_ptr != nullptr) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < cell_size; ++i) {
+    for (index_t i = 0; i < cell_size; ++i) {
       dc.dptr_[i] = dcy_ptr[i];
     }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < cell_size; ++i) {
+    for (index_t i = 0; i < cell_size; ++i) {
       dc.dptr_[i] = 0;
     }
   }
 
-  for (int i = T - 1; i >= 0; --i) {
-    int t = bid ? T - 1 - i : i;
-    int tnext = bid ? t + 1 : t - 1;
+  for (index_t i = T - 1; i >= 0; --i) {
+    index_t t = bid ? T - 1 - i : i;
+    index_t tnext = bid ? t + 1 : t - 1;
     const Tensor<cpu, 2, DType>& dhnext = i ? dh : dhx;
     const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx;
     const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx;
     const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int jk = 0; jk < cell_size; ++jk) {
-      int j = jk / H;
-      int k = jk % H;
+    for (index_t jk = 0; jk < cell_size; ++jk) {
+      index_t j = jk / H;
+      index_t k = jk % H;
       DType tc = tanh(c[i][j][k]);
       DType it = ifgo[i][j][k][0];
       DType ft = ifgo[i][j][k][1];
@@ -480,13 +480,13 @@ void LstmBackwardSingleLayer(DType* ws,
   if (req_params != kNullOp && req_params != kAddTo) {
     linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
   }
-  const int row = T * N;
-  const int col = H * 4;
+  const index_t row = T * N;
+  const index_t col = H * 4;
   if (req_params != kNullOp) {
     if (req_params != kAddTo) {
-      for (int i = 0; i < row; ++i) {
+      for (index_t i = 0; i < row; ++i) {
         #pragma omp parallel for num_threads(omp_threads)
-        for (int j = 0; j < col; ++j) {
+        for (index_t j = 0; j < col; ++j) {
           dbx[j] += dyx[i][j];
           dbh[j] = dbx[j];
         }
@@ -495,20 +495,20 @@ void LstmBackwardSingleLayer(DType* ws,
       const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T));
       const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T));
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < col * T; ++i) {
+      for (index_t i = 0; i < col * T; ++i) {
         tmp_dbx.dptr_[i] = 0;
         tmp_dbh.dptr_[i] = 0;
       }
-      for (int t = T - 1; t >= 0; --t) {
+      for (index_t t = T - 1; t >= 0; --t) {
         #pragma omp parallel for num_threads(omp_threads)
-        for (int j = 0; j < col; ++j) {
-          for (int i = 0; i < N; ++i) {
+        for (index_t j = 0; j < col; ++j) {
+          for (index_t i = 0; i < N; ++i) {
             tmp_dbx[j][t] += dyx[t * N + i][j];
             tmp_dbh[j][t] = tmp_dbx[j][t];
           }
         }
         #pragma omp parallel for num_threads(omp_threads)
-        for (int j = 0; j < col; ++j) {
+        for (index_t j = 0; j < col; ++j) {
           dbx[j] += tmp_dbx[j][t] + dbx[j];
           dbh[j] += tmp_dbh[j][t] + dbh[j];
         }
@@ -522,9 +522,9 @@ void LstmBackward(DType* ws,
                   DType* rs,
                   const int L,
                   const int D,
-                  const int T,
-                  const int N,
-                  const int I,
+                  const index_t T,
+                  const index_t N,
+                  const index_t I,
                   const int H,
                   DType* x_ptr,
                   DType* hx_ptr,
@@ -553,16 +553,17 @@ void LstmBackward(DType* ws,
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> dhx(dhx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> dcx(dcx_ptr, Shape3(total_layers, N, H));
-  const int b_size = 2 * H * 4;
-  const int r_size = D * T * N * H * 6;
-  const int y_offset = T * N * H * 5;
-  const int w_size1 = (I + H) * H * 4;      // first layer
-  const int w_size2 = (D * H + H) * H * 4;  // other layers
-  const int cell_size = N * H;
+  const index_t b_size = 2 * H * 4;
+  const index_t r_size = D * T * N * H * 6;
+  const index_t y_offset = T * N * H * 5;
+  const index_t w_size1 = (I + H) * H * 4;      // first layer
+  const index_t w_size2 = (D * H + H) * H * 4;  // other layers
+  const index_t cell_size = N * H;
+  const index_t y_size = T * N * H * D;
   DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3;
   for (int i = L - 1; i >= 0; --i) {
-    const int input_size = i ? H * D : I;
-    const int w_size = i ? w_size2 : w_size1;
+    const index_t input_size = i ? H * D : I;
+    const index_t w_size = i ? w_size2 : w_size1;
     int idx = i * D;
     DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr;
     DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : 
dw_ptr;
@@ -589,12 +590,16 @@ void LstmBackward(DType* ws,
                                      x, hx[idx], cx[idx], y, dy, dx, dhx[idx], 
dcx[idx],
                                      dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, 
dw_cur_ptr, db_cur_ptr,
                                      req_data, req_params, req_state, 
req_statecell);
+
+      // Prevent overwritting dy while calculating dx in left2right layer
+      const int loop_iteration = (L - 1) - i;
+      dy_tmp_ptr = loop_iteration % 2 ? dy_tmp_ptr - y_size : dy_tmp_ptr + 
y_size;
     }
     if (dropout > 0.0f && i > 0 && req_data != kNullOp) {
       dropout_random = dropout_random - T * N * D * H;
       const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
       #pragma omp parallel for num_threads(omp_threads)
-      for (int j = 0; j < T * N * D * H; j++) {
+      for (index_t j = 0; j < T * N * D * H; j++) {
         if (dropout_random[j] == 0) {
           dx.dptr_[j] = 0;
         } else {
@@ -611,9 +616,9 @@ void GruForwardInferenceSingleLayer(DType* ws,
                                     DType* tmp_buf,
                                     bool state_outputs,
                                     const int D,
-                                    const int T,
-                                    const int N,
-                                    const int I,
+                                    const index_t T,
+                                    const index_t N,
+                                    const index_t I,
                                     const int H,
                                     const Tensor<cpu, 2, DType> &x,
                                     const Tensor<cpu, 2, DType> &hx,
@@ -650,13 +655,13 @@ void GruForwardInferenceSingleLayer(DType* ws,
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   if (D == 1) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * H + j] = hx[i][j];
       }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * D * H + j] = hx[i][j];
         back_ht_1[i * D * H + j] = hx[N + i][j];
@@ -674,7 +679,7 @@ void GruForwardInferenceSingleLayer(DType* ws,
     linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
   }
 
-  for (int t = 0; t < T; t++) {
+  for (index_t t = 0; t < T; t++) {
     //  perform the first direction, X * wx and H * wh for each step
     //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
     Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
@@ -688,11 +693,11 @@ void GruForwardInferenceSingleLayer(DType* ws,
     }
     gemmC1_t = gemmC1 + t * N * 3 * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
-        int rtb = i * 3 * H;
-        int ztb = i * 3 * H + H;
-        int ntb = i * 3 * H + 2 * H;
+        index_t rtb = i * 3 * H;
+        index_t ztb = i * 3 * H + H;
+        index_t ntb = i * 3 * H + 2 * H;
         rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
             + bx[0][j] + bh[0][j]);
         zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
@@ -715,11 +720,11 @@ void GruForwardInferenceSingleLayer(DType* ws,
       linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, 
true);
 
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int rtb = i * 3 * H;
-          int ztb = i * 3 * H + H;
-          int ntb = i * 3 * H + 2 * H;
+          index_t rtb = i * 3 * H;
+          index_t ztb = i * 3 * H + H;
+          index_t ntb = i * 3 * H + 2 * H;
           rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
               gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
           zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
@@ -739,7 +744,7 @@ void GruForwardInferenceSingleLayer(DType* ws,
     if (D == 1) {
       DType* y_start = y_ptr + (T - 1) * N * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * H + j];
         }
@@ -747,7 +752,7 @@ void GruForwardInferenceSingleLayer(DType* ws,
       DType* y_start = y_ptr + (T - 1) * N * H * D;
       DType* y_back_start = y_ptr + H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * D * H + j];
           hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
@@ -761,9 +766,9 @@ void GruForwardInference(DType* ws,
                          bool state_outputs,
                          const int L,
                          const int D,
-                         const int T,
-                         const int N,
-                         int I,
+                         const index_t T,
+                         const index_t N,
+                         index_t I,
                          const int H,
                          DType* x_ptr,
                          DType* hx_ptr,
@@ -814,9 +819,9 @@ void GruForwardTrainingSingleLayer(DType* ws,
                                    DType* tmp_buf,
                                    bool state_outputs,
                                    const int D,
-                                   const int T,
-                                   const int N,
-                                   const int I,
+                                   const index_t T,
+                                   const index_t N,
+                                   const index_t I,
                                    const int H,
                                    const Tensor<cpu, 2, DType> &x,
                                    const Tensor<cpu, 2, DType> &hx,
@@ -862,13 +867,13 @@ void GruForwardTrainingSingleLayer(DType* ws,
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   if (D == 1) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * H + j] = hx[i][j];
       }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * D * H + j] = hx[i][j];
         back_ht_1[i * D * H + j] = hx[N + i][j];
@@ -887,7 +892,7 @@ void GruForwardTrainingSingleLayer(DType* ws,
     linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
   }
 
-  for (int t = 0; t < T; t++) {
+  for (index_t t = 0; t < T; t++) {
     //  perform the first direction, X * wx and H * wh for each step
     //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
     Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
@@ -905,11 +910,11 @@ void GruForwardTrainingSingleLayer(DType* ws,
     gemmC1_t = gemmC1 + t * N * 3 * H;
     DType* Mnht = Mnh + t * N * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
-        int rtb = i * 3 * H;
-        int ztb = i * 3 * H + H;
-        int ntb = i * 3 * H + 2 * H;
+        index_t rtb = i * 3 * H;
+        index_t ztb = i * 3 * H + H;
+        index_t ntb = i * 3 * H + 2 * H;
         Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j];
         rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
             + bx[0][j] + bh[0][j]);
@@ -937,11 +942,11 @@ void GruForwardTrainingSingleLayer(DType* ws,
 
       DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int rtb = i * 3 * H;
-          int ztb = i * 3 * H + H;
-          int ntb = i * 3 * H + 2 * H;
+          index_t rtb = i * 3 * H;
+          index_t ztb = i * 3 * H + H;
+          index_t ntb = i * 3 * H + 2 * H;
           back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j];
           rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
               gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
@@ -963,7 +968,7 @@ void GruForwardTrainingSingleLayer(DType* ws,
     if (D == 1) {
       DType* y_start = y_ptr + (T - 1) * N * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * H + j];
         }
@@ -971,7 +976,7 @@ void GruForwardTrainingSingleLayer(DType* ws,
       DType* y_start = y_ptr + (T - 1) * N * H * D;
       DType* y_back_start = y_ptr + H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * D * H + j];
           hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
@@ -986,9 +991,9 @@ void GruForwardTraining(DType* ws,
                         bool state_outputs,
                         const int L,
                         const int D,
-                        const int T,
-                        const int N,
-                        int I,
+                        const index_t T,
+                        const index_t N,
+                        index_t I,
                         const int H,
                         DType* x_ptr,
                         DType* hx_ptr,
@@ -1025,7 +1030,7 @@ void GruForwardTraining(DType* ws,
     if (dropout > 0.0f && l > 0) {
       const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * I; i++) {
+      for (index_t i = 0; i < T * N * I; i++) {
         int rand_data = rand_r(&seed_);
         if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * 
dropout)) {
           dropout_random[(l - 1) * T * N * I + i] = 0;
@@ -1057,7 +1062,7 @@ void GruForwardTraining(DType* ws,
   }
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < T * N * H * D; ++i) {
+  for (index_t i = 0; i < T * N * H * D; ++i) {
     y_ptr[i] = y_l[i];
   }
 }
@@ -1066,9 +1071,9 @@ template <typename DType>
 void GruBackwardSingleLayer(DType* ws,
                             DType* tmp_buf,
                             const int D,
-                            const int T,
-                            const int N,
-                            const int I,
+                            const index_t T,
+                            const index_t N,
+                            const index_t I,
                             const int H,
                             const Tensor<cpu, 2, DType> &x,
                             const Tensor<cpu, 2, DType> &hx,
@@ -1134,7 +1139,7 @@ void GruBackwardSingleLayer(DType* ws,
     }
   }
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < N * H; ++i) {
+  for (index_t i = 0; i < N * H; ++i) {
     if (dhy_ptr) {
       dht1[i] = dhy_ptr[i];
     } else {
@@ -1143,7 +1148,7 @@ void GruBackwardSingleLayer(DType* ws,
   }
 
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < N; ++i) {
+  for (index_t i = 0; i < N; ++i) {
     for (int j = 0; j < H; ++j) {
       hx_[i * D * H + j] = hx[i][j];
     }
@@ -1151,7 +1156,7 @@ void GruBackwardSingleLayer(DType* ws,
 
   if (D == 2) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N * H; ++i) {
+    for (index_t i = 0; i < N * H; ++i) {
       if (dhy_ptr) {
         back_dht1[i] = dhy_ptr[N * H + i];
       } else {
@@ -1159,13 +1164,13 @@ void GruBackwardSingleLayer(DType* ws,
       }
     }
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
         hx_[i * D * H + H + j] = hx[N + i][j];
       }
     }
   }
-  for (int t = T - 1; t >= 0; --t) {
+  for (index_t t = T - 1; t >= 0; --t) {
     if (t) {
       ht1 = y_ptr + (t - 1) * N * D * H;
     } else {
@@ -1175,7 +1180,7 @@ void GruBackwardSingleLayer(DType* ws,
     dyt = dy_ptr + t * N * D * H;
 
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
         dht1[i * H + j] += dyt[i * D * H + j];
       }
@@ -1188,7 +1193,7 @@ void GruBackwardSingleLayer(DType* ws,
     dat = da + t * N * 3 * H;
     dart = dar + t * N * 3 * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
         int nid = i * 3 * H + 2 * H + j;
         int zid = i * 3 * H + H + j;
@@ -1234,7 +1239,7 @@ void GruBackwardSingleLayer(DType* ws,
     if (req_params != kAddTo) {
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < 3 * H; ++i) {
-        for (int j = 0; j < N * T; ++j) {
+        for (index_t j = 0; j < N * T; ++j) {
           dbx[i] += da[j * 3 * H + i];
           dbh[i] += dar[j * 3 * H + i];
         }
@@ -1243,15 +1248,15 @@ void GruBackwardSingleLayer(DType* ws,
       const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 
3, T));
       const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, 
Shape2(H * 3, T));
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < H * T * 3; ++i) {
+      for (index_t i = 0; i < H * T * 3; ++i) {
         tmp_dbx.dptr_[i] = 0;
         tmp_dbh.dptr_[i] = 0;
       }
 
-      for (int t = T - 1; t >= 0; --t) {
+      for (index_t t = T - 1; t >= 0; --t) {
         #pragma omp parallel for num_threads(omp_threads)
         for (int i = 0; i < 3 * H; ++i) {
-          for (int j = 0; j < N; ++j) {
+          for (index_t j = 0; j < N; ++j) {
             tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
             tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
           }
@@ -1281,7 +1286,7 @@ void GruBackwardSingleLayer(DType* ws,
   }
 
   if (D == 2) {
-    for (int t = 0; t < T; ++t) {
+    for (index_t t = 0; t < T; ++t) {
       if (t == T-1) {
         back_ht1 = hx_;
       } else {
@@ -1291,7 +1296,7 @@ void GruBackwardSingleLayer(DType* ws,
       //  add dy[T, N, D, H] to dhy[D, N, H]
       dyt = dy_ptr + t * N * D * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
           back_dht1[i * H + j] += dyt[i * D * H + H + j];
         }
@@ -1305,12 +1310,12 @@ void GruBackwardSingleLayer(DType* ws,
       dart = dar + t * N * 3 * H;
 
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int nid = i * 3 * H + 2 * H + j;
-          int zid = i * 3 * H + H + j;
-          int rid = i * 3 * H + j;
-          int id = i * H + j;
+          index_t nid = i * 3 * H + 2 * H + j;
+          index_t zid = i * 3 * H + H + j;
+          index_t rid = i * 3 * H + j;
+          index_t id = i * H + j;
           dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
           dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] -
               nt[id]) * zt[id] * (1 - zt[id]);
@@ -1352,7 +1357,7 @@ void GruBackwardSingleLayer(DType* ws,
       if (req_params != kAddTo) {
         #pragma omp parallel for num_threads(omp_threads)
         for (int i = 0; i < 3 * H; ++i) {
-          for (int j = 0; j < N * T; ++j) {
+          for (index_t j = 0; j < N * T; ++j) {
             back_dbx[i] += da[j * 3 * H + i];
             back_dbh[i] += dar[j * 3 * H + i];
           }
@@ -1361,14 +1366,14 @@ void GruBackwardSingleLayer(DType* ws,
         const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H 
* 3, T));
         const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * 
T, Shape2(H * 3, T));
         #pragma omp parallel for num_threads(omp_threads)
-        for (int i = 0; i < H * T * 3; ++i) {
+        for (index_t i = 0; i < H * T * 3; ++i) {
           tmp_dbx.dptr_[i] = 0;
           tmp_dbh.dptr_[i] = 0;
         }
-        for (int t = T - 1; t >= 0; --t) {
+        for (index_t t = T - 1; t >= 0; --t) {
           #pragma omp parallel for num_threads(omp_threads)
           for (int i = 0; i < 3 * H; ++i) {
-            for (int j = 0; j < N; ++j) {
+            for (index_t j = 0; j < N; ++j) {
               tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
               tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
             }
@@ -1399,7 +1404,7 @@ void GruBackwardSingleLayer(DType* ws,
   }
   if (req_state != kNullOp) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N * H * D; ++i) {
+    for (index_t i = 0; i < N * H * D; ++i) {
       dhx[i] = dht1[i];
     }
   }
@@ -1410,9 +1415,9 @@ void GruBackward(DType* ws,
                  DType* rs,
                  const int L,
                  const int D,
-                 const int T,
-                 const int N,
-                 int I,
+                 const index_t T,
+                 const index_t N,
+                 index_t I,
                  const int H,
                  DType* x_ptr,
                  DType* hx_ptr,
@@ -1464,7 +1469,7 @@ void GruBackward(DType* ws,
   DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
   DType* dy_l = dy_ptr;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
-  int inputsize = I;
+  index_t inputsize = I;
   DType* y_tmp = y_l - T * N * H * D;
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   for (int l = L - 1; l >= 0; --l) {
@@ -1483,7 +1488,7 @@ void GruBackward(DType* ws,
     if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
       dropout_random = dropout_random - T * N * D * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * I; i++) {
+      for (index_t i = 0; i < T * N * I; i++) {
         if (dropout_random[i] == 0) {
           dx_l[i] = 0;
         } else {
@@ -1493,7 +1498,7 @@ void GruBackward(DType* ws,
     }
     if (l > 0) {
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * H * D; ++i) {
+      for (index_t i = 0; i < T * N * H * D; ++i) {
         dy_l[i] = dx_l[i];
       }
       gateR_l = gateR_l - T * D * N * H;
@@ -1504,7 +1509,7 @@ void GruBackward(DType* ws,
       if (dhy_l)
         dhy_l = dhy_l - D * N * H;
       y_l = y_l - T * N * H * D;
-      y_tmp = y_l;
+      y_tmp = y_tmp - T * N * H * D;
       if (l == 1) {
         wx_l = wx_l - (inputsize + H) * H * 3 * D;
         wh_l = wx_l + inputsize * 3 * H;
@@ -1527,9 +1532,9 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
                                            DType* tmp_buf,
                                            bool state_outputs,
                                            const int D,
-                                           const int T,
-                                           const int N,
-                                           const int I,
+                                           const index_t T,
+                                           const index_t N,
+                                           const index_t I,
                                            const int H,
                                            const Tensor<cpu, 2, DType> &x,
                                            const Tensor<cpu, 2, DType> &hx,
@@ -1564,13 +1569,13 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   if (D == 1) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * H + j] = hx[i][j];
       }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * D * H + j] = hx[i][j];
         back_ht_1[i * D * H + j] = hx[N + i][j];
@@ -1588,7 +1593,7 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
     linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
   }
 
-  for (int t = 0; t < T; t++) {
+  for (index_t t = 0; t < T; t++) {
     //  perform the first direction, X * wx and H * wh for each step
     //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
     Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
@@ -1602,9 +1607,9 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
     }
     gemmC1_t = gemmC1 + t * N * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
-        int tb = i * H;
+        index_t tb = i * H;
         if (mode == 1) {
           ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
               gemmC2[tb + j] + bh[0][j]);
@@ -1626,9 +1631,9 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
       linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, 
true);
 
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int tb = i * H;
+          index_t tb = i * H;
           if (mode == 1) {
             back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j]
                 + gemmC2[tb + j] + back_bh[0][j]);
@@ -1647,7 +1652,7 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
     if (D == 1) {
       DType* y_start = y_ptr + (T - 1) * N * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * H + j];
         }
@@ -1655,7 +1660,7 @@ void VanillaRNNForwardInferenceSingleLayer(DType* ws,
       DType* y_start = y_ptr + (T - 1) * N * H * D;
       DType* y_back_start = y_ptr + H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * D * H + j];
           hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
@@ -1669,9 +1674,9 @@ void VanillaRNNForwardInference(DType* ws,
                                 bool state_outputs,
                                 const int L,
                                 const int D,
-                                const int T,
-                                const int N,
-                                int I,
+                                const index_t T,
+                                const index_t N,
+                                index_t I,
                                 const int H,
                                 DType* x_ptr,
                                 DType* hx_ptr,
@@ -1724,9 +1729,9 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
                                        DType* tmp_buf,
                                        bool state_outputs,
                                        const int D,
-                                       const int T,
-                                       const int N,
-                                       const int I,
+                                       const index_t T,
+                                       const index_t N,
+                                       const index_t I,
                                        const int H,
                                        const Tensor<cpu, 2, DType> &x,
                                        const Tensor<cpu, 2, DType> &hx,
@@ -1765,13 +1770,13 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   if (D == 1) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * H + j] = hx[i][j];
       }
   } else {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; i++)
+    for (index_t i = 0; i < N; i++)
       for (int j = 0; j < H; j++) {
         y_ptr[i * D * H + j] = hx[i][j];
         back_ht_1[i * D * H + j] = hx[N + i][j];
@@ -1790,7 +1795,7 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
     linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
   }
 
-  for (int t = 0; t < T; t++) {
+  for (index_t t = 0; t < T; t++) {
     //  perform the first direction, X * wx and H * wh for each step
     //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
     Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
@@ -1805,9 +1810,9 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
     nt = gateN + t * N * H;
     gemmC1_t = gemmC1 + t * N * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
-        int tb = i * H;
+        index_t tb = i * H;
         if (mode == 1) {
           nt[tb + j] = ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
               gemmC2[tb + j] + bh[0][j]);
@@ -1829,9 +1834,9 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
       dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
       linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, 
true);
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int tb = i * H;
+          index_t tb = i * H;
           if (mode == 1) {
             nt[tb + j] = back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + 
back_bx[0][j]
                 + gemmC2[tb + j] + back_bh[0][j]);
@@ -1851,7 +1856,7 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
     if (D == 1) {
       DType* y_start = y_ptr + (T - 1) * N * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * H + j];
         }
@@ -1859,7 +1864,7 @@ void VanillaRNNForwardTrainingSingleLayer(DType* ws,
       DType* y_start = y_ptr + (T - 1) * N * H * D;
       DType* y_back_start = y_ptr + H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; i++)
+      for (index_t i = 0; i < N; i++)
         for (int j = 0; j < H; j++) {
           hy_ptr[i * H + j] = y_start[i * D * H + j];
           hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
@@ -1874,9 +1879,9 @@ void VanillaRNNForwardTraining(DType* ws,
                                bool state_outputs,
                                const int L,
                                const int D,
-                               const int T,
-                               const int N,
-                               int I,
+                               const index_t T,
+                               const index_t N,
+                               index_t I,
                                const int H,
                                DType* x_ptr,
                                DType* hx_ptr,
@@ -1911,7 +1916,7 @@ void VanillaRNNForwardTraining(DType* ws,
     }
     if (dropout > 0.0f && l > 0) {
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * I; i++) {
+      for (index_t i = 0; i < T * N * I; i++) {
         int rand_data = rand_r(&seed_);
         if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * 
dropout)) {
           dropout_random[(l - 1) * T * N * I + i] = 0;
@@ -1939,7 +1944,7 @@ void VanillaRNNForwardTraining(DType* ws,
     wh_l = wx_l + I * H;
   }
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < T * N * H * D; ++i) {
+  for (index_t i = 0; i < T * N * H * D; ++i) {
     y_ptr[i] = y_l[i];
   }
 }
@@ -1948,9 +1953,9 @@ template <typename DType>
 void VanillaRNNBackwardSingleLayer(DType* ws,
                                    DType* tmp_buf,
                                    const int D,
-                                   const int T,
-                                   const int N,
-                                   const int I,
+                                   const index_t T,
+                                   const index_t N,
+                                   const index_t I,
                                    const int H,
                                    const Tensor<cpu, 2, DType> &x,
                                    const Tensor<cpu, 2, DType> &hx,
@@ -2008,7 +2013,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
   }
 
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < N * H; ++i) {
+  for (index_t i = 0; i < N * H; ++i) {
     if (dhy_ptr) {
       dht1[i] = dhy_ptr[i];
     } else {
@@ -2017,7 +2022,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
   }
 
   #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < N; ++i) {
+  for (index_t i = 0; i < N; ++i) {
     for (int j = 0; j < H; ++j) {
       hx_[i * D * H + j] = hx[i][j];
     }
@@ -2025,7 +2030,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
 
   if (D == 2) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N * H; ++i) {
+    for (index_t i = 0; i < N * H; ++i) {
       if (dhy_ptr) {
         back_dht1[i] = dhy_ptr[N * H + i];
       } else {
@@ -2033,13 +2038,13 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
       }
     }
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
         hx_[i * D * H + H + j] = hx[N + i][j];
       }
     }
   }
-  for (int t = T - 1; t >= 0; --t) {
+  for (index_t t = T - 1; t >= 0; --t) {
     if (t) {
       ht1 = y_ptr + (t - 1) * N * D * H;
     } else {
@@ -2049,7 +2054,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
     dyt = dy_ptr + t * N * D * H;
 
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
         dht1[i * H + j] += dyt[i * D * H + j];
       }
@@ -2058,9 +2063,9 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
     nt = gateN + t * N * H;
     dart = dar + t * N * H;
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N; ++i) {
+    for (index_t i = 0; i < N; ++i) {
       for (int j = 0; j < H; ++j) {
-        int id = i * H + j;
+        index_t id = i * H + j;
         if (mode == 1) {
           dart[id] = dht1[id] * (1 - nt[id] * nt[id]);
         } else {
@@ -2099,7 +2104,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
     if (req_params != kAddTo) {
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < H; ++i) {
-        for (int j = 0; j < N * T; ++j) {
+        for (index_t j = 0; j < N * T; ++j) {
           dbx[i] += dar[j * H + i];
           dbh[i] = dbx[i];
         }
@@ -2108,15 +2113,15 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
       const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, 
T));
       const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, 
Shape2(H, T));
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < H * T; ++i) {
+      for (index_t i = 0; i < H * T; ++i) {
         tmp_dbx.dptr_[i] = 0;
         tmp_dbh.dptr_[i] = 0;
       }
 
-      for (int t = T - 1; t >= 0; --t) {
+      for (index_t t = T - 1; t >= 0; --t) {
         #pragma omp parallel for num_threads(omp_threads)
         for (int i = 0; i < H; ++i) {
-          for (int j = 0; j < N; ++j) {
+          for (index_t j = 0; j < N; ++j) {
             tmp_dbx[i][t] += dar[t * N * H + j * H + i];
             tmp_dbh[i][t] = tmp_dbx[i][t];
           }
@@ -2146,7 +2151,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
   }
 
   if (D == 2) {
-    for (int t = 0; t < T; ++t) {
+    for (index_t t = 0; t < T; ++t) {
       if (t == T-1) {
         back_ht1 = hx_;
       } else {
@@ -2156,7 +2161,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
       //  add dy[T, N, D, H] to dhy[D, N, H]
       dyt = dy_ptr + t * N * D * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
           back_dht1[i * H + j] += dyt[i * D * H + H + j];
         }
@@ -2166,9 +2171,9 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
       dart = dar + t * N * H;
 
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < N; ++i) {
+      for (index_t i = 0; i < N; ++i) {
         for (int j = 0; j < H; ++j) {
-          int id = i * H + j;
+          index_t id = i * H + j;
           if (mode == 1) {
             dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]);
           } else {
@@ -2208,7 +2213,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
       if (req_params != kAddTo) {
         #pragma omp parallel for num_threads(omp_threads)
         for (int i = 0; i < H; ++i) {
-          for (int j = 0; j < N * T; ++j) {
+          for (index_t j = 0; j < N * T; ++j) {
             back_dbx[i] += dar[j * H + i];
             back_dbh[i] = back_dbx[i];
           }
@@ -2217,15 +2222,15 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
         const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, 
T));
         const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, 
Shape2(H, T));
         #pragma omp parallel for num_threads(omp_threads)
-        for (int i = 0; i < H * T; ++i) {
+        for (index_t i = 0; i < H * T; ++i) {
           tmp_dbx.dptr_[i] = 0;
           tmp_dbh.dptr_[i] = 0;
         }
 
-        for (int t = T - 1; t >= 0; --t) {
+        for (index_t t = T - 1; t >= 0; --t) {
           #pragma omp parallel for num_threads(omp_threads)
           for (int i = 0; i < H; ++i) {
-            for (int j = 0; j < N; ++j) {
+            for (index_t j = 0; j < N; ++j) {
               tmp_dbx[i][t] += dar[t * N * H + j * H + i];
               tmp_dbh[i][t] = tmp_dbx[i][t];
             }
@@ -2256,7 +2261,7 @@ void VanillaRNNBackwardSingleLayer(DType* ws,
   }
   if (req_state != kNullOp) {
     #pragma omp parallel for num_threads(omp_threads)
-    for (int i = 0; i < N * H * D; ++i) {
+    for (index_t i = 0; i < N * H * D; ++i) {
       dhx[i] = dht1[i];
     }
   }
@@ -2267,9 +2272,9 @@ void VanillaRNNBackward(DType* ws,
                         DType* rs,
                         const int L,
                         const int D,
-                        const int T,
-                        const int N,
-                        int I,
+                        const index_t T,
+                        const index_t N,
+                        index_t I,
                         const int H,
                         DType* x_ptr,
                         DType* hx_ptr,
@@ -2319,7 +2324,7 @@ void VanillaRNNBackward(DType* ws,
   DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
   DType* dy_l = dy_ptr;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
-  int inputsize = I;
+  index_t inputsize = I;
   DType* y_tmp = y_l - T * N * H * D;
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   for (int l = L - 1; l >= 0; --l) {
@@ -2338,7 +2343,7 @@ void VanillaRNNBackward(DType* ws,
     if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
       dropout_random = dropout_random - T * N * D * H;
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * I; i++) {
+      for (index_t i = 0; i < T * N * I; i++) {
         if (dropout_random[i] == 0) {
           dx_l[i] = 0;
         } else {
@@ -2348,7 +2353,7 @@ void VanillaRNNBackward(DType* ws,
     }
     if (l > 0) {
       #pragma omp parallel for num_threads(omp_threads)
-      for (int i = 0; i < T * N * H * D; ++i) {
+      for (index_t i = 0; i < T * N * H * D; ++i) {
         dy_l[i] = dx_l[i];
       }
       gateN_l = gateN_l -  T * D * N * H;
diff --git a/tests/nightly/test_large_array.py 
b/tests/nightly/test_large_array.py
index 1528bc0..39fee72 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -39,6 +39,7 @@ SMALL_X = 100
 SMALL_Y = 50
 LARGE_SIZE = LARGE_X * SMALL_Y
 LARGE_TENSOR_SHAPE = 2**32
+RNN_LARGE_TENSOR = 2**28
 
 
 def test_nn():
@@ -479,7 +480,6 @@ def test_nn():
 
         assert out.shape[0] == LARGE_TENSOR_SHAPE
         assert out.shape[1] == 1
-        assert out.shape[2] == 1
         
     def check_spatial_transformer():
         data = nd.random_normal(shape=(2, 2**29, 1, 6))
@@ -504,6 +504,39 @@ def test_nn():
 
         assert out.shape[0] == LARGE_TENSOR_SHAPE
 
+    def check_rnn():
+        data = nd.random_normal(shape=(RNN_LARGE_TENSOR, 4, 4))
+        parameters_relu_tanh = nd.random_normal(shape=(7,))
+        parameters_lstm = nd.random_normal(shape=(28,))
+        parameters_gru = nd.random_normal(shape=(21,))
+        state = nd.random_normal(shape=(1, 4, 1))
+        state_cell = nd.random_normal(shape=(1, 4, 1))
+        mode_relu = 'rnn_relu'
+        mode_tanh = 'rnn_tanh'
+        mode_lstm = 'lstm'
+        mode_gru = 'gru'
+        state_size = 1
+        num_layers = 1
+
+        out_relu = nd.RNN(data=data, parameters=parameters_relu_tanh, 
state=state, mode=mode_relu,
+                          state_size=state_size, num_layers=num_layers)
+        
+        out_tanh = nd.RNN(data=data, parameters=parameters_relu_tanh, 
state=state, mode=mode_tanh,
+                          state_size=state_size, num_layers=num_layers)
+        
+        out_lstm = nd.RNN(data=data, parameters=parameters_lstm, state=state, 
mode=mode_lstm,
+                          state_cell=state_cell, state_size=state_size, 
num_layers=num_layers)
+
+        out_gru = nd.RNN(data=data, parameters=parameters_gru, state=state, 
mode=mode_gru,
+                         state_size=state_size, num_layers=num_layers)
+
+        for out in [out_relu, out_tanh, out_lstm, out_gru]:
+            assert out.shape[0] == RNN_LARGE_TENSOR
+            assert out.shape[1] == 4
+            assert out.shape[2] == 1
+
+            assert type(out[0, 0, 0].asscalar()).__name__ == 'float32'
+
     check_gluon_embedding()
     check_fully_connected()
     check_dense()
@@ -527,6 +560,7 @@ def test_nn():
     check_embedding()
     check_spatial_transformer()
     check_ravel()
+    check_rnn()
 
 
 def test_tensor():
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index f2a220b..6f9308b 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -685,15 +685,10 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, 
input_size, hidden_siz
     stack_input_grad = sx.grad.asnumpy()
 
     assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, 
atol=atol)
-    if mx.context.current_context().device_type == 'cpu' and \
-            not mx.runtime.Features().is_enabled('MKLDNN') and \
-            'rnn' not in fused_layer.prefix:
-        print("LSTM and GRU on native CPU give wrong gradients. "
-              "Tracking issue: 
https://github.com/apache/incubator-mxnet/issues/17898.";)
-    else:
-        assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, 
atol=atol)
-        for key, value in fused_grads.items():
-            assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), 
rtol=rtol, atol=atol)
+    assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
+    for key, value in fused_grads.items():
+        assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), 
rtol=rtol, atol=atol)
+
     num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1)
     check_rnn_states(fused_states, stack_states, num_layers, bidirectional, 
len(fused_begin_state) == 2)
 
@@ -719,61 +714,32 @@ def create_op_by_mode(mode):
     return fused_op, stack_op, recurrent_block_prefix
 
 
-def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss):
+def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, 
num_layers, loss):
     fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
-    # ==== Single layer ====
-    fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', 
bidirectional=False, prefix=recurrent_block_prefix)
-    fused_layer.initialize()
-
-    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
-    with stack_layer.name_scope():
-        stack_layer.add(stack_op(hidden_size, prefix='l0_'))
-    stack_layer.initialize()
 
-    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size)
-
-    # ==== Multiple layer ====
-    fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', 
bidirectional=False, prefix=recurrent_block_prefix)
+    fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', 
bidirectional=False, prefix=recurrent_block_prefix)
     fused_layer.initialize()
 
     stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
     with stack_layer.name_scope():
-        stack_layer.add(stack_op(hidden_size, prefix='l0_'))
-        stack_layer.add(stack_op(hidden_size, prefix='l1_'))
-        stack_layer.add(stack_op(hidden_size, prefix='l2_'))
+        for n in range(num_layers):
+            stack_layer.add(stack_op(hidden_size, prefix="l{}_".format(n)))
     stack_layer.initialize()
-
     check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size)
 
 
-def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
+def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, 
loss):
     fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
-    # ==== Single layer ====
-    fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', 
bidirectional=True, prefix=recurrent_block_prefix)
-    fused_layer.initialize()
-
-    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
-    with stack_layer.name_scope():
-        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l0_'),
-                                                    stack_op(hidden_size, 
prefix='r0_')))
-    stack_layer.initialize()
 
-    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size, bidirectional=True)
-
-    # ==== Multiple layer ====
-    fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', 
bidirectional=True, prefix=recurrent_block_prefix)
+    fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', 
bidirectional=True, prefix=recurrent_block_prefix)
     fused_layer.initialize()
 
     stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
     with stack_layer.name_scope():
-        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l0_'),
-                                                    stack_op(hidden_size, 
prefix='r0_')))
-        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l1_'),
-                                                    stack_op(hidden_size, 
prefix='r1_')))
-        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l2_'),
-                                                    stack_op(hidden_size, 
prefix='r2_')))
-    stack_layer.initialize()
-
+        for n in range(num_layers):
+            stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix="l{}_".format(n)),
+                                                stack_op(hidden_size, 
prefix="r{}_".format(n))))
+        stack_layer.initialize()
     check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size, bidirectional=True)
 
 
@@ -782,10 +748,11 @@ def check_rnn_bidir_layer_gradients(mode, input_size, 
hidden_size, loss):
 def test_fused_lstm_layer():
     input_sizes = [8]
     hidden_sizes = [8, 16]
-    for input_size, hidden_size in product(input_sizes, hidden_sizes):
+    num_layers = [1, 2, 3, 4]
+    for input_size, hidden_size, num_layers in product(input_sizes, 
hidden_sizes, num_layers):
         loss = mx.gluon.loss.L2Loss()
-        check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss)
-        check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss)
+        check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, 
num_layers, loss)
+        check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, 
num_layers, loss)
 
 
 @with_seed()
@@ -793,10 +760,11 @@ def test_fused_lstm_layer():
 def test_fused_gru_layer():
     input_sizes = [8]
     hidden_sizes = [8, 16]
-    for input_size, hidden_size in product(input_sizes, hidden_sizes):
+    num_layers = [1, 2, 3, 4]
+    for input_size, hidden_size, num_layers in product(input_sizes, 
hidden_sizes, num_layers):
         loss = mx.gluon.loss.L2Loss()
-        check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss)
-        check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss)
+        check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, 
num_layers, loss)
+        check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, 
num_layers, loss)
 
 
 @with_seed()
@@ -804,10 +772,11 @@ def test_fused_gru_layer():
 def test_fused_rnnrelu_layer():
     input_sizes = [8]
     hidden_sizes = [8, 16]
-    for input_size, hidden_size in product(input_sizes, hidden_sizes):
+    num_layers = [1, 2, 3, 4]
+    for input_size, hidden_size, num_layers in product(input_sizes, 
hidden_sizes, num_layers):
         loss = mx.gluon.loss.L2Loss()
-        check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, 
loss)
-        check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, 
loss)
+        check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, 
num_layers, loss)
+        check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, 
num_layers, loss)
 
 
 @with_seed()
@@ -815,10 +784,11 @@ def test_fused_rnnrelu_layer():
 def test_fused_rnntanh_layer():
     input_sizes = [8]
     hidden_sizes = [8, 16]
-    for input_size, hidden_size in product(input_sizes, hidden_sizes):
+    num_layers = [1, 2, 3, 4]
+    for input_size, hidden_size, num_layers in product(input_sizes, 
hidden_sizes, num_layers):
         loss = mx.gluon.loss.L2Loss()
-        check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, 
loss)
-        check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, 
loss)
+        check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, 
num_layers, loss)
+        check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, 
num_layers, loss)
 
 
 def test_rnn_unroll_variant_length():

Reply via email to