This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f94c9a  [MXNET-264] Improve performance of MKLDNN in small batch 
sizes. (#10317)
5f94c9a is described below

commit 5f94c9aa320703432b04c692931aae021604e2b9
Author: Da Zheng <zhengda1...@gmail.com>
AuthorDate: Tue Apr 10 15:33:34 2018 -0700

    [MXNET-264] Improve performance of MKLDNN in small batch sizes. (#10317)
    
    * Create MKLDNNMemory to cache metadata.
    
    * Fix lint error.
    
    * Cache concat.
    
    * Fix a bug in NDArray.
    
    * improve hashing.
    
    * don't use omp for gamma and beta in batchnorm.
    
    * address the comments.
    
    * Avoid computing out mean&var in batchnorm.
    
    * Cache LRN.
    
    * Fix a bug in LRN.
    
    * Fix lint error.
    
    * Revert "Avoid computing out mean&var in batchnorm."
    
    This reverts commit 71d0dec0a8bae90a9dc53e08ed1297eede34afa1.
    
    * remove more omp in batchnorm.
    
    * add comments for MKLDNNMemory.
    
    * Revert "improve hashing."
    
    This reverts commit 58854be7e82a65383f3a9989f358e26a85765a17.
    
    * Remove unnecessary TODO.
    
    * address comments.
    
    * Remove additional auto.
    
    * Fix compile error.
    
    * remove more auto.
---
 include/mxnet/ndarray.h                        |   3 +-
 src/ndarray/ndarray.cc                         | 178 +++++++++++--------------
 src/operator/nn/lrn-inl.h                      |  27 ++++
 src/operator/nn/mkldnn/mkldnn_base-inl.h       |  95 ++++++++++++-
 src/operator/nn/mkldnn/mkldnn_base.cc          |   2 +-
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h |  13 +-
 src/operator/nn/mkldnn/mkldnn_concat.cc        |  84 ++++++++++--
 src/operator/nn/mkldnn/mkldnn_lrn-inl.h        | 154 ++++++++++++++++-----
 8 files changed, 399 insertions(+), 157 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index cd9004f..e172279 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -74,6 +74,7 @@ enum NDArrayFormatErr {
   kRSPIdxErr,     // indices error for row sparse
 };
 
+class MKLDNNMemory;
 
 /*!
  * \brief ndarray interface
@@ -730,7 +731,7 @@ class NDArray {
 #if MXNET_USE_MKLDNN == 1
     /*! This is created when data is stored in MKLDNN format.
      */
-    std::shared_ptr<mkldnn::memory> mkl_mem_;
+    std::shared_ptr<MKLDNNMemory> mkl_mem_;
 #endif
     /*! \brief variable from engine */
     Engine::VarHandle var;
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 8745093..d175a13 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -100,7 +100,7 @@ struct ChunkMem {
   Storage::Handle h;
   std::vector<Storage::Handle> aux_h;
 #if MXNET_USE_MKLDNN == 1
-  std::shared_ptr<mkldnn::memory> mem;
+  std::shared_ptr<MKLDNNMemory> mem;
 #endif
 };
 
@@ -117,8 +117,8 @@ NDArray::Chunk::~Chunk() {
     if (skip_free == false) {
 #if MXNET_USE_MKLDNN == 1
       if (mem.mem) {
-        CHECK_LE(mem.mem->get_primitive_desc().get_size(), mem.h.size);
-        CHECK_EQ(mem.mem->get_data_handle(), mem.h.dptr);
+        CHECK_LE(mem.mem->GetSize(), mem.h.size);
+        CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr);
       }
 #endif
       if (mem.h.size > 0) Storage::Get()->Free(mem.h);
@@ -181,19 +181,20 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) 
const {
     NDArray ret(shape, ctx(), true, dtype());
     // We shouldn't submit the reorder primitive here because submit will
     // be called in operators.
-    auto format = 
GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
-    CHECK_NE(format, ptr_->mkl_mem_->get_primitive_desc().desc().data.format);
-    auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), 
format);
-    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat();
+    CHECK_NE(format, ptr_->mkl_mem_->GetFormat());
+    mkldnn::memory::primitive_desc def_pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(format);
+    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     MKLDNNStream *stream = MKLDNNStream::Get();
-    stream->RegisterMem(ptr_->mkl_mem_);
-    stream->RegisterPrim(mkldnn::reorder(*ptr_->mkl_mem_, *def_mem));
+    std::shared_ptr<mkldnn::memory> curr_mem = ptr_->mkl_mem_->GetMem();
+    stream->RegisterMem(curr_mem);
+    stream->RegisterPrim(mkldnn::reorder(*curr_mem, *def_mem));
     // def_mem points to a memory region in the temp space. It's only valid
     // inside an operator. As such, the returned NDArray can only be valid
     // inside an operator and the shared point doesn't need to do anything
     // when it's destroyed.
-    ret.ptr_->mkl_mem_ = std::shared_ptr<mkldnn::memory>(def_mem,
-                                                         [](mkldnn::memory 
*mem){});
+    auto tmp = std::shared_ptr<mkldnn::memory>(def_mem, [](mkldnn::memory 
*mem){});
+    ret.ptr_->mkl_mem_.reset(new MKLDNNMemory(tmp));
     ret.ptr_->shandle.dptr = def_mem->get_data_handle();
     ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size();
     ret.ptr_->delay_alloc = false;
@@ -323,27 +324,13 @@ void NDArray::set_fresh_out_grad(bool state) const {
 }
 
 #if MXNET_USE_MKLDNN == 1
-static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int 
ndims) {
-  if (shape.ndim() != (size_t)ndims)
-    return false;
-  for (int i = 0; i < ndims; i++)
-    if (shape[i] != dims[i])
-      return false;
-  return true;
-}
-
-static inline bool same_shape(const TShape &shape, int dtype, 
mkldnn::memory::desc desc) {
-  return same_shape(shape, desc.data.dims, desc.data.ndims)
-      && get_mkldnn_type(dtype) == desc.data.data_type;
-}
 
 bool NDArray::Chunk::IsMKLDNN() const {
   if (storage_type != kDefaultStorage)
     return false;
   if (mkl_mem_ == nullptr)
     return false;
-  auto desc = mkl_mem_->get_primitive_desc().desc();
-  return desc.data.format != GetDefaultFormat(desc);
+  return mkl_mem_->IsMKLDNN();
 }
 
 bool NDArray::Chunk::IsDefault() const {
@@ -353,23 +340,19 @@ bool NDArray::Chunk::IsDefault() const {
   // format.
   if (mkl_mem_ == nullptr)
     return true;
-  auto desc = mkl_mem_->get_primitive_desc().desc();
-  return desc.data.format == GetDefaultFormat(desc);
+  return !mkl_mem_->IsMKLDNN();
 }
 
 void NDArray::Chunk::Reorder2Default() {
   if (mkl_mem_ == nullptr)
     return;
 
-  auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc());
-  CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format);
+  mkldnn_memory_format_t format = mkl_mem_->GetDefaultFormat();
+  CHECK_NE(format, mkl_mem_->GetFormat());
 
-  auto def_pd = GetPrimitiveDesc(mkl_mem_->get_primitive_desc(), format);
+  mkldnn::memory::primitive_desc def_pd = mkl_mem_->GetPrimitiveDesc(format);
   mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd));
-  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
-  std::vector<mkldnn::primitive> net;
-  net.push_back(mkldnn::reorder(*mkl_mem_, *def_mem));
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  mkl_mem_->ReorderTo(def_mem.get());
 
   CHECK(shandle.size >= def_pd.get_size());
   CheckAndAlloc(def_pd.get_size());
@@ -380,11 +363,11 @@ void NDArray::Chunk::Reorder2Default() {
 
 void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc 
&pd) {
   // If the memory already uses the specified layout, don't do anything.
-  if (mkl_mem_ != nullptr && mkl_mem_->get_primitive_desc() == pd)
+  if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(pd))
     return;
-  auto _pd = pd;
-  auto _desc = _pd.desc();
-  auto def_format = GetDefaultFormat(_desc);
+  mkldnn::memory::primitive_desc _pd = pd;
+  mkldnn::memory::desc _desc = _pd.desc();
+  mkldnn_memory_format_t def_format = GetDefaultFormat(_desc);
   // If the memory is default, don't do anything.
   if (def_format == _desc.data.format && IsDefault())
     return;
@@ -397,10 +380,10 @@ void NDArray::Chunk::MKLDNNDataReorder(const 
mkldnn::memory::primitive_desc &pd)
   std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd));
   std::shared_ptr<mkldnn::memory> old_mem;
   if (IsDefault()) {
-    auto def_pd = GetPrimitiveDesc(pd, def_format);
+    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(pd, def_format);
     old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr));
   } else {
-    old_mem = this->mkl_mem_;
+    old_mem = this->mkl_mem_->GetMem();
   }
   CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims);
 
@@ -413,15 +396,15 @@ void NDArray::Chunk::MKLDNNDataReorder(const 
mkldnn::memory::primitive_desc &pd)
   CheckAndAlloc(pd.get_size());
   // TODO(zhengda) We need to avoid memory copy here.
   memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size());
-  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+  mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr));
 }
 
 void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
   // The shape of the array and the one of the MKL memory may mismatch.
   // For example, if the array stores parameters, the MKL memory may store data
   // in 5 dimensions while the NDArray stores data in 4 dimensions.
-  if (mkl_mem_ && mkl_mem_->get_data_handle() == shandle.dptr
-      && same_shape(shape, dtype, mkl_mem_->get_primitive_desc().desc())) {
+  if (mkl_mem_ && mkl_mem_->GetDataHandle() == shandle.dptr
+      && mkl_mem_->SameFormat(shape, dtype)) {
     return;
   }
 
@@ -459,7 +442,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int 
dtype) {
   }
   mkldnn::memory::primitive_desc pd(data_md, cpu_engine);
   CHECK(shandle.size >= pd.get_size());
-  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+  mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr));
 }
 
 /*
@@ -469,7 +452,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int 
dtype) {
  */
 static inline mkldnn::memory *GetMKLDNNExact(
     const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) {
-  auto src_desc = mem->get_primitive_desc();
+  mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc();
   if (desc == src_desc && desc.desc().data.format == 
src_desc.desc().data.format) {
     return const_cast<mkldnn::memory *>(mem);
   } else {
@@ -486,16 +469,16 @@ const mkldnn::memory *NDArray::GetMKLDNNData(
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN 
memory desc";
     return nullptr;
   }
-  auto mem = GetMKLDNNData();
+  const mkldnn::memory *mem = GetMKLDNNData();
   mkldnn::memory::primitive_desc _desc = desc;
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
+  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
+  mkldnn::memory::desc desc2 = _desc.desc();
   // The MKL memory has the same format and shape as required,
   // or both use the default format, we can return the MKL memory.
   if (mem->get_primitive_desc() == desc
       || (desc1.data.format == GetDefaultFormat(desc1)
         && desc2.data.format == GetDefaultFormat(desc2))) {
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   } else {
     return nullptr;
   }
@@ -509,7 +492,7 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   }
   CHECK(storage_type() == kDefaultStorage);
 
-  auto mem = GetMKLDNNData();
+  const mkldnn::memory *mem = GetMKLDNNData();
   // If the memory descriptor matches, it's easy.
   MKLDNNStream *stream = MKLDNNStream::Get();
   if (mem->get_primitive_desc() == desc) {
@@ -519,15 +502,15 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   mkldnn::memory::primitive_desc _desc = desc;
   // Now we need to determine if we should reorder the memory.
   // If both use the default formats, we think we don't need to reorder.
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
+  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
+  mkldnn::memory::desc desc2 = _desc.desc();
   if (desc1.data.format == GetDefaultFormat(desc1) &&
       desc2.data.format == GetDefaultFormat(desc2)) {
     mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle()));
     stream->RegisterMem(ret);
     return ret.get();
   } else {
-    auto ret = TmpMemMgr::Get()->Alloc(desc);
+    mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(desc);
     stream->RegisterPrim(mkldnn::reorder(*mem, *ret));
     return ret;
   }
@@ -538,18 +521,15 @@ NDArray NDArray::Reorder2Default() const {
 
   if (ptr_->mkl_mem_ == nullptr)
     return *this;
-  auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
-  if (format == ptr_->mkl_mem_->get_primitive_desc().desc().data.format)
+  mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat();
+  if (format == ptr_->mkl_mem_->GetFormat())
     return *this;
 
   NDArray ret(shape(), ctx(), false, dtype());
-  auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format);
+  mkldnn::memory::primitive_desc def_pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(format);
   CHECK(ret.ptr_->shandle.size >= def_pd.get_size());
   mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr);
-  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
-  std::vector<mkldnn::primitive> net;
-  net.push_back(mkldnn::reorder(*ptr_->mkl_mem_, def_mem));
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  ptr_->mkl_mem_->ReorderTo(&def_mem);
   return ret;
 }
 
@@ -584,17 +564,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
   if (IsMKLDNNData())
     CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
-  // If shandle has data, the data in shandle and mkl_mem_ should match.
-  if (ptr_->shandle.dptr)
-    CHECK(ptr_->shandle.dptr == ptr_->mkl_mem_->get_data_handle());
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-  auto pd = ptr_->mkl_mem_->get_primitive_desc();
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
   if (IsView()) {
+    mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc();
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
-  }
-  if (IsView()) {
-    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->get_data_handle())
+    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle())
         + byte_offset_;
 
     // Create the primitive desc for the new mkldnn memory.
@@ -612,13 +587,13 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
     MKLDNNStream::Get()->RegisterMem(ret);
     return ret.get();
   } else {
-    return ptr_->mkl_mem_.get();
+    return ptr_->mkl_mem_->GetRaw();
   }
 }
 
 void NDArray::CopyFrom(const mkldnn::memory &mem) {
   CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
-  if (ptr_->mkl_mem_.get() == &mem)
+  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem)
     return;
 
   CHECK(mem.get_primitive_desc().get_size() == shape().Size() * 
GetTypeSize(dtype_))
@@ -630,10 +605,10 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
                   dtype_);
-  stream->RegisterMem(ptr_->mkl_mem_);
-  auto from_desc = mem.get_primitive_desc().desc();
-  auto this_desc = ptr_->mkl_mem_->get_primitive_desc().desc();
-  auto from_def_format = GetDefaultFormat(from_desc);
+  stream->RegisterMem(ptr_->mkl_mem_->GetMem());
+  mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc();
+  mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
+  mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
   if (IsView()) {
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
@@ -652,12 +627,13 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
   } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
     // In this case, the source memory stores data in a customized layout. We
     // need to reorganize the data in memory before we can reshape.
-    auto def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), from_def_format);
-    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    mkldnn::memory::primitive_desc def_pd = 
GetPrimitiveDesc(mem.get_primitive_desc(),
+                                                             from_def_format);
+    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
     // Now we can reshape it
     mkldnn::memory::dims dims(this_desc.data.dims,
@@ -668,32 +644,32 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
-  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->get_primitive_desc()) 
{
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
+  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) {
     // If the layout is the same, we can just copy data.
-    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
+    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
   } else {
-    auto src_def = GetDefaultFormat(mem.get_primitive_desc().desc());
-    auto dst_def = 
GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
+    mkldnn_memory_format_t src_def = 
GetDefaultFormat(mem.get_primitive_desc().desc());
+    mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat();
     // If both are not using the default layouts. There isn't much we can do,
     // other than reorder data layout directly.
-    if (dst_def != ptr_->mkl_mem_->get_primitive_desc().desc().data.format
+    if (dst_def != ptr_->mkl_mem_->GetFormat()
         && src_def != mem.get_primitive_desc().desc().data.format) {
-      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
-    } else if (dst_def == 
ptr_->mkl_mem_->get_primitive_desc().desc().data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
+    } else if (dst_def == ptr_->mkl_mem_->GetFormat()) {
       // If the dest mem uses the default memory layout, we can simply use
       // the default format of the source memory to improve perf of reorder.
-      auto pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), 
src_def);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
ptr_->mkl_mem_->get_data_handle()));
+      mkldnn::memory::primitive_desc pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(src_def);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
ptr_->mkl_mem_->GetDataHandle()));
       stream->RegisterMem(tmp_mem);
       stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
     } else {
       // If the src mem uses the default memory layout, we can use
       // the default format of the source memory to improve perf.
-      auto pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
+      mkldnn::memory::primitive_desc pd = 
GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
       mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
       stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, 
*ptr_->mkl_mem_->GetRaw()));
     }
   }
 }
@@ -710,28 +686,28 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const 
mkldnn::memory::primitive_desc &
   }
 
   mkldnn::memory::primitive_desc _desc = desc;
-  auto required_format = _desc.desc().data.format;
-  auto def_format = GetDefaultFormat(_desc.desc());
+  mkldnn_memory_format_t required_format = _desc.desc().data.format;
+  mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc());
   // If the required format is a default format, we don't need to worry about 
the shape.
   // If the shape isn't the same, it actually implicitly reshapes data.
   if (required_format == def_format) {
     ptr_->SetMKLMem(shape_, dtype_);
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   }
 
   if (ptr_->mkl_mem_)
-    CHECK(ptr_->mkl_mem_->get_data_handle() == ptr_->shandle.dptr);
-  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->get_primitive_desc() == desc) {
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    CHECK(ptr_->mkl_mem_->GetDataHandle() == ptr_->shandle.dptr);
+  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetPrimitiveDesc() == desc) {
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   }
 
   CHECK(ptr_->shandle.size >= desc.get_size());
   ptr_->CheckAndAlloc(desc.get_size());
-  ptr_->mkl_mem_.reset(new mkldnn::memory(desc, ptr_->shandle.dptr));
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-  return ptr_->mkl_mem_.get();
+  ptr_->mkl_mem_.reset(new MKLDNNMemory(desc, ptr_->shandle.dptr));
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+  return ptr_->mkl_mem_->GetRaw();
 }
 #endif
 
diff --git a/src/operator/nn/lrn-inl.h b/src/operator/nn/lrn-inl.h
index fdae1ec..cb441de 100644
--- a/src/operator/nn/lrn-inl.h
+++ b/src/operator/nn/lrn-inl.h
@@ -58,8 +58,35 @@ struct LRNParam : public dmlc::Parameter<LRNParam> {
     DMLC_DECLARE_FIELD(nsize)
     .describe("normalization window width in elements.");
   }
+
+  bool operator==(const LRNParam& other) const {
+    return (this->nsize == other.nsize &&
+            fabs(this->alpha - other.alpha) < 1e-6 &&
+            fabs(this->beta  - other.beta)  < 1e-6 &&
+            fabs(this->knorm - other.knorm) < 1e-6);
+  }
 };  // struct LRNParam
 
+}  // namespace op
+}  // namespace mxnet
+
+namespace std {
+template<>
+struct hash<mxnet::op::LRNParam> {
+  size_t operator()(const mxnet::op::LRNParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.alpha);
+    ret = dmlc::HashCombine(ret, val.beta);
+    ret = dmlc::HashCombine(ret, val.knorm);
+    ret = dmlc::HashCombine(ret, val.nsize);
+    return ret;
+  }
+};
+}  // namespace std
+
+namespace mxnet {
+namespace op {
+
 template<typename xpu>
 void LRNForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                 const std::vector<TBlob> &in_data,
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 61bef11..489351e 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -334,11 +334,104 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
                                  const mkldnn::memory::primitive_desc 
&target_pd,
                                  int num_groups);
 
-mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc);
+mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc);
 mkldnn_memory_format_t GetDefaultFormat(int num_dims);
 mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc 
pd,
                                                 mkldnn_memory_format_t format);
 
+inline bool same_shape(const TShape &shape, const mkldnn_dims_t dims, int 
ndims) {
+  if (shape.ndim() != (size_t)ndims)
+    return false;
+  for (int i = 0; i < ndims; i++)
+    if (shape[i] != dims[i])
+      return false;
+  return true;
+}
+
+inline bool same_shape(const TShape &shape, int dtype,
+                       const mkldnn::memory::desc &desc) {
+  return same_shape(shape, desc.data.dims, desc.data.ndims)
+      && get_mkldnn_type(dtype) == desc.data.data_type;
+}
+
+/*
+ * There is a large overhead of getting mkldnn::memory::primitive_desc from
+ * mkldnn::memory. This class is created to cache the metadata of mkldnn memory
+ * to provide a much more lightweight method to access them.
+ */
+class MKLDNNMemory {
+  std::shared_ptr<mkldnn::memory> mem;
+  mkldnn::memory::desc desc;
+  size_t size;      // The number of bytes.
+
+ public:
+  MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) 
{
+    mem.reset(new mkldnn::memory(pd, addr));
+    size = pd.get_size();
+  }
+
+  explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc(
+      mem->get_primitive_desc().desc()) {
+    this->mem = mem;
+    mkldnn::memory::primitive_desc pd = mem->get_primitive_desc();
+    size = pd.get_size();
+  }
+
+  void SetDataHandle(void *handle) {
+    mem->set_data_handle(handle);
+  }
+
+  void *GetDataHandle() const {
+    return mem->get_data_handle();
+  }
+
+  std::shared_ptr<mkldnn::memory> GetMem() const {
+    return mem;
+  }
+
+  mkldnn::memory *GetRaw() const {
+    return mem.get();
+  }
+
+  size_t GetSize() const {
+    return size;
+  }
+
+  mkldnn::memory::primitive_desc GetPrimitiveDesc() const {
+    return mem->get_primitive_desc();
+  }
+
+  mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn_memory_format_t 
format) const {
+    return mxnet::GetPrimitiveDesc(mem->get_primitive_desc(), format);
+  }
+
+  mkldnn_memory_format_t GetDefaultFormat() const {
+    return mxnet::GetDefaultFormat(desc);
+  }
+
+  mkldnn_memory_format_t GetFormat() const {
+    return desc.data.format;
+  }
+
+  bool IsMKLDNN() const {
+    return GetFormat() != GetDefaultFormat();
+  }
+
+  bool SameFormat(mkldnn::memory::primitive_desc pd) const {
+    return mem->get_primitive_desc() == pd;
+  }
+
+  bool SameFormat(const TShape &shape, int dtype) const {
+    return same_shape(shape, dtype, desc);
+  }
+
+  void ReorderTo(mkldnn::memory *other) const {
+    std::vector<mkldnn::primitive> net;
+    net.push_back(mkldnn::reorder(*mem, *other));
+    mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  }
+};
+
 void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<NDArray> &inputs,
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 820cca1..684abd2 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -190,7 +190,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
   }
 }
 
-mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) {
+mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
   if (desc.data.ndims == 1) {
     return desc.data.format;
   } else if (desc.data.ndims == 2) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h 
b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 16f9874..d1c80a6 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -234,20 +234,15 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
     DType* weight_ptr = gamma.data().dptr<DType>();
     DType* bias_ptr = beta.data().dptr<DType>();
     if (!param.fix_gamma) {
-#pragma omp parallel for
-      for (int i = 0; i < channels_; i++) {
-        weight_buf[i] = weight_ptr[i];
-        weight_buf[channels_ + i] = bias_ptr[i];  // bias
-      }
+      memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_);
+      memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * 
channels_);
     } else if (IsBNWriting(req[batchnorm::kGamma])) {
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         weight_buf[i] = (DType)1.0f;
         weight_ptr[i] = (DType)1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
     } else {
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         weight_buf[i] = (DType)1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
@@ -260,7 +255,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
       DType* inmean   = 
aux_states[batchnorm::kMovingMean].data().dptr<DType>();
       DType* invar    = aux_states[batchnorm::kMovingVar].data().dptr<DType>();
       // to align with origin implmentation: batch_norm.cc: L164
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         omean[i] = inmean[i];
         ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
@@ -282,14 +276,13 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
       MKLDNNStream::Get()->Submit();
       DType* mean_mem_ptr = 
reinterpret_cast<DType*>(fwd.GetMean().get_data_handle());
       DType* var_mem_ptr  = 
reinterpret_cast<DType*>(fwd.GetVar().get_data_handle());
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         omean[i] = mean_mem_ptr[i];
         ovar[i]  = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps);
       }
     }
   } else {  // no input gamma and beta
-      LOG(FATAL) << "MKLDNN batch normalization: should not reach here ...";
+    LOG(FATAL) << "MKLDNN batch normalization: should not reach here ...";
   }
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc 
b/src/operator/nn/mkldnn/mkldnn_concat.cc
index d3e6e77..240673d 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -30,6 +30,67 @@
 namespace mxnet {
 namespace op {
 
+class MKLDNNConcatFwd {
+  std::shared_ptr<mkldnn::concat> fwd;
+  std::vector<std::shared_ptr<mkldnn::memory>> data;
+  std::vector<mkldnn::primitive::at> data_mem;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  mkldnn::concat::primitive_desc fwd_pd;
+
+  MKLDNNConcatFwd(
+      int concat_dim,
+      const std::vector<mkldnn::memory::primitive_desc> &data_md): 
fwd_pd(concat_dim, data_md) {
+    data.resize(data_md.size());
+  }
+
+  void SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
+                 const mkldnn::memory &output) {
+    CHECK_EQ(in_data.size(), data.size());
+    for (size_t i = 0; i < data.size(); i++) {
+      if (this->data[i] == nullptr) {
+        this->data[i] = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+                in_data[i]->get_primitive_desc(), 
in_data[i]->get_data_handle()));
+        this->data_mem.push_back(*this->data[i]);
+      } else {
+        this->data[i]->set_data_handle(in_data[i]->get_data_handle());
+      }
+    }
+    if (this->out == nullptr)
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+    else
+      this->out->set_data_handle(output.get_data_handle());
+
+    if (this->fwd == nullptr)
+      fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out));
+  }
+
+  const mkldnn::concat &GetFwd() const {
+    return *fwd;
+  }
+};
+
+static MKLDNNConcatFwd &GetConcatForward(
+    int concat_dim, const std::vector<NDArray> &in_data,
+    const std::vector<mkldnn::memory::primitive_desc> &data_md) {
+  static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> 
fwds;
+  OpSignature key;
+  key.AddSign(concat_dim);
+  key.AddSign(in_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNConcatFwd fwd(concat_dim, data_md);
+    auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>(
+            key, fwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                          const std::vector<NDArray> &in_data,
                          const std::vector<OpReqType> &req,
@@ -39,18 +100,21 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, 
const OpContext &ctx,
   int num_in_data = param.num_args;
   int concat_dim = param.dim;
   std::vector<mkldnn::memory::primitive_desc> data_md;
-  std::vector<mkldnn::primitive::at> data_mem;
+  std::vector<const mkldnn::memory *> data_mem;
+  data_md.reserve(num_in_data);
+  data_mem.reserve(num_in_data);
   for (int i =0; i < num_in_data; i++) {
-      auto tmp_mem = in_data[i].GetMKLDNNData();
-      auto tmp_pd = tmp_mem->get_primitive_desc();
-      data_md.push_back(tmp_pd);
-      data_mem.push_back(*tmp_mem);
+    const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
+    mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
+    data_md.push_back(tmp_pd);
+    data_mem.push_back(tmp_mem);
   }
-  mkldnn::concat::primitive_desc fwd_pd(concat_dim, data_md);
-  auto engine = CpuEngine::Get()->get_engine();
-  auto out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
-      fwd_pd.dst_primitive_desc(), req[concat_enum::kOut]);
-  MKLDNNStream::Get()->RegisterPrim(mkldnn::concat(fwd_pd, data_mem, 
*out_mem.second));
+  MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md);
+  mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
+                                                   
fwd.fwd_pd.dst_primitive_desc(),
+                                                   req[concat_enum::kOut]);
+  fwd.SetNewMem(data_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
   CommitOutput(out_data[concat_enum::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h 
b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index 9a9bf62..b0b715a 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
 
 #if MXNET_USE_MKLDNN == 1
+#include <utility>
 #include <mkldnn.hpp>
 #include "../lrn-inl.h"
 #include "./mkldnn_base-inl.h"
@@ -39,11 +40,11 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
   return algorithm::lrn_across_channels;
 }
 
-inline lrn_forward::primitive_desc GetLRNFwd(const LRNParam &param,
-                                             const bool is_train,
-                                             const memory::desc &src_md) {
-  const auto  engine = CpuEngine::Get()->get_engine();
-  const auto  alg = GetMKLDNNLRNAlgo(param);
+inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
+                                                 const bool is_train,
+                                                 const memory::desc &src_md) {
+  mkldnn::engine &engine = CpuEngine::Get()->get_engine();
+  const algorithm  alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
   const float beta = param.beta;
   const int   nsize = param.nsize;
@@ -63,8 +64,8 @@ GetLRNBwd(const LRNParam &param,
           const mkldnn::memory::desc &diff_in_md,
           const mkldnn::memory::desc &diff_md,
           const lrn_forward::primitive_desc &lrnFwd_desc) {
-  const auto engine = CpuEngine::Get()->get_engine();
-  const auto alg = GetMKLDNNLRNAlgo(param);
+  mkldnn::engine &engine = CpuEngine::Get()->get_engine();
+  const algorithm alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
   const float beta = param.beta;
   const int nsize = param.nsize;
@@ -76,28 +77,113 @@ GetLRNBwd(const LRNParam &param,
                                engine, lrnFwd_desc);
 }
 
+
+typedef ParamOpSign<LRNParam> MKLDNNLRNSignature;
+
+// LRN Forward Class
+class MKLDNNLRNFwd {
+ public:
+  MKLDNNLRNFwd(const LRNParam& param,
+               bool  is_train,
+               const NDArray &in_data):
+               is_train(is_train) {
+    _Init(param, is_train, in_data);
+  }
+
+  ~MKLDNNLRNFwd() {}
+
+  void SetDataHandle(const NDArray &data,
+                     const NDArray &output);
+
+  void Execute();
+
+ private:
+  std::shared_ptr<mkldnn::lrn_forward> fwd;
+  std::shared_ptr<mkldnn::memory> in_mem;
+  std::shared_ptr<mkldnn::memory> out_mem;
+  std::shared_ptr<mkldnn::memory> ws_mem;
+  bool is_train;
+
+ private:
+  void _Init(const LRNParam &param, bool is_train, const NDArray &in_data);
+};  // End of LRN Forword Class
+
+void MKLDNNLRNFwd::_Init(const LRNParam &param,
+                         bool is_train,
+                         const NDArray &in_data) {
+  mkldnn::memory::desc in_data_md = 
in_data.GetMKLDNNData()->get_primitive_desc().desc();
+  lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, 
in_data_md);
+
+  this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
+                     ->get_primitive_desc()));
+  this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
+  if (is_train) {
+    // If it's training, we have to create a workspace memory. Otherwise, 
MKLDNN
+    // will have segmentation fault.
+    ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
+    this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
+        new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
+                                *this->ws_mem, *this->out_mem));
+  } else {
+    this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
+        new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)),
+                                *(this->out_mem)));
+  }
+}
+
+void MKLDNNLRNFwd::SetDataHandle(const NDArray &in_data,
+                                 const NDArray &out_data) {
+  const mkldnn::memory *in_data_mem   = in_data.GetMKLDNNData();
+  mkldnn::memory *out_data_mem  = 
const_cast<NDArray&>(out_data).CreateMKLDNNData(
+                       this->out_mem->get_primitive_desc());
+  this->in_mem->set_data_handle(in_data_mem->get_data_handle());
+  this->out_mem->set_data_handle(out_data_mem->get_data_handle());
+}
+
+void MKLDNNLRNFwd::Execute() {
+  MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
+  MKLDNNStream::Get()->Submit();
+}
+// End of LRN Class and its functions
+
+static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
+                               const OpContext &ctx,
+                               const NDArray &in_data) {
+  static thread_local std::unordered_map<MKLDNNLRNSignature,
+                                         MKLDNNLRNFwd,
+                                         OpHash> lrn_fwds;
+  auto alg_ = algorithm::lrn_across_channels;
+  auto kind_ = prop_kind::forward_training;
+  if (ctx.is_train) {
+    kind_ = prop_kind::forward_training;
+  } else {
+    kind_ = prop_kind::forward_scoring;
+  }
+
+  MKLDNNLRNSignature key(param);
+  key.AddSign(alg_);
+  key.AddSign(kind_);
+  key.AddSign(in_data);
+
+  auto it = lrn_fwds.find(key);
+  if (it == lrn_fwds.end()) {
+    MKLDNNLRNFwd fwd(param, ctx.is_train, in_data);
+    auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
+                                   (key, fwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNLRNForward(const OpContext &ctx,
                       const LRNParam &param,
                       const NDArray &in_data,
                       const OpReqType req,
                       const NDArray &out_data) {
-  auto src_mem = in_data.GetMKLDNNData();
-  const auto src_md = src_mem->get_primitive_desc().desc();
-  const auto pdesc = GetLRNFwd(param, ctx.is_train, src_md);
-  auto dst_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
-          pdesc.dst_primitive_desc());
-  if (ctx.is_train) {
-    std::shared_ptr<const mkldnn::memory> ws_mem(
-            new mkldnn::memory(pdesc.workspace_primitive_desc()));
-    MKLDNNStream::Get()->RegisterPrim(
-        lrn_forward(pdesc, mkldnn::primitive::at(*src_mem),
-        *ws_mem, *dst_mem));
-    MKLDNNStream::Get()->Submit();
-  } else {
-    MKLDNNStream::Get()->RegisterPrim(
-        lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), *dst_mem));
-    MKLDNNStream::Get()->Submit();
-  }
+  MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_data);
+  fwd.SetDataHandle(in_data, out_data);
+  fwd.Execute();
 }
 
 void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
@@ -109,9 +195,10 @@ void MKLDNNLRNBackward(const OpContext &ctx, const 
LRNParam &param,
     return;
   }
   // Repeat FW for getting workspace
-  auto data_mem = in_data.GetMKLDNNData();
-  const auto data_md = data_mem->get_primitive_desc().desc();
-  const auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md);
+  const mkldnn::memory *data_mem = in_data.GetMKLDNNData();
+  const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
+  const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, 
ctx.is_train,
+                                                              data_md);
 
   // TODO(Patric): To keep the function stateless, we can't pass workspace
   //               from LRN forward to backward. We have to re-compute
@@ -125,12 +212,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const 
LRNParam &param,
           lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
           *ws_mem, *dst_temp));
 
-  const auto data_in_md = pdesc_fwd.src_primitive_desc().desc();
-  auto diff_mem = out_grad.GetMKLDNNData();
-  const auto diff_md = diff_mem->get_primitive_desc().desc();
-  const auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd);
-  auto diff_src_mem = CreateMKLDNNMem(in_grad,
-          pdesc_bwd.diff_src_primitive_desc(), req);
+  const mkldnn::memory::desc data_in_md = 
pdesc_fwd.src_primitive_desc().desc();
+  const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData();
+  const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
+  const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, 
data_in_md,
+                                                                   diff_md, 
pdesc_fwd);
+  mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
+                                                 
pdesc_bwd.diff_src_primitive_desc(), req);
 
   MKLDNNStream::Get()->RegisterPrim(
         lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem),

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to