marcoabreu closed pull request #10458: [Don't merge] A debug pr of 
https://github.com/apache/incubator-mxnet/pull/10104
URL: https://github.com/apache/incubator-mxnet/pull/10458
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index 2aaaeb25d76..7b4d3393e59 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -23,7 +23,6 @@
 from __future__ import print_function
 __all__ = ['RNN', 'LSTM', 'GRU']
 
-from ...autograd import is_training
 from ... import ndarray
 from .. import Block
 from . import rnn_cell
@@ -186,8 +185,7 @@ def forward(self, inputs, states=None):
             for i in range(self._dir):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, 
inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
-        if inputs.context.device_type == 'gpu' or \
-            (not is_training() and self._mode == 'lstm'):
+        if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
             out = self._forward_kernel(inputs, states)
         else:
             out = self._forward(inputs, states)
diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index 4bd170cfac7..744fb3fe1bf 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -38,7 +38,7 @@ namespace mxnet {
 namespace op {
 #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 template<typename DType>
-class CuDNNRNNOp : public Operator {
+class CuDNNRNNOp : public Operator{
  public:
   explicit CuDNNRNNOp(RNNParam param) {
     this->param_ = param;
@@ -100,6 +100,7 @@ class CuDNNRNNOp : public Operator {
       CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
       Storage::Get()->Free(dropout_states_);
       Storage::Get()->Free(reserve_space_);
+      init_cudnn_ = false;
     }
   }
 
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 13c077dd9e3..0fc11c71cb3 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file rnn-inl.h
  * \brief
- * \author Sebastian Bodenstein
+ * \author Sebastian Bodenstein, Shu Zhang(shu.zh...@intel.com)
 */
 #ifndef MXNET_OPERATOR_RNN_INL_H_
 #define MXNET_OPERATOR_RNN_INL_H_
@@ -29,6 +29,7 @@
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
+#include <mxnet/storage.h>
 #include <algorithm>
 #include <map>
 #include <vector>
@@ -37,8 +38,7 @@
 #include "./math.h"
 #include "./math_functions-inl.h"
 #include "./operator_common.h"
-#include "./mshadow_op.h"
-#include "./linalg.h"
+#include "./rnn_impl.hpp"
 
 namespace mxnet {
 namespace op {
@@ -50,18 +50,37 @@ namespace rnn_enum {
   enum RNNOpResource {kTempSpace};
 }
 
-// A utility function to calculate input size
-inline int rnn_single_param_size(int inputSize,
-                                int hiddenSize,
-                                int mode) {
-  int size = hiddenSize * (hiddenSize + inputSize + 2);
-  // Different RNN's have different num weights
+inline int GetRnnParamSize(int num_layer,
+                           int input_size,
+                           int state_size,
+                           int direction,
+                           int mode) {
+  int size = state_size * direction;
   switch (mode) {
     case rnn_enum::kRnnRelu:
-      size *= 1;
+    case rnn_enum::kRnnTanh:
       break;
+    case rnn_enum::kLstm:
+      size *= 4;
+      break;
+    case rnn_enum::kGru:
+      size *= 3;
+      break;
+  }
+  int size1 = (input_size + state_size + 2) * size;  // first layer size
+  int size2 = (state_size * direction + state_size + 2) * size;  // other 
layers size
+  int param_size = size1 + (num_layer - 1) * size2;
+  return param_size;
+}
+
+inline int GetRnnBiasSize(int num_layer,
+                           int state_size,
+                           int direction,
+                           int mode) {
+  int size = 2 * state_size * direction * num_layer;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-      size *= 1;
       break;
     case rnn_enum::kLstm:
       size *= 4;
@@ -73,19 +92,48 @@ inline int rnn_single_param_size(int inputSize,
   return size;
 }
 
-inline int rnn_param_size(int layerNum,
-                          int inputSize,
-                          int hiddenSize,
-                          bool bidirectional,
-                          int mode) {
-  // get size of first layer
-  int size = rnn_single_param_size(inputSize, hiddenSize, mode);
-  // get size of remaining layers
-  if (bidirectional) {
-    size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, 
mode);
-    size *= 2;
-  } else {
-    size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, 
mode);
+inline size_t GetRNNWorkspaceSize(int seq_length,
+                                  int batch_size,
+                                  int hidden_size,
+                                  int direction,
+                                  int mode) {
+  size_t size = 0;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      LOG(FATAL) << "Only LSTM is supported at the moment";
+      break;
+    case rnn_enum::kLstm:
+      size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * 
hidden_size * 3
+             + seq_length * batch_size * hidden_size * direction;
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
+  }
+  return size;
+}
+
+inline size_t GetRNNReserveSpaceSize(int num_layer,
+                                     int direction,
+                                     int seq_length,
+                                     int batch_size,
+                                     int hidden_size,
+                                     int mode) {
+  size_t size = 0;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      LOG(FATAL) << "Only LSTM is supported at the moment";
+      break;
+    case rnn_enum::kLstm:
+      size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
   }
   return size;
 }
@@ -125,51 +173,152 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   }
 };
 
-template<typename xpu, typename DType>
-class RNNOp : public Operator {
- public:
-  explicit RNNOp(RNNParam p) {
-  }
 
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    // TODO(sbodenstein): add MShadow implementation
+/**
+ * @params: ws: Temp workspace for gemm's output storage.
+ *          rs: Reserve space of forward intermediate data used for training.
+ *          num_layers: The number of recurrent layers.
+ *          direction: direction is 2 if use bidirectional recurrent layers, 
else is 1;
+ *          seq_length: The number of iterations to unroll over.
+ *          batch_size: size of batch.
+ *          input_size: The number of expected input features.
+ *          state_size: The number of hidden state features.
+ *          x_ptr: Pointer of tensor x containing the features of the input 
sequence.
+ *                 x's shape is [seq_length, batch_size, input_size]
+ *          hx_ptr: Pointer of tensor hx containing the initial hidden state.
+ *                  hx's shape is [num_layers, batch_size, state_size]
+ *          cx_ptr: Only used in lstm mode. pointer of tensor cx containing 
the initial cell state.
+ *                  cx's shape is [num_layers, batch_size, state_size]
+ *          w_ptr: Pointer of tensor w containing weights.
+ *          b_ptr: Pointer of tensor w containing bias.
+ *          y_ptr: Pointer of tensor y containing the features of the output 
features from the
+ *                 last layers of the RNN. y's shape is [seq_length, 
batch_size, state_size]
+ *          hy_ptr: Pointer of tensor hy containing the hidden state for 
t=seq_length.
+ *                  hy's shape is [num_layers, batch_size, state_size]
+ *          cy_ptr: Only used in lstm mode. pointer of tensor cy  containing 
the cell state
+ *                  for t=seq_length. cy' shape is [num_layers, batch_size, 
state_size]
+ *          mode: Specifies the type of RNN to compute.
+ */
+template <typename DType>
+void RNNForwardTraining(DType* ws,
+                        DType* rs,
+                        bool state_outputs,
+                        const int num_layers,
+                        const int direction,
+                        const int seq_length,
+                        const int batch_size,
+                        const int input_size,
+                        const int state_size,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* cx_ptr,
+                        DType* w_ptr,
+                        DType* b_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr,
+                        DType* cy_ptr,
+                        int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      LOG(FATAL) << "Only LSTM is supported at the moment";
+      break;
+    case rnn_enum::kLstm:
+      LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, 
seq_length,
+                                 batch_size, input_size, state_size, x_ptr, 
hx_ptr, cx_ptr,
+                                 w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
   }
+}
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    // TODO(sbodenstein): add MShadow implementation
+template <typename DType>
+void RNNForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int num_layers,
+                         const int direction,
+                         const int seq_length,
+                         const int batch_size,
+                         const int input_size,
+                         const int state_size,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* cx_ptr,
+                         DType* w_ptr,
+                         DType* b_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr,
+                         DType* cy_ptr,
+                         int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kGru:
+      LOG(FATAL) << "Only LSTM is supported at the moment";
+      break;
+    case rnn_enum::kLstm:
+      LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, 
seq_length,
+                                  batch_size, input_size, state_size, x_ptr, 
hx_ptr, cx_ptr,
+                                  w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode" << mode;
+      break;
   }
+}
 
- private:
-  RNNParam param_;
-};  // class RNNOp
+template <typename DType>
+void RNNBackward(DType* ws,
+                 DType* rs,
+                 const int num_layers,
+                 const int direction,
+                 const int seq_length,
+                 const int batch_size,
+                 const int input_size,
+                 const int state_size,
+                 DType* x_ptr,
+                 DType* hx_ptr,
+                 DType* cx_ptr,
+                 DType* w_ptr,
+                 DType* y_ptr,
+                 DType* dy_ptr,
+                 DType* dhy_ptr,
+                 DType* dcy_ptr,
+                 DType* dx_ptr,
+                 DType* dhx_ptr,
+                 DType* dcx_ptr,
+                 DType* dw_ptr,
+                 DType* db_ptr,
+                 int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+      break;
+    case rnn_enum::kRnnTanh:
+      break;
+    case rnn_enum::kLstm:
+      LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, 
batch_size,
+                          input_size, state_size, x_ptr, hx_ptr, cx_ptr, 
w_ptr, y_ptr,
+                          dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, 
dw_ptr, db_ptr);
+      break;
+    case rnn_enum::kGru:
+      break;
+  }
+}
 
 template<typename DType>
-class RNNOp<cpu, DType> : public Operator {
+class RNNOp : public Operator{
  public:
-  explicit RNNOp(RNNParam param) {
-    this->param_ = param;
-    // RNN Mode
-    param_.lstm_q_ = false;
-    switch (param_.mode) {
-      case rnn_enum::kLstm:
-        param_.lstm_q_ = true;
-        break;
-      default:
-        LOG(FATAL) << "only LSTM is implmented on CPU";
+  explicit RNNOp(RNNParam p)
+    :param_(p), init_space_(false), reserve_space_size_(0)
+  {}
+
+  ~RNNOp() {
+    if (init_space_) {
+      Storage::Get()->Free(reserve_space_);
+      init_space_ = false;
     }
   }
 
@@ -178,189 +327,219 @@ class RNNOp<cpu, DType> : public Operator {
                        const std::vector<OpReqType> &req,
                        const std::vector<TBlob> &out_data,
                        const std::vector<TBlob> &aux_args) {
-    // Layout TNC
-    CHECK(!ctx.is_train) << "only inference mode is available"
-      "for cpu at the moment.";
-    size_t in_expected = param_.lstm_q_ ? 4 : 3;
-    size_t out_expected = param_.lstm_q_ ? 3 : 2;
-
-    if (!param_.state_outputs)
-      LOG(FATAL) << "no state outputs is currently not supported for cpu.";
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at 
the moment.";
 
-    CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
+    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
+    if (!param_.state_outputs) {
+      out_expected = 1;
+    }
     CHECK_EQ(in_data.size(), in_expected);
     CHECK_EQ(out_data.size(), out_expected);
-
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-    // get input + output tensors
-    // w layout i2h_w, h2h_w, i2h_b, h2h_b
-    Tensor<cpu, 3, DType> x =
-        in_data[rnn_enum::kData].get<cpu, 3, DType>(s);  // TNC
+    Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensor
+    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
     Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx =
-        in_data[rnn_enum::kState].get<cpu, 3, DType>(s);  // LNC
-    Tensor<cpu, 3, DType> y =
-        out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);  // TNC
-    int64_t seq_len = x.shape_[0];
-    int64_t num_layers = hx.shape_[0];
-    int64_t batch_size = x.shape_[1];
-    int64_t h_channel = hx.shape_[2];
-    int64_t in_channel = x.shape_[2];
-    Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
-      .get_with_shape<cpu, 2, DType>(
-          mshadow::Shape2(seq_len * batch_size, in_channel), s);  // (T*N)C
-    Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
-      .get_with_shape<cpu, 2, DType>(
-          mshadow::Shape2(
-              y.shape_[0] * y.shape_[1], y.shape_[2]), s);  // (T*N)C
-
+    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
     CHECK(x.CheckContiguous());
     CHECK(w.CheckContiguous());
     CHECK(hx.CheckContiguous());
     CHECK(y.CheckContiguous());
+    param_.seq_length_ = x.shape_[0];
+    param_.batch_size_ = x.shape_[1];
+    param_.input_size_ = x.shape_[2];
+
+    const int direction = param_.bidirectional ? 2 : 1;
+    const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, 
direction, param_.mode);
+    DType* b_ptr = w.dptr_ + w.shape_[0] - bsize;
+
+    DType* hy_ptr = NULL;
+    if (param_.state_outputs) {
+      hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
+    }
+    DType* cx_ptr = NULL;
+    DType* cy_ptr = NULL;
+
+    if (param_.mode == rnn_enum::kLstm) {
+      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
+      if (param_.state_outputs) {
+        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
+      }
+    }
+
+    // allocate temp space
+    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
+                                                      param_.state_size, 
direction, param_.mode);
+    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+
+    if (ctx.is_train) {
+      const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, 
direction,
+                                                   param_.seq_length_, 
param_.batch_size_,
+                                                   param_.state_size, 
param_.mode);
+      if (init_space_ && reserve_space_size_ < r_size) {
+        Storage::Get()->Free(reserve_space_);
+        init_space_ = false;
+      }
 
-    if (param_.lstm_q_) {
-      const size_t kNumMat = 4;
-      int64_t fused_h_ch = kNumMat * h_channel;
-      int64_t h_size = batch_size * fused_h_ch;
-      int64_t num_dir = 1 + param_.bidirectional;
-      int64_t h2h_w_size = h_channel * fused_h_ch;
-
-      Tensor<cpu, 3, DType> cx =
-          in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
-      CHECK(cx.CheckContiguous());
-
-      Tensor<cpu, 3, DType> cy =
-          out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
-      Tensor<cpu, 3, DType> hy =
-          out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
-      CHECK(cy.CheckContiguous());
-      CHECK(hy.CheckContiguous());
-
-      DType* workspace_addr =
-      static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
-          .get_host_space_internal(sizeof(DType) *
-                                  (seq_len * h_size + h_size
-                                  + y.shape_[0] * y.shape_[1] * y.shape_[2])));
-      Tensor<cpu, 3, DType> i2h_y(
-          workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
-      Tensor<cpu, 2, DType> i2h_y_flatten(
-          workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
-      Tensor<cpu, 2, DType> h2h_y(workspace_addr
-          + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
-      Tensor<cpu, 3, DType> y_tmp(workspace_addr
-          + (seq_len + 1) * h_size, y.shape_);
-      Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
-          + (seq_len + 1) * h_size, y_flatten.shape_);
-      CHECK(i2h_y.CheckContiguous());
-      CHECK(h2h_y.CheckContiguous());
-      CHECK(y_tmp.CheckContiguous());
-
-      for (int64_t layer = 0; layer < num_layers; layer++) {
-        int reverse_dir = 0;
-        int out_tmp = 0;
-        if (param_.bidirectional && layer % 2)
-          reverse_dir = 1;
-        if (layer / num_dir % 2 == 0)
-          out_tmp = 1;
-        mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
-            (layer < num_dir) ? in_channel : num_dir * h_channel);
-        mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
-        int64_t start = layer < num_dir ?
-            (layer * (in_channel * fused_h_ch + h2h_w_size)) :  // input layer
-              (num_dir * (in_channel * fused_h_ch + h2h_w_size)
-              + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
-        Tensor<cpu, 2, DType> i2h_w(w.dptr_ + start, i2h_w_shape);
-        start += layer < num_dir ?
-            in_channel * fused_h_ch : h2h_w_size * num_dir;
-        Tensor<cpu, 2, DType> h2h_w(w.dptr_ + start, h2h_w_shape);
-        start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
-            + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
-              + layer * fused_h_ch * 2;
-        Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
-        start += fused_h_ch;
-        Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
-        if (out_tmp) {
-          linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
-              i2h_y_flatten, false, true, s);
-        } else {
-          linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
-              i2h_y_flatten, false, true, s);
-        }
-        i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
-        for (int64_t t = 0; t < seq_len; t++) {
-          int64_t timestep = t;
-          if (reverse_dir)
-            timestep = seq_len - 1 - t;
-          linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
-              false, true, s);
-          h2h_y += repmat(h2h_b, batch_size);
-          // fused element-wise ops
-          LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
-              y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
-                hy[layer], cy[layer], batch_size, h_channel, t,
-                reverse_dir, out_tmp && (layer == num_layers - 1));
-        }
+      if (!init_space_) {
+        reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), 
Context::CPU());
+        reserve_space_size_ = r_size;
+        init_space_ = true;
       }
+
+      DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
+      RNNForwardTraining<DType>(workspace.dptr_,
+                                reserve_space_ptr,
+                                param_.state_outputs,
+                                param_.num_layers,
+                                direction,
+                                param_.seq_length_,
+                                param_.batch_size_,
+                                param_.input_size_,
+                                param_.state_size,
+                                x.dptr_,
+                                hx.dptr_,
+                                cx_ptr,
+                                w.dptr_,
+                                b_ptr,
+                                y.dptr_,
+                                hy_ptr,
+                                cy_ptr,
+                                param_.mode);
     } else {
-      LOG(FATAL) << "only LSTM is available for cpu at the moment.";
+      RNNForwardInference<DType>(workspace.dptr_,
+                                 param_.state_outputs,
+                                 param_.num_layers,
+                                 direction,
+                                 param_.seq_length_,
+                                 param_.batch_size_,
+                                 param_.input_size_,
+                                 param_.state_size,
+                                 x.dptr_,
+                                 hx.dptr_,
+                                 cx_ptr,
+                                 w.dptr_,
+                                 b_ptr,
+                                 y.dptr_,
+                                 hy_ptr,
+                                 cy_ptr,
+                                 param_.mode);
     }
   }
 
   virtual void Backward(const OpContext &ctx,
                         const std::vector<TBlob> &out_grad,
                         const std::vector<TBlob> &in_data,
-      const std::vector<TBlob> &out_data,
+                        const std::vector<TBlob> &out_data,
                         const std::vector<OpReqType> &req,
                         const std::vector<TBlob> &in_grad,
                         const std::vector<TBlob> &aux_args) {
-    LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
-  }
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at 
the moment.";
+    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
+    if (!param_.state_outputs) {
+      out_expected = 1;
+    }
+    CHECK_EQ(in_data.size(), in_expected);
+    CHECK_EQ(out_data.size(), out_expected);
+    CHECK_EQ(in_grad.size(), in_expected);
+    CHECK_EQ(out_grad.size(), out_expected);
+    CHECK_EQ(req.size(), in_expected);
+    CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for 
data";
+    CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for 
state";
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensors
+    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
+    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
+    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
+    Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, 
DType>(s);
+    Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, 
DType>(s);
+    Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
+    CHECK(x.CheckContiguous());
+    CHECK(w.CheckContiguous());
+    CHECK(hx.CheckContiguous());
+    CHECK(y.CheckContiguous());
+    CHECK(dx.CheckContiguous());
+    CHECK(dw.CheckContiguous());
+    CHECK(dhx.CheckContiguous());
+    CHECK(dy.CheckContiguous());
+    param_.seq_length_ = x.shape_[0];
+    param_.batch_size_ = x.shape_[1];
+    param_.input_size_ = x.shape_[2];
+
+    const int direction = param_.bidirectional ? 2 : 1;
+    const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, 
direction, param_.mode);
+    DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
+
+    DType * dhy_ptr = NULL;
+    if (param_.state_outputs) {
+      dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
+    }
 
- private:
-  RNNParam param_;
+    DType * cx_ptr = NULL;
+    DType * dcx_ptr = NULL;
+    DType * dcy_ptr = NULL;
 
-  void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
-                                  const Tensor<cpu, 2, DType> &cx,
-                                  const Tensor<cpu, 2, DType> &h2h_y,
-                                  const Tensor<cpu, 2, DType> &y,
-                                  // holding intermediate layer output
-                                  const Tensor<cpu, 2, DType> &tmp,
-                                  const Tensor<cpu, 2, DType> &hy,
-                                  const Tensor<cpu, 2, DType> &cy,
-                                  const int64_t batch_size,
-                                  const int64_t h_channel,
-                                  const int64_t t,
-                                  const int reverse_dir,
-                                  const int copy_tmp2y) {
-    int64_t length = batch_size * h_channel;
-    #pragma omp parallel for
-    for (int64_t ji = 0; ji < length; ++ji) {
-      int64_t j = ji / h_channel;  // batch dim
-      int64_t i = ji % h_channel;
-      int64_t f = i + h_channel;
-      int64_t c = i + h_channel * 2;
-      int64_t o = i + h_channel * 3;
-      int64_t j_pos = j * h_channel * 4;
-      h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i];
-      h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f];
-      h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o];
-      h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c];
-      h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + 
i]));
-      h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + 
f]));
-      h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + 
o]));
-      h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]);
-      cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i])
-          + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c];
-      hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]);
-      tmp[j][i + h_channel * reverse_dir] = hy[j][i];
-      if (copy_tmp2y) {
-        y[j][i] = tmp[j][i];
-        if (reverse_dir)
-          y[j][i + h_channel] = tmp[j][i + h_channel];
+    if (param_.mode == rnn_enum::kLstm) {
+      CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported 
for state cell";
+      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
+      dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
+      if (param_.state_outputs) {
+        dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
       }
     }
+
+    // allocate temp space
+    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
+                                                      param_.state_size, 
direction, param_.mode);
+    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+
+    size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+                                           param_.seq_length_, 
param_.batch_size_,
+                                           param_.state_size, param_.mode);
+    if (!init_space_ || reserve_space_size_ != r_size) {
+      LOG(FATAL) << "Check forward init error";
+    }
+
+    DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
+    RNNBackward<DType>(workspace.dptr_,
+                       reserve_space_ptr,
+                       param_.num_layers,
+                       direction,
+                       param_.seq_length_,
+                       param_.batch_size_,
+                       param_.input_size_,
+                       param_.state_size,
+                       x.dptr_,
+                       hx.dptr_,
+                       cx_ptr,
+                       w.dptr_,
+                       y.dptr_,
+                       dy.dptr_,
+                       dhy_ptr,
+                       dcy_ptr,
+                       dx.dptr_,
+                       dhx.dptr_,
+                       dcx_ptr,
+                       dw.dptr_,
+                       db_ptr,
+                       param_.mode);
   }
+
+ private:
+  RNNParam param_;
+  bool init_space_;
+  size_t reserve_space_size_;
+  Storage::Handle reserve_space_;
 };  // class RNNOp
 
 template<typename xpu>
@@ -429,10 +608,10 @@ class RNNProp : public OperatorProperty {
                         Shape3(total_layers, batch_size, param_.state_size));
 
     // calculate parameter vector length
-    int param_size = rnn_param_size(param_.num_layers,
+    int param_size = GetRnnParamSize(param_.num_layers,
                                     input_size,
                                     param_.state_size,
-                                    param_.bidirectional,
+                                    numDirections,
                                     param_.mode);
     SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
 
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index a60adbcd2fb..a8bc9e1e3fb 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -21,9 +21,8 @@
  * Copyright (c) 2015 by Contributors
  * \file rnn.cc
  * \brief
- * \author Sebastian Bodenstein
+ * \author Sebastian Bodenstein, Shu Zhang(shu.zh...@intel.com)
 */
-
 #include "./rnn-inl.h"
 
 namespace mxnet {
@@ -32,7 +31,7 @@ template<>
 Operator *CreateOp<cpu>(RNNParam param, int dtype) {
   Operator *op = NULL;
   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new RNNOp<cpu, DType>(param);
+    op = new RNNOp<DType>(param);
   });
   return op;
 }
diff --git a/src/operator/rnn_impl.hpp b/src/operator/rnn_impl.hpp
new file mode 100644
index 00000000000..765b54ad1ca
--- /dev/null
+++ b/src/operator/rnn_impl.hpp
@@ -0,0 +1,454 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file    rnn_impl.hpp
+ * \brief
+ * \author  Shu Zhang(shu.zh...@intel.com)
+*/
+#ifndef MXNET_OPERATOR_RNN_IMPL_HPP_
+#define MXNET_OPERATOR_RNN_IMPL_HPP_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <algorithm>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+#include "./math.h"
+#include "./math_functions-inl.h"
+#include "./operator_common.h"
+#include "./mshadow_op.h"
+#include "./linalg.h"
+
+template<typename DType>
+inline DType sigmoid(DType x) {
+  return 1.0f / (1.0f + exp(-x));
+}
+
+template<typename DType>
+void LstmForwardTrainingSingleLayer(DType* ws,
+                                    DType* rs,
+                                    bool state_outputs,
+                                    bool bid,
+                                    const int T,
+                                    const int N,
+                                    const int I,
+                                    const int H,
+                                    const Tensor<cpu, 2, DType> &x,
+                                    const Tensor<cpu, 2, DType> &hx,
+                                    const Tensor<cpu, 2, DType> &cx,
+                                    const Tensor<cpu, 3, DType> &y,
+                                    DType* w_ptr,
+                                    DType* b_ptr,
+                                    DType* hy_ptr,
+                                    DType* cy_ptr) {
+  using namespace mshadow;
+  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
+  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
+  const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H));
+  const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H));
+  const Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, 4 * H));
+  const Tensor<cpu, 2, DType> yh_flat(ws + T * N * H * 4, Shape2(N, 4 * H));
+  const Tensor<cpu, 4, DType> yx(yx_flat.dptr_, Shape4(T, N, 4, H));
+  const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H));
+  Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H));
+  DType *c_ptr = bid ? rs + T * N * H * 7 : rs;
+  Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H));
+  Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4));
+
+  const int offset = bid ? H : 0;
+  const DType alpha = 1.0;
+  const DType beta = 0.0;
+  const int cell_size = N * H;
+  linalg_gemm(x, wx, yx_flat, alpha, beta, false, true);
+
+  for (int i = 0; i < T; ++i) {
+    int t = bid ? T - 1 - i : i;
+    linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
+    #pragma omp parallel for
+    for (int jk = 0; jk < cell_size; ++jk) {
+      int j = jk / H;
+      int k = jk % H;
+      DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + 
bh[0][k]);
+      DType ft = sigmoid<DType>(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + 
bh[1][k]);
+      DType gt =           tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + 
bh[2][k]);
+      DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + 
bh[3][k]);
+      DType ct = (i ? c[i-1][j][k] : cx[j][k]) * ft + it * gt;
+      DType ht = ot * tanh(ct);
+      h[j][k] = ht;
+      // reserve
+      y[t][j][k + offset] = ht;
+      c[i][j][k] = ct;
+      ifgo[i][j][k][0] = it;
+      ifgo[i][j][k][1] = ft;
+      ifgo[i][j][k][2] = gt;
+      ifgo[i][j][k][3] = ot;
+      if (i == T - 1 && state_outputs) {
+        hy_ptr[jk] = ht;
+        cy_ptr[jk] = ct;
+      }
+    }
+  }
+}
+
+template <typename DType>
+void LstmForwardTraining(DType* ws,
+                         DType* rs,
+                         bool state_outputs,
+                         const int L,
+                         const int D,
+                         const int T,
+                         const int N,
+                         const int I,
+                         const int H,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* cx_ptr,
+                         DType* w_ptr,
+                         DType* b_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr,
+                         DType* cy_ptr) {
+  const int total_layers = D * L;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
+  Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
+  const int b_size = 2 * H * 4;
+  const int r_size = D * T * N * H * 6;
+  const int y_offset = T * N * H * 5;
+  const int cell_size = N * H;
+  int idx = 0;  // state & cell state's idx;
+  for (int i = 0; i < L; ++i) {
+    const int input_size = i ? H * D : I;
+    const int w_size = (input_size + H) * H * 4;
+    Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
+    Tensor<cpu, 3, DType> y(rs + y_offset, Shape3(T, N, H * D));
+    LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, false, T, N, 
input_size, H, x,
+                                          hx[idx], cx[idx], y, w_ptr, b_ptr, 
hy_ptr, cy_ptr);
+    if (D == 2) {
+      w_ptr += w_size;
+      b_ptr += b_size;
+      ++idx;
+      if (state_outputs) {
+        hy_ptr += cell_size;
+        cy_ptr += cell_size;
+      }
+      LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, true, T, N, 
input_size, H, x,
+                                            hx[idx], cx[idx], y, w_ptr, b_ptr, 
hy_ptr, cy_ptr);
+    }
+    if (i != L - 1) {
+      w_ptr += w_size;
+      b_ptr += b_size;
+      x_ptr = y.dptr_;
+      rs += r_size;
+      ++idx;
+      if (state_outputs) {
+        hy_ptr += cell_size;
+        cy_ptr += cell_size;
+      }
+    }
+  }
+  memcpy(y_ptr, rs + y_offset, T * N * H * D * sizeof(DType));
+}
+
+template<typename DType>
+void LstmForwardInferenceSingleLayer(DType* ws,
+                                     bool state_outputs,
+                                     bool bid,
+                                     const int T,
+                                     const int N,
+                                     const int I,
+                                     const int H,
+                                     const Tensor<cpu, 2, DType> &x,
+                                     const Tensor<cpu, 2, DType> &hx,
+                                     const Tensor<cpu, 2, DType> &cx,
+                                     const Tensor<cpu, 3, DType> &y,
+                                     DType* w_ptr,
+                                     DType* b_ptr,
+                                     DType* hy_ptr,
+                                     DType* cy_ptr) {
+  using namespace mshadow;
+  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
+  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
+  const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H));
+  const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H));
+  Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, H * 4));
+  Tensor<cpu, 2, DType> yh_flat(ws + T * N * H * 4, Shape2(N, H * 4));
+  const Tensor<cpu, 4, DType> yx(yx_flat.dptr_, Shape4(T, N, 4, H));
+  const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H));
+  Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H));
+  Tensor<cpu, 2, DType> c(h.dptr_ + N * H, Shape2(N, H));
+  const int offset = bid ? H : 0;
+  const DType alpha = 1.0;
+  const DType beta = 0.0;
+  const int cell_size = N * H;
+  linalg_gemm(x, wx, yx_flat, alpha, beta, false, true);
+
+  for (int i = 0; i < T; ++i) {
+    int t = bid ? T - 1 - i : i;
+    linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
+    #pragma omp parallel for
+    for (int jk = 0; jk < cell_size; ++jk) {
+      int j = jk / H;
+      int k = jk % H;
+      DType it = sigmoid<DType>(yx[t][j][0][k] + yh[j][0][k] + bx[0][k] + 
bh[0][k]);
+      DType ft = sigmoid<DType>(yx[t][j][1][k] + yh[j][1][k] + bx[1][k] + 
bh[1][k]);
+      DType gt =           tanh(yx[t][j][2][k] + yh[j][2][k] + bx[2][k] + 
bh[2][k]);
+      DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + 
bh[3][k]);
+      DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt;
+      DType ht = ot * tanh(ct);
+      y[t][j][k + offset] = ht;
+      if (i == T - 1 && state_outputs) {
+        hy_ptr[jk] = ht;
+        cy_ptr[jk] = ct;
+      } else {
+        h[j][k] = ht;
+        c[j][k] = ct;
+      }
+    }
+  }
+}
+
+template <typename DType>
+void LstmForwardInference(DType* ws,
+                          bool state_outputs,
+                          const int L,
+                          const int D,
+                          const int T,
+                          const int N,
+                          const int I,
+                          const int H,
+                          DType* x_ptr,
+                          DType* hx_ptr,
+                          DType* cx_ptr,
+                          DType* w_ptr,
+                          DType* b_ptr,
+                          DType* y_ptr,
+                          DType* hy_ptr,
+                          DType* cy_ptr) {
+  const int total_layers = D * L;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
+  Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
+  const int b_size = 2 * H * 4;
+  const int cell_size = N * H;
+  DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 3;
+  DType* y_cur_ptr = y_ptr;
+  int idx = 0;  // state & cell state's idx;
+  bool flag = L % 2 ? false : true;
+  for (int i = 0; i < L; ++i) {
+    const int input_size = i ? H * D : I;
+    const int w_size = (input_size + H) * H * 4;
+    // If bidirectional, need space to save current layer output y.
+    if (D == 2) {
+      y_cur_ptr = flag ? y_tmp_ptr : y_ptr;
+      flag = !flag;
+    }
+    Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
+    Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, H * D));
+    LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, false, T, N, 
input_size, H,
+                                           x, hx[idx], cx[idx], y, w_ptr, 
b_ptr, hy_ptr, cy_ptr);
+    // If bidirectional, then calculate the reverse direction's forward result.
+    if (D == 2) {
+      w_ptr += w_size;
+      b_ptr += b_size;
+      ++idx;
+      if (state_outputs) {
+        hy_ptr += cell_size;
+        cy_ptr += cell_size;
+      }
+      LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, true, T, N, 
input_size, H,
+                                             x, hx[idx], cx[idx], y, w_ptr, 
b_ptr, hy_ptr, cy_ptr);
+    }
+    // Don't need to move pointer in the last layer.
+    if (i != L - 1) {
+      w_ptr += w_size;
+      b_ptr += b_size;
+      x_ptr = y_cur_ptr;
+      ++idx;
+      if (state_outputs) {
+        hy_ptr += cell_size;
+        cy_ptr += cell_size;
+      }
+    }
+  }
+}
+
+template <typename DType>
+void LstmBackwardSingleLayer(DType* ws,
+                             DType* rs,
+                             bool bid,
+                             const int T,
+                             const int N,
+                             const int I,
+                             const int H,
+                             const Tensor<cpu, 2, DType> &x,
+                             const Tensor<cpu, 2, DType> &hx,
+                             const Tensor<cpu, 2, DType> &cx,
+                             const Tensor<cpu, 3, DType> &y,
+                             const Tensor<cpu, 3, DType> &dy,
+                             const Tensor<cpu, 2, DType> &dx,
+                             const Tensor<cpu, 2, DType> &dhx,
+                             const Tensor<cpu, 2, DType> &dcx,
+                             DType* dhy_ptr,
+                             DType* dcy_ptr,
+                             DType* w_ptr,
+                             DType* dw_ptr,
+                             DType* db_ptr) {
+  using namespace mshadow;
+  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
+  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
+  Tensor<cpu, 2, DType> dwx(dw_ptr, Shape2(H * 4, I));
+  Tensor<cpu, 2, DType> dwh(dw_ptr + I * H * 4, Shape2(H * 4, H));
+  Tensor<cpu, 1, DType> dbx(db_ptr, Shape1(H * 4));
+  Tensor<cpu, 1, DType> dbh(dbx.dptr_ + H * 4, Shape1(H * 4));
+  DType *c_ptr = bid ? rs + T * N * H * 7 : rs;
+  const Tensor<cpu, 3, DType> c(c_ptr, Shape3(T, N, H));
+  const Tensor<cpu, 4, DType> ifgo(c_ptr + T * N * H, Shape4(T, N, H, 4));
+  memset(dwh.dptr_, 0, H * H * 4 * sizeof(DType));
+  memset(dbx.dptr_, 0, H * 4 * sizeof(DType));
+  memset(dbh.dptr_, 0, H * 4 * sizeof(DType));
+  Tensor<cpu, 4, DType> difgo(ws, Shape4(T, N, 4, H));
+  Tensor<cpu, 2, DType> dh(ws + T * N * H * 4, Shape2(N, H));
+  Tensor<cpu, 2, DType> dc(dh.dptr_ + N * H, Shape2(N, H));
+  Tensor<cpu, 2, DType> htmp(dc.dptr_ + N * H, Shape2(N, H));
+  const int offset = bid ? H : 0;
+  const DType alpha = 1.0;
+  const DType beta0 = 0.0;
+  const DType beta1 = 1.0;
+  const int cell_size = N * H;
+  if (dhy_ptr != NULL) {
+    memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType));
+  }
+  if (dcy_ptr != NULL) {
+    memcpy(dc.dptr_, dcy_ptr, cell_size * sizeof(DType));
+  }
+  for (int i = T - 1; i >= 0; --i) {
+    int t = bid ? T - 1 - i : i;
+    int tnext = bid ? t + 1 : t - 1;
+    const Tensor<cpu, 2, DType>& dhnext = i ? dh : dhx;
+    const Tensor<cpu, 2, DType>& dcnext = i ? dc : dcx;
+    const Tensor<cpu, 2, DType>& hnext = i ? htmp : hx;
+    const Tensor<cpu, 2, DType>& cnext = i ? c[i - 1] : cx;
+    #pragma omp parallel for
+    for (int jk = 0; jk < cell_size; ++jk) {
+      int j = jk / H;
+      int k = jk % H;
+      DType tc = tanh(c[i][j][k]);
+      DType it = ifgo[i][j][k][0];
+      DType ft = ifgo[i][j][k][1];
+      DType gt = ifgo[i][j][k][2];
+      DType ot = ifgo[i][j][k][3];
+      dh[j][k] += dy[t][j][k + offset];
+      dc[j][k] += dh[j][k] * ot * (1 - tc * tc);
+      difgo[t][j][0][k] = dc[j][k] * gt * it * (1 - it);
+      difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft);
+      difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt);
+      difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot);
+      dcnext[j][k] = dc[j][k] * ft;
+      if (i) {
+        htmp[j][k] = y[tnext][j][k + offset];
+      }
+    }
+    Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
+    linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
+    linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
+  }
+  Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4));
+  linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
+  linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
+  const int row = T * N;
+  const int col = H * 4;
+  for (int i = 0; i < row; ++i) {
+    #pragma omp parallel for
+    for (int j = 0; j < col; ++j) {
+      dbx[j] += dyx[i][j];
+      dbh[j] = dbx[j];
+    }
+  }
+}
+
+template <typename DType>
+void LstmBackward(DType* ws,
+                  DType* rs,
+                  const int L,
+                  const int D,
+                  const int T,
+                  const int N,
+                  const int I,
+                  const int H,
+                  DType* x_ptr,
+                  DType* hx_ptr,
+                  DType* cx_ptr,
+                  DType* w_ptr,
+                  DType* y_ptr,
+                  DType* dy_ptr,
+                  DType* dhy_ptr,
+                  DType* dcy_ptr,
+                  DType* dx_ptr,
+                  DType* dhx_ptr,
+                  DType* dcx_ptr,
+                  DType* dw_ptr,
+                  DType* db_ptr) {
+  const int total_layers = D * L;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
+  Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
+  Tensor<cpu, 3, DType> dhx(dhx_ptr, Shape3(total_layers, N, H));
+  Tensor<cpu, 3, DType> dcx(dcx_ptr, Shape3(total_layers, N, H));
+  const int b_size = 2 * H * 4;
+  const int r_size = D * T * N * H * 6;
+  const int y_offset = T * N * H * 5;
+  const int w_size1 = (I + H) * H * 4;      // first layer
+  const int w_size2 = (D * H + H) * H * 4;  // other layers
+  const int cell_size = N * H;
+  DType* dy_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 3;
+  for (int i = L - 1; i >= 0; --i) {
+    const int input_size = i ? H * D : I;
+    const int w_size = i ? w_size2 : w_size1;
+    int idx = i * D;
+    DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr;
+    DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : 
dw_ptr;
+    DType* db_cur_ptr = db_ptr + i * b_size * D;
+    DType* rs_cur_ptr = rs + i * r_size;
+    DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : NULL;
+    DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : NULL;
+    Tensor<cpu, 3, DType> y(rs_cur_ptr + y_offset, Shape3(T, N, H * D));
+    Tensor<cpu, 3, DType> dy(dy_ptr, Shape3(T, N, H * D));
+    Tensor<cpu, 2, DType> x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, 
input_size));
+    Tensor<cpu, 2, DType> dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, 
input_size));
+    LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, false, T, N, input_size, H,
+                                   x, hx[idx], cx[idx], y, dy, dx, dhx[idx], 
dcx[idx],
+                                   dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, 
dw_cur_ptr, db_cur_ptr);
+    if (D == 2) {
+      w_cur_ptr += w_size;
+      dw_cur_ptr += w_size;
+      db_cur_ptr += b_size;
+      ++idx;
+      dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : NULL;
+      dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : NULL;
+      LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, true, T, N, input_size, H,
+                                     x, hx[idx], cx[idx], y, dy, dx, dhx[idx], 
dcx[idx],
+                                     dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, 
dw_cur_ptr, db_cur_ptr);
+    }
+    dy_ptr = dx.dptr_;
+  }
+}
+#endif  // MXNET_OPERATOR_RNN_IMPL_HPP_
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 2dd66ee2d10..ae0ce459a7b 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1239,7 +1239,6 @@ def test_rnn():
     check_rnn_consistency(fused, stack)
     check_rnn_consistency(stack, fused)
 
-
 @with_seed()
 def test_lstm():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='lstm', prefix='')
@@ -1251,7 +1250,6 @@ def test_lstm():
     check_rnn_consistency(fused, stack)
     check_rnn_consistency(stack, fused)
 
-
 @with_seed()
 def test_lstm_forget_bias():
     forget_bias = 2.0
@@ -1273,7 +1271,6 @@ def test_lstm_forget_bias():
     expected_bias = forget_bias * np.ones(10, )
     assert_allclose(args[bias_name].asnumpy(), expected_bias)
 
-
 @with_seed()
 def test_gru():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='')
@@ -1285,7 +1282,6 @@ def test_gru():
     check_rnn_consistency(fused, stack)
     check_rnn_consistency(stack, fused)
 
-
 @with_seed()
 def test_bidirectional():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='',
@@ -1304,7 +1300,6 @@ def test_bidirectional():
     check_rnn_consistency(fused, stack)
     check_rnn_consistency(stack, fused)
 
-
 @with_seed()
 def test_unfuse():
     for mode in ['rnn_tanh', 'rnn_relu', 'lstm', 'gru']:
@@ -1486,7 +1481,6 @@ def test_deformable_convolution_options():
     sym = mx.sym.contrib.DeformableConvolution(num_filter=4, kernel=(3,3), 
num_deformable_group=2,
                                                name='deformable_conv')
 
-
 @with_seed()
 def test_residual_fused():
     cell = mx.rnn.ResidualCell(
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index f22b13d6575..aea071e1044 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -67,6 +67,7 @@ def test_lstm_forget_bias():
                                forget_bias * np.ones(100, ), np.zeros((2 * 
100,))])
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), 
expected_bias)
 
+@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. 
Tracked at https://github.com/apache/incubator-mxnet/issues/10104";)
 def test_lstm_cpu_inference():
     # should behave the same as lstm cell
     EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 
0.95215213],
@@ -272,6 +273,7 @@ def check_rnn_layer_forward(layer, inputs, states=None):
     mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, 
atol=1e-5)
 
 
+@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. 
Tracked at https://github.com/apache/incubator-mxnet/issues/10104";)
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), 
mx.nd.ones((2, 3, 10)))
@@ -370,6 +372,7 @@ def test_cell_fill_shape():
     check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
     assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
 
+@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. 
Tracked at https://github.com/apache/incubator-mxnet/issues/10104";)
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)
     layer.hybridize()
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 00394496868..1886241b5b9 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -28,6 +28,69 @@
 from common import setup_module, with_seed
 import unittest
 
+def check_rnn_consistency(cell1, cell2, T, N, I, H):
+    dshape = (N, T, I)
+    data = mx.sym.Variable('data')
+
+    Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True)
+    mod1 = mx.mod.Module(Y1, label_names=None, context=mx.cpu())
+    mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, 
inputs_need_grad=True)
+
+    Y2, _ = cell2.unroll(T, data, layout='NTC', merge_outputs=True)
+    mod2 = mx.mod.Module(Y2, label_names=None, context=mx.cpu())
+    mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, 
inputs_need_grad=True)
+
+    mod1.init_params()
+    args, auxs = mod1.get_params()
+    args = cell1.unpack_weights(args)
+    args = cell2.pack_weights(args)
+    mod2.set_params(args, auxs)
+
+    x = mx.random.uniform(shape=dshape)
+    batch=mx.io.DataBatch(data=[x])
+    # check inference
+    mod1.forward(batch, is_train=False)
+    mod2.forward(batch, is_train=False)
+    assert_allclose(mod1.get_outputs()[0].asnumpy(), 
mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+    
+    # check training
+    mod1.forward(batch, is_train=True)
+    mod2.forward(batch, is_train=True)
+    assert_allclose(mod1.get_outputs()[0].asnumpy(), 
mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+
+    dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape)
+    mod1.backward(out_grads=[dy])
+    mod2.backward(out_grads=[dy])
+    assert_allclose(mod1.get_input_grads()[0].asnumpy(), 
mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+
+@with_seed(0)
+def test_lstm():
+    T, N, I, H = 5, 32, 800, 800
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', 
get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.LSTMCell(H, prefix='l0_'))
+    stack.add(mx.rnn.LSTMCell(H, prefix='l1_'))
+    stack.add(mx.rnn.LSTMCell(H, prefix='l2_'))
+    check_rnn_consistency(fused, stack, T, N, I, H)
+
+@with_seed(0)
+def test_lstm_bidirectional():
+    T, N, I, H = 5, 20, 800, 800
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm',
+                                bidirectional=True, get_next_state=True, 
prefix='')
+
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.LSTMCell(H, prefix='l0_'),
+                mx.rnn.LSTMCell(H, prefix='r0_'),
+                output_prefix='bi_lstm_0_'))
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.LSTMCell(H, prefix='l1_'),
+                mx.rnn.LSTMCell(H, prefix='r1_'),
+                output_prefix='bi_lstm_1_'))
+
+    check_rnn_consistency(stack, fused, T, N, I, H)
+
 
 def np_softmax(x, axis=-1):
     # fix for old numpy on Travis not supporting keepdims


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to