This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new ad1ff3a  [v1.6.x] Cherry-pick MKL-DNN Rnn operator enhancements to 
v1.6.x (#17225)
ad1ff3a is described below

commit ad1ff3aa8532f2f7e42a732f0d35dfb0574ca05c
Author: Zixuan Wei <zixuan....@intel.com>
AuthorDate: Tue Jan 7 09:50:53 2020 +0800

    [v1.6.x] Cherry-pick MKL-DNN Rnn operator enhancements to v1.6.x (#17225)
    
    * [MKLDNN] mkldnn RNN operator enhancement (#17075)
    
    * mkldnn rnn operator enhancement
    
    `add` operation support
    
    Rename AddTo
    
    Add MXNET_USE_MKLDNN_RNN env
    
    Add Env var for switching to naive RNN impl and naive add/copy impl
    
    * Re-run CI, op:test_reduce failed on Unix-CPU
    
    * Rerun CI, Python2 CPU on Unix-CPU timeout
    
    * MKL-DNN RNN backward path enhancement (#17183)
    
    * Flush memory before RNN backward primitive
    
    * Add gluon rnn unit test for gradients check
    
    * Cache reorder
    
    * Re-write rnn supporting check
    
    * Update OpSignature.AddSign to avoid potential hash collision for
    rnn-packed memory
    
    Get the data type from mkldnn memory descriptor when setting grad handle
---
 docs/static_site/src/pages/api/faq/env_var.md |  12 +-
 src/common/utils.h                            |  20 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h      |   9 +-
 src/operator/nn/mkldnn/mkldnn_rnn-inl.h       |  38 ++-
 src/operator/nn/mkldnn/mkldnn_rnn.cc          | 430 +++++++++++++++-----------
 src/operator/operator_common.h                |  18 ++
 src/operator/rnn.cc                           |  13 +-
 tests/python/unittest/test_gluon_rnn.py       | 124 ++++++++
 tests/python/unittest/test_operator.py        |  21 +-
 9 files changed, 466 insertions(+), 219 deletions(-)

diff --git a/docs/static_site/src/pages/api/faq/env_var.md 
b/docs/static_site/src/pages/api/faq/env_var.md
index e4fe58a..d63da61 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -283,11 +283,11 @@ If ctypes is used, it must be 
`mxnet._ctypes.ndarray.NDArrayBase`.
   If no such algorithm exists given other constraints, MXNet will error out. 
This variable affects the choice
   of CUDNN convolution algorithms. Please see [CUDNN developer 
guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html)
 for more details.
 
-* MXNET_CPU_PARALLEL_COPY_SIZE
+* MXNET_CPU_PARALLEL_SIZE
   - Values: Int ```(default=200000)```
-  - The minimum size to call parallel copy by OpenMP in CPU2CPU mode.
-  - When the array size is bigger than or equal to  this threshold, 
NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP 
Thread Count.
-  - When the array size is less than this threshold, NDArray::Copy(from , to)) 
is implemented by memcpy in single thread.
+  - The minimum size to call parallel operations by OpenMP for CPU context.
+  - When the array size is bigger than or equal to this threshold, the 
operation implemented by OpenMP is executed with the Recommended OMP Thread 
Count.
+  - When the array size is less than this threshold, the operation is 
implemented naively in single thread.
 
 * MXNET_OPTIMIZER_AGGREGATION_SIZE
   - Values: Int ```(default=4)```
@@ -343,6 +343,10 @@ If ctypes is used, it must be 
`mxnet._ctypes.ndarray.NDArrayBase`.
   - Values: 0(false) or 1(true) ```(default=1)```
   - If this variable is set, MXNet will simplify the computation graph, 
eliminating duplicated operations on the same inputs.
 
+* MXNET_USE_MKLDNN_RNN
+  - Values: 0(false) or 1(true) ```(default=1)```
+  - This variable controls whether to use the MKL-DNN backend in fused RNN 
operator for CPU context. There are two fusion implementations of RNN operator 
in MXNet. The MKL-DNN implementation has a better performance than the naive 
one, but the latter is more stable in the backward operation currently.
+
 Settings for Minimum Memory Usage
 ---------------------------------
 - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/src/common/utils.h b/src/common/utils.h
index 0e3e354..fcb61b7 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -760,7 +760,7 @@ inline void EmplaceBackZeros(const NDArrayStorageType 
stype, const mxnet::TShape
  */
 template<typename DType>
 inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
-  static index_t copy_block_size = 
dmlc::GetEnv("MXNET_CPU_PARALLEL_COPY_SIZE", 200000);
+  static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 
200000);
   if (size >= copy_block_size) {
     #pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
     for (index_t i = 0; i < size; ++i) {
@@ -772,6 +772,24 @@ inline void ParallelCopy(DType* dst, const DType* src, 
index_t size) {
 }
 
 /*!
+ * \breif parallelize add by OpenMP
+ */
+template<typename DType>
+inline void ParallelAdd(DType* dst, const DType* src, index_t size) {
+  static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 
200000);
+  if (size >= add_block_size) {
+    #pragma omp parallel for 
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+    for (index_t i = 0; i < size; ++i) {
+      dst[i] += src[i];
+    }
+  } else {
+    for (index_t i = 0; i < size; ++i) {
+      dst[i] += src[i];
+    }
+  }
+}
+
+/*!
  * \brief If numpy compatibility is turned off (default), the shapes passed in
  * by users follow the legacy shape definition:
  * 1. 0 ndim means the shape is completely unknown.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 9bfc20c..9763c42 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -132,9 +132,12 @@ static inline bool SupportMKLDNN(int dtype, const 
mxnet::TShape &shape) {
   return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
 }
 
-static inline bool SupportMKLDNNRNN(const NDArray &input) {
-  int ndim = input.shape().ndim();
-  return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
+static inline bool SupportMKLDNNRnn(const NDArray &input) {
+  if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
+      && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
+    return true;
+  }
+  return false;
 }
 
 static inline bool SupportMKLDNNQuantize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h 
b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
index ad3f733..a4104bf 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
@@ -120,25 +120,24 @@ class RnnPrimitive {
   template<typename rnn_fwd, typename... Args>
   static RnnPrimitive Create(Args&&... args) {
     RnnPrimitive rnn_fwd_prim;
-    rnn_fwd_prim.pd_.reset(
-      new typename rnn_fwd::desc(std::forward<Args>(args)...),
-      [](typename rnn_fwd::desc* pd) {
-        delete reinterpret_cast<typename rnn_fwd::desc*>(pd);
+    auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
+    rnn_fwd_prim.fwd_pd_.reset(
+      new typename rnn_fwd::primitive_desc(fwd_desc, 
CpuEngine::Get()->get_engine()),
+      [](typename rnn_fwd::primitive_desc* pd) {
+        delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
       });
-    const typename rnn_fwd::desc& fwd_desc =
-        *(reinterpret_cast<typename rnn_fwd::desc*>(rnn_fwd_prim.pd_.get()));
-    typename rnn_fwd::primitive_desc fwd_pd(fwd_desc, 
CpuEngine::Get()->get_engine());
-    rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc();
-    rnn_fwd_prim.weights_iter_desc_  = fwd_pd.weights_iter_desc();
-    rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc();
+    auto fwd_pd = reinterpret_cast<typename 
rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
+    rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
+    rnn_fwd_prim.weights_iter_desc_  = fwd_pd->weights_iter_desc();
+    rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
 
-    rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new 
rnn_fwd(fwd_pd));
+    rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new 
rnn_fwd(*fwd_pd));
 
     return rnn_fwd_prim;
   }
 
   RnnPrimitive() {
-    this->pd_ = nullptr;
+    this->fwd_pd_ = nullptr;
     this->primitive_ = nullptr;
     this->weights_layer_desc_ = mkldnn::memory::desc();
     this->weights_iter_desc_ = mkldnn::memory::desc();
@@ -146,7 +145,7 @@ class RnnPrimitive {
   }
 
   RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
-    this->pd_ = rnn_fwd_prim.pd_;
+    this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
     this->primitive_ = rnn_fwd_prim.primitive_;
     this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
     this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -155,7 +154,7 @@ class RnnPrimitive {
 
   RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
     if (this != &rnn_fwd_prim) {
-      this->pd_ = rnn_fwd_prim.pd_;
+      this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
       this->primitive_ = rnn_fwd_prim.primitive_;
       this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
       this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -165,7 +164,7 @@ class RnnPrimitive {
     return *this;
   }
 
-  const void* GetPrimDesc() const { return pd_.get(); }
+  const void* GetPrimDesc() const { return fwd_pd_.get(); }
   const mkldnn::primitive& GetPrim() const { return *primitive_; }
 
   const mkldnn::memory::desc& GetLayerDesc() const {
@@ -181,7 +180,7 @@ class RnnPrimitive {
   }
 
  private:
-  std::shared_ptr<void> pd_;
+  std::shared_ptr<void> fwd_pd_;
   std::shared_ptr<mkldnn::primitive> primitive_;
   mkldnn::memory::desc weights_layer_desc_;
   mkldnn::memory::desc weights_iter_desc_;
@@ -370,7 +369,10 @@ class MKLDNNRnnBackward {
   void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell,
                        void* diff_out, void* diff_state_out, void* 
diff_statecell_out,
                        const int dtype = mshadow::kFloat32);
-  void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype 
= mshadow::kFloat32);
+  void SetNativeWeightsGrads() const;
+  void CommitWeightsGrads(void* diff_weights, void* diff_bias,
+                          const OpReqType req,
+                          const int dtype = mshadow::kFloat32);
 
   const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; }
   const mkldnn_args_map_t& GetArgsMap() const { return net_args_; }
@@ -385,6 +387,8 @@ class MKLDNNRnnBackward {
 
   mkldnn_shared_mem_t diff_weights_layer_;
   mkldnn_shared_mem_t diff_weights_iter_;
+  mkldnn_shared_mem_t diff_weights_layer_r_;
+  mkldnn_shared_mem_t diff_weights_iter_r_;
   mkldnn_shared_mem_t diff_bias_;
 
   mkldnn_args_map_t net_args_;
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc 
b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index e797b64..8af0e99 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -213,13 +213,13 @@ RnnBwdPrimitive GetRnnBwdPrim(const 
MKLDNNRnnForwardTraining &fwd,
   auto dst_state_desc = layer_param.state_outputs ? memory::desc(
       layer_param.state_dims, data_type, tag::ldnc) : memory::desc();
 
-  const void* fwd_desc = fwd.GetPrimDesc();
+  const void* fwd_pd = fwd.GetPrimDesc();
   auto bwd = RnnBwdPrimitive();
   switch (mode) {
     case rnn_enum::kLstm: {
-      const lstm_forward::primitive_desc* desc =
-          reinterpret_cast<const lstm_forward::primitive_desc*>(fwd_desc);
-      bwd = RnnBwdPrimitive::Create<lstm_forward, lstm_backward>(*desc,
+      const lstm_forward::primitive_desc* pd =
+          reinterpret_cast<const lstm_forward::primitive_desc*>(fwd_pd);
+      bwd = RnnBwdPrimitive::Create<lstm_forward, lstm_backward>(*pd,
           prop, mkldnn_rnn_direction,
           // data desc
           src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc,
@@ -231,9 +231,9 @@ RnnBwdPrimitive GetRnnBwdPrim(const 
MKLDNNRnnForwardTraining &fwd,
           dst_state_desc);
     } break;
     case rnn_enum::kGru: {
-      const lbr_gru_forward::primitive_desc* desc =
-          reinterpret_cast<const lbr_gru_forward::primitive_desc*>(fwd_desc);
-      bwd = RnnBwdPrimitive::Create<lbr_gru_forward, lbr_gru_backward>(*desc,
+      const lbr_gru_forward::primitive_desc* pd =
+          reinterpret_cast<const lbr_gru_forward::primitive_desc*>(fwd_pd);
+      bwd = RnnBwdPrimitive::Create<lbr_gru_forward, lbr_gru_backward>(*pd,
           prop, mkldnn_rnn_direction,
           // data desc
           src_layer_desc, src_state_desc, weight_layer_desc,
@@ -244,10 +244,10 @@ RnnBwdPrimitive GetRnnBwdPrim(const 
MKLDNNRnnForwardTraining &fwd,
     } break;
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh: {
-      const vanilla_rnn_forward::primitive_desc* desc =
-          reinterpret_cast<const 
vanilla_rnn_forward::primitive_desc*>(fwd_desc);
+      const vanilla_rnn_forward::primitive_desc* pd =
+          reinterpret_cast<const vanilla_rnn_forward::primitive_desc*>(fwd_pd);
       bwd = RnnBwdPrimitive::Create<vanilla_rnn_forward, vanilla_rnn_backward>(
-          *desc, prop,
+          *pd, prop,
           mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : 
algorithm::eltwise_relu,
           mkldnn_rnn_direction,
           // data desc
@@ -364,18 +364,38 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, 
void* cx,
   }
 }
 
+inline void MKLDNNMemoryReorder(const mkldnn::memory& src,
+                                const mkldnn::memory& dst) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature,
+      mkldnn::reorder, OpHash> reorderPrimitives;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature,
+      mkldnn::reorder, OpHash> reorderPrimitives;
+#endif
+  OpSignature key{};
+  key.AddSign(src);
+  key.AddSign(dst);
+
+  auto it = reorderPrimitives.find(key);
+  if (it == reorderPrimitives.end()) {
+    auto reorder = mkldnn::reorder(src, dst);
+    it = AddToCache(&reorderPrimitives, key, reorder);
+  }
+
+  mkldnn_args_map_t net_args;
+  net_args.emplace(MKLDNN_ARG_SRC, src);
+  net_args.emplace(MKLDNN_ARG_DST, dst);
+  MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args);
+}
+
 /*
  * Reorder the concatenated weights memory to a efficient memory block
  * with primitive-prefered format.
  */
 void MKLDNNRnnForward::ReorderWeights() {
-  auto& cpu_engine = CpuEngine::Get()->get_engine();
-  mkldnn::stream s(cpu_engine);
-  mkldnn::reorder(*weights_layer_r_, *weights_layer_)
-      .execute(s, *weights_layer_r_, *weights_layer_);
-  mkldnn::reorder(*weights_iter_r_, *weights_iter_)
-      .execute(s, *weights_iter_r_, *weights_iter_);
-  s.wait();
+  MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_);
+  MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_);
 }
 
 void AdjustGruGateOrder(char* weight,
@@ -394,7 +414,7 @@ void AdjustGruGateOrder(char* weight,
  * Fuse uni-directional bias among single layer.
  */
 template <typename DType>
-void FuseBias(DType* fuse_bias, DType* naive_bias,
+void FuseBias(DType* fuse_bias, DType* native_bias,
               const int mode, const size_t state_size) {
   const size_t ngates = GetRnnGatesNum(mode);
   const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
@@ -403,8 +423,8 @@ void FuseBias(DType* fuse_bias, DType* naive_bias,
   // OpenMP 'for' statement.
   const int state_size_ = static_cast<int>(state_size);
   const int single_b_sz = static_cast<int>(nbias * state_size);
-  DType* bx = naive_bias;
-  DType* bh = naive_bias + state_size * ngates;
+  DType* bx = native_bias;
+  DType* bh = native_bias + state_size * ngates;
   if (mode == rnn_enum::kGru) {
     // While mxnet gru gate order is reset, update and new gates,
     // mkldnn gru gate order is update, reset and new gates. So
@@ -528,12 +548,6 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, 
void *w_ptr, void *b_
       }
     }
   }
-  // Reorder after adjustment only when is_train == false. When is_train == 
true, i.e.
-  // in forward training path, we use plain memory (ldxxx) as the space for 
weights and
-  // their gradients. Then, forward training primitives could fetch them from 
the scope
-  // of forward inference. And from there, we don't need to reorder the plain 
memory to
-  // the optimal rnn-packed memory for forward inference.
-  ReorderWeights();
 
   // Process bias
   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
@@ -553,7 +567,15 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, 
void *w_ptr, void *b_
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER,  
this->weights_iter_);
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS,          this->bias_);
 
-  initialized_ = true;
+  if (!is_train) {
+    // Reorder after adjustment only when is_train == false. When is_train == 
true, i.e.
+    // in forward training path, we use plain memory (ldxxx) as the space for 
weights and
+    // their gradients. Then, forward training primitives could fetch them 
from the scope
+    // of forward inference. And from there, we don't need to reorder the 
plain memory to
+    // the optimal rnn-packed memory for forward inference.
+    ReorderWeights();
+    initialized_ = true;
+  }
 }
 
 void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) {
@@ -572,17 +594,14 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const 
MKLDNNRnnForward& fwd) {
   if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) {
     weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle());
   } else {
-    mkldnn::reorder(*fwd.weights_layer_r_, *weights_layer_)
-        .execute(s, *fwd.weights_layer_r_, *weights_layer_);
+    MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_);
   }
 
   if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) {
     weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle());
   } else {
-    mkldnn::reorder(*fwd.weights_iter_r_, *weights_iter_)
-        .execute(s, *fwd.weights_iter_r_, *weights_iter_);
+    MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_);
   }
-  s.wait();
 
   // bias are always in format_tag::ldgo
   this->bias_ = fwd.bias_;
@@ -687,18 +706,17 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
         {fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype), 
format_tag::tnc}));
   }
 
-  initialized_ = true;
+  if (!is_training) initialized_ = true;
 }
 
 void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining& 
fwd) {
   using memory = mkldnn::memory;
   auto& cpu_engine = CpuEngine::Get()->get_engine();
-  auto s = mkldnn::stream(cpu_engine);
 
-  if (this->weights_layer_ == nullptr)
+  if (this->weights_layer_ == nullptr || this-> weights_iter_ == nullptr) {
     this->weights_layer_ = mkldnn_shared_mem_t(new 
memory(bwd_.weights_layer_desc_, cpu_engine));
-  if (this->weights_iter_ == nullptr)
     this->weights_iter_ = mkldnn_shared_mem_t(new 
memory(bwd_.weights_iter_desc_, cpu_engine));
+  }
 
   for (auto& kv : fwd.net_args_) {
     const mkldnn::memory* valid_mem;
@@ -707,17 +725,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const 
MKLDNNRnnForwardTraining& fwd)
         if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
           this->weights_layer_->set_data_handle(kv.second.get_data_handle());
         } else {
-          mkldnn::reorder(*fwd.weights_layer_, *this->weights_layer_)
-              .execute(s, *fwd.weights_layer_, *this->weights_layer_);
+          MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_);
         }
         valid_mem = this->weights_layer_.get();
       } break;
       case MKLDNN_ARG_WEIGHTS_ITER: {
-        if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
+        if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) {
           this->weights_iter_->set_data_handle(kv.second.get_data_handle());
         } else {
-          mkldnn::reorder(*fwd.weights_iter_, *this->weights_iter_)
-              .execute(s, *fwd.weights_iter_, *this->weights_iter_);
+          MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_);
         }
         valid_mem = this->weights_iter_.get();
       } break;
@@ -727,20 +743,50 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const 
MKLDNNRnnForwardTraining& fwd)
     }
     EmplaceNetArgs(&this->net_args_, kv.first, valid_mem);
   }
-  s.wait();
 }
 
 void MKLDNNRnnBackward::SetWeightsGradsMem() {
-  auto& cpu_engine = CpuEngine::Get()->get_engine();
-  if (this->diff_weights_layer_ == nullptr)
-    this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
-        bwd_.diff_weights_layer_desc_, cpu_engine);
-  if (this->diff_weights_iter_ == nullptr)
-    this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
-        bwd_.diff_weights_iter_desc_, cpu_engine);
-  if (this->diff_bias_ == nullptr)
+  using tag = mkldnn::memory::format_tag;
+
+  if (this->diff_weights_layer_ == nullptr
+      || this->diff_weights_iter_ == nullptr
+      || this->diff_bias_ == nullptr) {
+    const auto& cpu_engine = CpuEngine::Get()->get_engine();
+    const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam();
+    const auto mkldnn_type = static_cast<mkldnn::memory::data_type>(
+        bwd_.diff_weights_layer_desc_.data.data_type);
+
+    auto native_layer_desc = mkldnn::memory::desc(param.weight_layer_dims, 
mkldnn_type, tag::ldgoi);
+    auto native_iter_desc = mkldnn::memory::desc(param.weight_iter_dims, 
mkldnn_type, tag::ldgoi);
+
+    this->diff_weights_layer_r_ = std::make_shared<mkldnn::memory>(
+        native_layer_desc, cpu_engine);
+    this->diff_weights_iter_r_ = std::make_shared<mkldnn::memory>(
+        native_iter_desc, cpu_engine);
+
+    if (native_layer_desc == bwd_.diff_weights_layer_desc_) {
+      this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
+          bwd_.diff_weights_layer_desc_, cpu_engine, 
diff_weights_layer_r_->get_data_handle());
+    } else {
+      this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
+          bwd_.diff_weights_layer_desc_, cpu_engine);
+    }
+    if (native_iter_desc == bwd_.diff_weights_iter_desc_) {
+      this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
+          bwd_.diff_weights_iter_desc_, cpu_engine, 
diff_weights_iter_r_->get_data_handle());
+    } else {
+      this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
+          bwd_.diff_weights_iter_desc_, cpu_engine);
+    }
     this->diff_bias_ = std::make_shared<mkldnn::memory>(
         bwd_.diff_bias_desc_, cpu_engine);
+  }
+  std::memset(this->diff_weights_layer_->get_data_handle(), 0,
+      bwd_.diff_weights_layer_desc_.get_size());
+  std::memset(this->diff_weights_iter_->get_data_handle(), 0,
+      bwd_.diff_weights_iter_desc_.get_size());
+  std::memset(this->diff_bias_->get_data_handle(), 0,
+      bwd_.diff_bias_desc_.get_size());
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
       this->diff_weights_layer_.get());
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_ITER,
@@ -776,30 +822,40 @@ void MKLDNNRnnBackward::SetDataGradsMem(
   }
 }
 
-template <typename DType>
-void HalveWeightsDiff(DType* w, const size_t size) {
-  const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-  #pragma omp parallel for num_threads(omp_threads)
-  for (int i = 0; i < static_cast<int>(size); ++i) {
-    w[i] *= 0.5;
+void MKLDNNRnnBackward::SetNativeWeightsGrads() const {
+  if (this->diff_weights_layer_->get_desc() != 
this->diff_weights_layer_r_->get_desc()) {
+    MKLDNNMemoryReorder(*this->diff_weights_layer_, 
*this->diff_weights_layer_r_);
+  }
+  if (this->diff_weights_iter_->get_desc() != 
this->diff_weights_iter_r_->get_desc()) {
+    MKLDNNMemoryReorder(*this->diff_weights_iter_, 
*this->diff_weights_iter_r_);
   }
 }
 
-void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias, 
int dtype) {
-  using tag = mkldnn::memory::format_tag;
-  auto& cpu_engine = CpuEngine::Get()->get_engine();
-  auto s = mkldnn::stream(cpu_engine);
+#define OPREQTYPE_SWITCH(ReqType, DType, FWrapper, ...)                     \
+std::function<void(DType*, DType*, size_t)> FWrapper = nullptr;             \
+if (kWriteTo == ReqType || kWriteInplace == ReqType)                        \
+  FWrapper = common::ParallelCopy<DType>;                                   \
+else                                                                        \
+  FWrapper = common::ParallelAdd<DType>;                                    \
+{__VA_ARGS__}
 
+void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
+                                           const OpReqType req, const int 
dtype) {
   const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam();
+
+  void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle();
+  void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle();
+  if (this->diff_weights_layer_->get_desc() != 
this->diff_weights_layer_r_->get_desc())
+    diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle();
+  if (this->diff_weights_iter_->get_desc() != 
this->diff_weights_iter_r_->get_desc())
+    diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle();
+
   const int num_layer = param.num_layer;
   const int direction = param.bidirectional ? 2 : 1;
   const int ngates = GetRnnGatesNum(param.mode);
-  const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
-  const size_t wxh_bytes = param.single_w_size * dtype_bytes;
-  const size_t wx_bytes = param.input_size * param.state_size * ngates * 
dtype_bytes;
-  const size_t wh_bytes = param.state_size * param.state_size * ngates * 
dtype_bytes;
-  char* diff_wx_ptr = static_cast<char 
*>(diff_weights_layer_->get_data_handle());
-  char* diff_wh_ptr = static_cast<char 
*>(diff_weights_iter_->get_data_handle());
+  const size_t wxh_size = param.single_w_size;
+  const size_t wx_size = param.input_size * param.state_size * ngates;
+  const size_t wh_size = param.state_size * param.state_size * ngates;
 
   /* naive weights layout is:
           1st-layer: | wx_lr  | wh_lr  | wx_rl | wh_rl |
@@ -807,82 +863,81 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void* 
diff_weights, void* diff_bias, i
   size:              |    wxh_bytes    |
                      |wx_bytes|wh_bytes|      
   */
-  char* naive_weights = static_cast<char *>(diff_weights);
-  if (param.mode != rnn_enum::kGru) {
-    for (int shift = 0; shift < num_layer * direction; ++shift) {
-      std::memcpy(naive_weights + shift * wxh_bytes,
-          diff_wx_ptr + shift * wx_bytes, wx_bytes);
-    }
-    // align naive_weights to weights_iter memory
-    naive_weights += wx_bytes;
-    for (int shift = 0; shift < num_layer * direction; ++shift) {
-      std::memcpy(naive_weights + shift * wxh_bytes,
-          diff_wh_ptr + shift * wh_bytes, wh_bytes);
-    }
-  } else {
-    const size_t wx_bytes_per_gate = param.input_size * param.state_size * 
dtype_bytes;
-    const size_t wh_bytes_per_gate = param.state_size * param.state_size * 
dtype_bytes;
-    for (int shift = 0; shift < num_layer * direction; ++shift) {
-      std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate,
-          diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate);
-      std::memcpy(naive_weights + shift * wxh_bytes,
-          diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate, 
wx_bytes_per_gate);
-      std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate,
-          diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate, 
wx_bytes_per_gate);
-    }
-    // align naive_weights to weights_iter memory
-    naive_weights += wx_bytes;
-    for (int shift = 0; shift < num_layer * direction; ++shift) {
-      std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate,
-          diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate);
-      std::memcpy(naive_weights + shift * wxh_bytes,
-          diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate, 
wh_bytes_per_gate);
-      std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate,
-          diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate, 
wh_bytes_per_gate);
-    }
-  }
-
-  char* naive_bias = static_cast<char *>(diff_bias);
-  char* diff_bias_ptr = static_cast<char 
*>(this->diff_bias_->get_data_handle());
-  const size_t bias_bytes = param.single_b_size * dtype_bytes;
-  const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes;
-  if (param.mode != rnn_enum::kGru) {
-    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-      DType* typed_bias = reinterpret_cast<DType *>(diff_bias_ptr);
-      HalveWeightsDiff(typed_bias, num_layer * direction * 
param.single_b_size);
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    DType* native_weights = static_cast<DType *>(diff_weights);
+    DType* diff_wx_ptr = static_cast<DType *>(diff_weights_layer_ptr);
+    DType* diff_wh_ptr = static_cast<DType *>(diff_weights_iter_ptr);
+    OPREQTYPE_SWITCH(req, DType, FAccGrad, {
+      if (param.mode != rnn_enum::kGru) {
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          FAccGrad(native_weights + shift * wxh_size, diff_wx_ptr + shift * 
wx_size, wx_size);
+        }
+        // align native_weights to weights_iter memory
+        native_weights += wx_size;
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          FAccGrad(native_weights + shift * wxh_size, diff_wh_ptr + shift * 
wh_size, wh_size);
+        }
+      } else {
+        const size_t wx_size_per_gate = param.input_size * param.state_size;
+        const size_t wh_size_per_gate = param.state_size * param.state_size;
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          FAccGrad(native_weights + shift * wxh_size + wx_size_per_gate,
+              diff_wx_ptr + shift * wx_size, wx_size_per_gate);
+          FAccGrad(native_weights + shift * wxh_size,
+              diff_wx_ptr + shift * wx_size + wx_size_per_gate, 
wx_size_per_gate);
+          FAccGrad(native_weights + shift * wxh_size + 2 * wx_size_per_gate,
+              diff_wx_ptr + shift * wx_size + 2 * wx_size_per_gate, 
wx_size_per_gate);
+        }
+        // align native_weights to weights_iter memory
+        native_weights += wx_size;
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          FAccGrad(native_weights + shift * wxh_size + wh_size_per_gate,
+              diff_wh_ptr + shift * wh_size, wh_size_per_gate);
+          FAccGrad(native_weights + shift * wxh_size,
+              diff_wh_ptr + shift * wh_size + wh_size_per_gate, 
wh_size_per_gate);
+          FAccGrad(native_weights + shift * wxh_size + 2 * wh_size_per_gate,
+              diff_wh_ptr + shift * wh_size + 2 * wh_size_per_gate, 
wh_size_per_gate);
+        }
+      }
     });
-    for (int shift = 0; shift < num_layer * direction; ++shift) {
-      std::memcpy(naive_bias + shift * naive_bias_bytes,
-          diff_bias_ptr + shift * bias_bytes, bias_bytes);
-      std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes,
-          diff_bias_ptr + shift * bias_bytes, bias_bytes);
-    }
-  } else {
-    const size_t bias_bytes_per_gate = param.state_size * dtype_bytes;
-    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-      for (int shift = 0; shift < num_layer * direction; ++shift) {
-        char* naive_reset = naive_bias + shift * naive_bias_bytes;
-        char* naive_update = naive_reset + bias_bytes_per_gate;
-        char* update = diff_bias_ptr + shift * bias_bytes;
-        char* reset = update + bias_bytes_per_gate;
-
-        DType* typed_update = reinterpret_cast<DType *>(update);
-        HalveWeightsDiff(typed_update, param.state_size * 2);
-
-        std::memcpy(naive_update, update, bias_bytes_per_gate);
-        std::memcpy(naive_reset, reset, bias_bytes_per_gate);
-        std::memcpy(naive_update + naive_bias_bytes / 2, update, 
bias_bytes_per_gate);
-        std::memcpy(naive_reset + naive_bias_bytes / 2, reset, 
bias_bytes_per_gate);
-
-        char* naive_new_bx = naive_update + bias_bytes_per_gate;
-        char* naive_new_bh = naive_new_bx + naive_bias_bytes / 2;
-        char* new_bx = reset + bias_bytes_per_gate;
-        char* new_bh = new_bx + bias_bytes_per_gate;
-        std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate);
-        std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate);
+  });
+
+  const size_t bias_size = param.single_b_size;
+  const size_t naive_bias_size = param.naive_single_b_size;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    DType* native_bias = static_cast<DType *>(diff_bias);
+    DType* diff_bias_ptr = static_cast<DType 
*>(this->diff_bias_->get_data_handle());
+    OPREQTYPE_SWITCH(req, DType, FAccGrad, {
+      if (param.mode != rnn_enum::kGru) {
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          FAccGrad(native_bias + shift * naive_bias_size,
+              diff_bias_ptr + shift * bias_size, bias_size);
+          FAccGrad(native_bias + shift * naive_bias_size + bias_size,
+              diff_bias_ptr + shift * bias_size, bias_size);
+        }
+      } else {
+        const size_t bias_size_per_gate = param.state_size;
+        for (int shift = 0; shift < num_layer * direction; ++shift) {
+          DType* native_reset = native_bias + shift * naive_bias_size;
+          DType* native_update = native_reset + bias_size_per_gate;
+          DType* update = diff_bias_ptr + shift * bias_size;
+          DType* reset = update + bias_size_per_gate;
+
+          FAccGrad(native_update, update, bias_size_per_gate);
+          FAccGrad(native_reset, reset, bias_size_per_gate);
+          FAccGrad(native_update + naive_bias_size / 2, update, 
bias_size_per_gate);
+          FAccGrad(native_reset + naive_bias_size / 2, reset, 
bias_size_per_gate);
+
+          DType* native_new_bx = native_update + bias_size_per_gate;
+          DType* native_new_bh = native_new_bx + naive_bias_size / 2;
+          DType* new_bx = reset + bias_size_per_gate;
+          DType* new_bh = new_bx + bias_size_per_gate;
+          FAccGrad(native_new_bx, new_bx, bias_size_per_gate);
+          FAccGrad(native_new_bh, new_bh, bias_size_per_gate);
+        }
       }
     });
-  }
+  });
 }
 
 template <typename MKLDNNRnnX>
@@ -893,25 +948,18 @@ inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) {
 template <>
 inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) {
   MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap());
+  rnn.SetNativeWeightsGrads();
 }
 
 void MKLDNNRnnOp::Forward(const OpContext &ctx,
                           const std::vector<NDArray> &inputs,
                           const std::vector<OpReqType> &req,
                           const std::vector<NDArray> &outputs) {
+  TmpMemMgr::Get()->Init(ctx.requested[0]);
   // In the `autograd.record()` context, RNNOp is required to run into
   // forward_training mode.
   const bool is_training = (ctx.is_train || ctx.need_grad);
-  // check output requests
-  if (kAddTo == req[rnn_enum::kOut])
-    LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
   const RNNParam& default_param = full_param_.default_param;
-  if (default_param.state_outputs) {
-    if (kAddTo == req[rnn_enum::kStateOut])
-      LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
-    if (default_param.mode == rnn_enum::kLstm && kAddTo == 
req[rnn_enum::kStateCellOut])
-      LOG(FATAL) << "Currently, `add` operation against lstm-cell output is 
not supported.";
-  }
 
   // Initialize weights version
   if (!initialized_ && weights_version_ == 0) {
@@ -919,8 +967,8 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
   }
 
   // Check if weights NDArray was changed. If so, reset initialized_
-  if (weights_version_ != inputs[rnn_enum::kParams].version() &&
-      fwd_inf_vec_.size() > 0) {
+  if (!is_training && fwd_inf_vec_.size() > 0
+      && weights_version_ != inputs[rnn_enum::kParams].version()) {
     initialized_ = false;
     for (auto& fwd : fwd_inf_vec_) fwd.Reset();
     weights_version_ = inputs[rnn_enum::kParams].version();
@@ -932,24 +980,40 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
 
   // Get data type
   int data_dtype = inputs[rnn_enum::kData].dtype();
+  // Get temporary memory for output, state_out, statecell_out
+  const int num_layers = default_param.num_layers;
+  const int seq_length = default_param.seq_length_;
+  const int batch_size = default_param.batch_size_;
+  const int state_size = default_param.state_size;
+  const int directions = default_param.bidirectional ? 2 : 1;
+  mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * 
state_size},
+      get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc);
+  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, 
state_size},
+      get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
+  auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, 
req[rnn_enum::kOut]);
+  mkldnn_output_t stateout_mem;
+  mkldnn_output_t statecellout_mem;
 
   // Get input & output NDArray
   char *src = static_cast<char *>(inputs[rnn_enum::kData].data().dptr_);
   char *src_state = static_cast<char *>(inputs[rnn_enum::kState].data().dptr_);
-  char *dst = req[rnn_enum::kOut] == kNullOp ? nullptr
-      : static_cast<char *>(outputs[rnn_enum::kOut].data().dptr_);
+  char *dst = static_cast<char *>(out_mem.second->get_data_handle());
   char *dst_state = nullptr;          // Output state
   char *src_state_cell = nullptr;     // Used in LSTM for cell state
   char *dst_state_cell = nullptr;     // Used in LSTM for cell state
 
   if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) {
-    dst_state = static_cast<char *>(outputs[rnn_enum::kStateOut].data().dptr_);
+    stateout_mem = CreateMKLDNNMem(
+        outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]);
+    dst_state = static_cast<char *>(stateout_mem.second->get_data_handle());
   }
 
   if (default_param.mode == rnn_enum::kLstm) {
     src_state_cell = static_cast<char 
*>(inputs[rnn_enum::kStateCell].data().dptr_);
     if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != 
kNullOp) {
-      dst_state_cell = static_cast<char 
*>(outputs[rnn_enum::kStateCellOut].data().dptr_);
+      statecellout_mem = CreateMKLDNNMem(
+          outputs[rnn_enum::kStateCellOut], state_desc, 
req[rnn_enum::kStateCellOut]);
+      dst_state_cell = static_cast<char 
*>(statecellout_mem.second->get_data_handle());
     }
   }
 
@@ -1000,6 +1064,12 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
   } else {
     for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr);
   }
+  CommitOutput(outputs[rnn_enum::kOut], out_mem);
+  if (default_param.state_outputs) {
+    CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem);
+    if (default_param.mode == rnn_enum::kLstm)
+      CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem);
+  }
   MKLDNNStream::Get()->Submit();
 }
 
@@ -1008,18 +1078,11 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
                            const std::vector<OpReqType>& req,
                            const std::vector<NDArray>& outputs) {
   using tag = mkldnn::memory::format_tag;
+  TmpMemMgr::Get()->Init(ctx.requested[0]);
   const RNNParam& default_param = full_param_.default_param;
-  if (kAddTo == req[rnn_enum::kData] || kAddTo == req[rnn_enum::kParams])
-    LOG(FATAL) << "Currently, `add` operations against gradients of input and 
weights"
-               << " are not supported by RNNs.";
-  if (default_param.state_outputs) {
-    if (kAddTo == req[rnn_enum::kStateOut])
-      LOG(FATAL) << "Currently, `add` operation against gradients of begining 
state"
-                 << " is not supported by RNNs.";
-    if (default_param.mode == rnn_enum::kLstm && req[rnn_enum::kStateCell])
-      LOG(FATAL) << "Currently, `add` operation against gradients of begining 
cell-state"
-                 << " is not supported by LSTM.";
-  }
+  const int data_dtype = inputs[rnn_enum::kData].dtype();
+  const int w_dtype = inputs[rnn_enum::kParams].dtype();
+
   // Initialize the bwd_vec_
   if (bwd_vec_.size() != fwd_inf_vec_.size()) {
     bwd_vec_.clear();
@@ -1035,24 +1098,39 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
     bwd_vec_.at(lyr).SetWeightsGradsMem();
   }
 
-  const int data_dtype = inputs[rnn_enum::kData].dtype();
-  const int w_dtype = inputs[rnn_enum::kParams].dtype();
   const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype);
+  // Get temporary memory for diff_src, diff_state, diff_statecell
+  const int num_layers = default_param.num_layers;
+  const int seq_length = default_param.seq_length_;
+  const int batch_size = default_param.batch_size_;
+  const int input_size = default_param.input_size_;
+  const int state_size = default_param.state_size;
+  const int directions = default_param.bidirectional ? 2 : 1;
+  mkldnn::memory::desc src_desc({seq_length, batch_size, input_size},
+      get_mkldnn_type(data_dtype), tag::tnc);
+  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, 
state_size},
+      get_mkldnn_type(data_dtype), tag::ldnc);
+  auto diff_input_mem = CreateMKLDNNMem(outputs[rnn_enum::kData], src_desc, 
req[rnn_enum::kData]);
+  mkldnn_output_t diff_state_mem;
+  mkldnn_output_t diff_statecell_mem;
   // index description of outputs NDArray
   //   0    1    2     3
   // | dx | dw | dhx | dcx|
-  char* dx = req[rnn_enum::kData] == kNullOp ? nullptr
-      : static_cast<char *>(outputs[rnn_enum::kData].data().dptr_);
+  char* dx = static_cast<char *>(diff_input_mem.second->get_data_handle());
   char* dw = static_cast<char *>(outputs[rnn_enum::kParams].data().dptr_);
   char* db = dw + (inputs[rnn_enum::kParams].data().Size() -
       GetRnnBiasSize(default_param.num_layers, default_param.state_size,
         default_param.bidirectional + 1, default_param.mode)) * w_bytes;
-  char* dhx = req[rnn_enum::kState] == kNullOp ? nullptr
-      : static_cast<char *>(outputs[rnn_enum::kState].data().dptr_);
+  diff_state_mem = CreateMKLDNNMem(
+      outputs[rnn_enum::kState], state_desc, req[rnn_enum::kState]);
+  char* dhx = static_cast<char *>(diff_state_mem.second->get_data_handle());
   char* dcx = nullptr;
   if (full_param_.default_param.mode == rnn_enum::kLstm
-      && req[rnn_enum::kStateCell] != kNullOp)
-    dcx = static_cast<char *>(outputs[rnn_enum::kStateCell].data().dptr_);
+      && req[rnn_enum::kStateCell] != kNullOp) {
+    diff_statecell_mem = CreateMKLDNNMem(
+        outputs[rnn_enum::kStateCell], state_desc, req[rnn_enum::kStateCell]);
+    dcx = static_cast<char *>(diff_statecell_mem.second->get_data_handle());
+  }
 
   // index description of inputs NDArray
   //   0   1   2    3   4    5     6    7    8     9
@@ -1100,12 +1178,16 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
       RegisterMKLDNNRnn(*bwd);
     }
   }
+  CommitOutput(outputs[rnn_enum::kData], diff_input_mem);
+  CommitOutput(outputs[rnn_enum::kState], diff_state_mem);
+  if (full_param_.default_param.mode == rnn_enum::kLstm)
+    CommitOutput(outputs[rnn_enum::kStateCell], diff_statecell_mem);
   MKLDNNStream::Get()->Submit();
 
   // Commit weights diff
   if (req[rnn_enum::kParams] != kNullOp) {
     for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) {
-      bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, w_dtype);
+      bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams], 
w_dtype);
       dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes;
       db += full_param_.layer_params.at(lyr).single_b_size * w_bytes;
     }
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index c23a5a8..c3cb5c8 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -561,6 +561,24 @@ class OpSignature {
       case mkldnn_format_kind_rnn_packed:
         hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format;
         eles.push_back(desc.data.format_desc.rnn_packed_desc.format);
+        hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n_parts;
+        eles.push_back(desc.data.format_desc.rnn_packed_desc.n_parts);
+        for (int i = 0; i < desc.data.format_desc.rnn_packed_desc.n_parts; 
++i) {
+          hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.parts[i];
+          hash = hash * 2 + 
desc.data.format_desc.rnn_packed_desc.part_pack_size[i];
+          hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.pack_part[i];
+          eles.push_back(desc.data.format_desc.rnn_packed_desc.parts[i]);
+          
eles.push_back(desc.data.format_desc.rnn_packed_desc.part_pack_size[i]);
+          eles.push_back(desc.data.format_desc.rnn_packed_desc.pack_part[i]);
+        }
+        hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n;
+        hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.ldb;
+        hash = hash * 2 + 
desc.data.format_desc.rnn_packed_desc.offset_compensation;
+        hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.size;
+        eles.push_back(desc.data.format_desc.rnn_packed_desc.n);
+        eles.push_back(desc.data.format_desc.rnn_packed_desc.ldb);
+        
eles.push_back(desc.data.format_desc.rnn_packed_desc.offset_compensation);
+        eles.push_back(desc.data.format_desc.rnn_packed_desc.size);
         break;
       default:
       // nothing need to add
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 6d568c8..a8e1b12 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -182,6 +182,10 @@ static std::vector<ResourceRequest> RNNResourceEx(const 
NodeAttrs& attrs, const
       request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
     }
 #endif
+  } else {
+#if MXNET_USE_MKLDNN == 1
+    request.emplace_back(ResourceRequest::kTempSpace);
+#endif
   }
   return request;
 }
@@ -243,7 +247,8 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs 
&attrs,
 
 #if MXNET_USE_MKLDNN == 1
   if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
-      && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU) {
+      && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
+      && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
     const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
     state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
         data_shape[1], data_shape[2]);
@@ -269,8 +274,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& 
state_ptr,
                                     const std::vector<NDArray>& inputs,
                                     const std::vector<OpReqType>& req,
                                     const std::vector<NDArray>& outputs) {
-  if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == 
mshadow::kFloat16) &&
-      inputs[0].shape().ndim() == 3) {
+  if (SupportMKLDNNRnn(inputs[0])) {
     MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
     op.Forward(ctx, inputs, req, outputs);
   } else {
@@ -283,8 +287,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& 
state_ptr,
                                         const std::vector<NDArray>& inputs,
                                         const std::vector<OpReqType>& req,
                                         const std::vector<NDArray>& outputs) {
-  if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() == 
mshadow::kFloat16) &&
-      inputs[0].shape().ndim() == 3) {
+  if (SupportMKLDNNRnn(inputs[0])) {
     MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
     op.Backward(ctx, inputs, req, outputs);
   } else {
diff --git a/tests/python/unittest/test_gluon_rnn.py 
b/tests/python/unittest/test_gluon_rnn.py
index 309756b..0f27f53 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -19,6 +19,8 @@ import mxnet as mx
 from mxnet import gluon, nd
 import numpy as np
 import copy
+from itertools import product
+from functools import partial
 from numpy.testing import assert_allclose
 import unittest
 from mxnet.test_utils import almost_equal, assert_almost_equal
@@ -545,6 +547,128 @@ def test_rnn_layers_fp16():
     run_rnn_layers('float16', 'float32', mx.gpu())
 
 
+def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4):
+    fused_begin_state = fused_layer.begin_state(1)
+    stack_state = stack_layer.begin_state(batch_size=1)
+    x = nd.random.normal(shape=(1, 5, input_size))
+    x.attach_grad()
+    y = nd.random.normal(shape=(1, 5, hidden_size * 2 if bidirectional else 
hidden_size))
+
+    with mx.autograd.record():
+        fused_out, fused_state = fused_layer(x, fused_begin_state)
+        l = loss(fused_out, y).mean()
+    l.backward()
+    fused_grads = dict([(name, p.grad()) for name, p in 
fused_layer.collect_params().items()])
+    fused_input_grad = x.grad.asnumpy()
+
+    with mx.autograd.record():
+        stack_out, stack_state = stack_layer.unroll(5, x, stack_state, 
merge_outputs=True)
+        l = loss(stack_out, y).mean()
+    l.backward()
+    stack_grads = dict([(name, p.grad()) for name, p in 
stack_layer.collect_params().items()])
+    stack_input_grad = x.grad.asnumpy()
+
+    assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, 
atol=atol)
+    assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
+    for key, value in fused_grads.items():
+        assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), 
rtol=rtol, atol=atol)
+
+
+def create_op_by_mode(mode):
+    if mode == 'lstm':
+        fused_op = gluon.rnn.LSTM
+        stack_op = gluon.rnn.LSTMCell
+        recurrent_block_prefix = 'lstm0_'
+    elif mode == 'gru':
+        fused_op = gluon.rnn.GRU
+        stack_op = gluon.rnn.GRUCell
+        recurrent_block_prefix = 'gru0_'
+    elif mode == 'rnn_relu':
+        fused_op = partial(gluon.rnn.RNN, activation='relu')
+        stack_op = partial(gluon.rnn.RNNCell, activation='relu')
+        recurrent_block_prefix = 'rnn0_'
+    elif mode == 'rnn_tanh':
+        fused_op = partial(gluon.rnn.RNN, activation='tanh')
+        stack_op = partial(gluon.rnn.RNNCell, activation='tanh')
+        recurrent_block_prefix = 'rnn0_'
+
+    return fused_op, stack_op, recurrent_block_prefix
+
+
+def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss):
+    fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
+    # ==== Single layer ====
+    fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', 
bidirectional=False)
+    fused_layer.collect_params().initialize()
+
+    params = fused_layer.collect_params()
+    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, 
params=params)
+    with stack_layer.name_scope():
+        stack_layer.add(stack_op(hidden_size, prefix='l0_'))
+    stack_layer.initialize()
+
+    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size)
+
+    # ==== Multiple layer ====
+    fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', 
bidirectional=False)
+    fused_layer.collect_params().initialize()
+
+    params = fused_layer.collect_params()
+    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, 
params=params)
+    with stack_layer.name_scope():
+        stack_layer.add(stack_op(hidden_size, prefix='l0_'))
+        stack_layer.add(stack_op(hidden_size, prefix='l1_'))
+        stack_layer.add(stack_op(hidden_size, prefix='l2_'))
+    stack_layer.initialize()
+
+    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size)
+
+
+def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
+    fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
+    # ==== Single layer ====
+    fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', 
bidirectional=True)
+    fused_layer.collect_params().initialize()
+
+    params = fused_layer.collect_params()
+    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, 
params=params)
+    with stack_layer.name_scope():
+        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l0_'),
+                                                    stack_op(hidden_size, 
prefix='r0_')))
+    stack_layer.initialize()
+
+    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size, bidirectional=True)
+
+    # ==== Multiple layer ====
+    fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', 
bidirectional=True)
+    fused_layer.collect_params().initialize()
+
+    params = fused_layer.collect_params()
+    stack_layer = 
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix, 
params=params)
+    with stack_layer.name_scope():
+        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l0_'),
+                                                    stack_op(hidden_size, 
prefix='r0_')))
+        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l1_'),
+                                                    stack_op(hidden_size, 
prefix='r1_')))
+        stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, 
prefix='l2_'),
+                                                    stack_op(hidden_size, 
prefix='r2_')))
+    stack_layer.initialize()
+
+    check_rnn_consistency(fused_layer, stack_layer, loss, input_size, 
hidden_size, bidirectional=True)
+
+
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_fused_rnn_layer():
+    input_sizes = [128]
+    hidden_sizes = [128, 256]
+    modes = ['lstm', 'gru', 'rnn_relu', 'rnn_tanh']
+    # single layer
+    for mode, input_size, hidden_size in product(modes, input_sizes, 
hidden_sizes):
+        loss = mx.gluon.loss.L2Loss()
+        check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss)
+        check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss)
+
+
 def test_rnn_unroll_variant_length():
     # Test for imperative usage
     cell_list = []
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 6fdb3c8..9ae35f1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -36,15 +36,6 @@ import unittest
 import os
 
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, 
atol=1e-4):
-    if default_context().device_type == 'cpu':
-    # NOTE(zixuanweeei): Currently, we don't add `add` requests support on 
fused mkl-dnn rnn operator.
-    # We tracked this issue by 
https://github.com/apache/incubator-mxnet/issues/16578
-        if isinstance(grad_req, dict) and 'add' in grad_req.values():
-            print("Skip the test when requiring `add` operation against 
gradients on CPU context.")
-            return
-        if isinstance(grad_req, str) and grad_req == 'add':
-            print("Skip the test when requiring `add` operation against 
gradients on CPU context.")
-            return
     dshape = (N, T, I)
     data = mx.sym.Variable('data')
 
@@ -182,9 +173,9 @@ def test_gru_sym():
         stack.add(mx.rnn.GRUCell(H, prefix='l1_'))
         stack.add(mx.rnn.GRUCell(H, prefix='l2_'))
 
-        check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4)
-        check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4)
-        check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4)
+        check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+        check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+        check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
 @assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@@ -208,9 +199,9 @@ def test_gru_bidirectional():
                     mx.rnn.GRUCell(H, prefix='r1_'),
                     output_prefix='bi_gru_1_'))
 
-        check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4)
-        check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4)
-        check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4)
+        check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+        check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+        check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
 @assert_raises_cudnn_not_satisfied(min_version='5.1.10')

Reply via email to