piiswrong closed pull request #10317: [MXNET-264] Improve performance of MKLDNN 
in small batch sizes.
URL: https://github.com/apache/incubator-mxnet/pull/10317
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index b8b7f20fcf3..8b5ae58f63f 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -74,6 +74,7 @@ enum NDArrayFormatErr {
   kRSPIdxErr,     // indices error for row sparse
 };
 
+class MKLDNNMemory;
 
 /*!
  * \brief ndarray interface
@@ -693,7 +694,7 @@ class NDArray {
 #if MXNET_USE_MKLDNN == 1
     /*! This is created when data is stored in MKLDNN format.
      */
-    std::shared_ptr<mkldnn::memory> mkl_mem_;
+    std::shared_ptr<MKLDNNMemory> mkl_mem_;
 #endif
     /*! \brief variable from engine */
     Engine::VarHandle var;
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 52b96fad692..85e70840411 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -100,7 +100,7 @@ struct ChunkMem {
   Storage::Handle h;
   std::vector<Storage::Handle> aux_h;
 #if MXNET_USE_MKLDNN == 1
-  std::shared_ptr<mkldnn::memory> mem;
+  std::shared_ptr<MKLDNNMemory> mem;
 #endif
 };
 
@@ -117,8 +117,8 @@ NDArray::Chunk::~Chunk() {
     if (skip_free == false) {
 #if MXNET_USE_MKLDNN == 1
       if (mem.mem) {
-        CHECK_LE(mem.mem->get_primitive_desc().get_size(), mem.h.size);
-        CHECK_EQ(mem.mem->get_data_handle(), mem.h.dptr);
+        CHECK_LE(mem.mem->GetSize(), mem.h.size);
+        CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr);
       }
 #endif
       if (mem.h.size > 0) Storage::Get()->Free(mem.h);
@@ -181,19 +181,20 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) 
const {
     NDArray ret(shape, ctx(), true, dtype());
     // We shouldn't submit the reorder primitive here because submit will
     // be called in operators.
-    auto format = 
GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
-    CHECK_NE(format, ptr_->mkl_mem_->get_primitive_desc().desc().data.format);
-    auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), 
format);
-    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat();
+    CHECK_NE(format, ptr_->mkl_mem_->GetFormat());
+    mkldnn::memory::primitive_desc def_pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(format);
+    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     MKLDNNStream *stream = MKLDNNStream::Get();
-    stream->RegisterMem(ptr_->mkl_mem_);
-    stream->RegisterPrim(mkldnn::reorder(*ptr_->mkl_mem_, *def_mem));
+    std::shared_ptr<mkldnn::memory> curr_mem = ptr_->mkl_mem_->GetMem();
+    stream->RegisterMem(curr_mem);
+    stream->RegisterPrim(mkldnn::reorder(*curr_mem, *def_mem));
     // def_mem points to a memory region in the temp space. It's only valid
     // inside an operator. As such, the returned NDArray can only be valid
     // inside an operator and the shared point doesn't need to do anything
     // when it's destroyed.
-    ret.ptr_->mkl_mem_ = std::shared_ptr<mkldnn::memory>(def_mem,
-                                                         [](mkldnn::memory 
*mem){});
+    auto tmp = std::shared_ptr<mkldnn::memory>(def_mem, [](mkldnn::memory 
*mem){});
+    ret.ptr_->mkl_mem_.reset(new MKLDNNMemory(tmp));
     ret.ptr_->shandle.dptr = def_mem->get_data_handle();
     ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size();
     ret.ptr_->delay_alloc = false;
@@ -323,27 +324,13 @@ void NDArray::set_fresh_out_grad(bool state) const {
 }
 
 #if MXNET_USE_MKLDNN == 1
-static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int 
ndims) {
-  if (shape.ndim() != (size_t)ndims)
-    return false;
-  for (int i = 0; i < ndims; i++)
-    if (shape[i] != dims[i])
-      return false;
-  return true;
-}
-
-static inline bool same_shape(const TShape &shape, int dtype, 
mkldnn::memory::desc desc) {
-  return same_shape(shape, desc.data.dims, desc.data.ndims)
-      && get_mkldnn_type(dtype) == desc.data.data_type;
-}
 
 bool NDArray::Chunk::IsMKLDNN() const {
   if (storage_type != kDefaultStorage)
     return false;
   if (mkl_mem_ == nullptr)
     return false;
-  auto desc = mkl_mem_->get_primitive_desc().desc();
-  return desc.data.format != GetDefaultFormat(desc);
+  return mkl_mem_->IsMKLDNN();
 }
 
 bool NDArray::Chunk::IsDefault() const {
@@ -353,23 +340,19 @@ bool NDArray::Chunk::IsDefault() const {
   // format.
   if (mkl_mem_ == nullptr)
     return true;
-  auto desc = mkl_mem_->get_primitive_desc().desc();
-  return desc.data.format == GetDefaultFormat(desc);
+  return !mkl_mem_->IsMKLDNN();
 }
 
 void NDArray::Chunk::Reorder2Default() {
   if (mkl_mem_ == nullptr)
     return;
 
-  auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc());
-  CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format);
+  mkldnn_memory_format_t format = mkl_mem_->GetDefaultFormat();
+  CHECK_NE(format, mkl_mem_->GetFormat());
 
-  auto def_pd = GetPrimitiveDesc(mkl_mem_->get_primitive_desc(), format);
+  mkldnn::memory::primitive_desc def_pd = mkl_mem_->GetPrimitiveDesc(format);
   mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd));
-  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
-  std::vector<mkldnn::primitive> net;
-  net.push_back(mkldnn::reorder(*mkl_mem_, *def_mem));
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  mkl_mem_->ReorderTo(def_mem.get());
 
   CHECK(shandle.size >= def_pd.get_size());
   CheckAndAlloc(def_pd.get_size());
@@ -380,11 +363,11 @@ void NDArray::Chunk::Reorder2Default() {
 
 void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc 
&pd) {
   // If the memory already uses the specified layout, don't do anything.
-  if (mkl_mem_ != nullptr && mkl_mem_->get_primitive_desc() == pd)
+  if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(pd))
     return;
-  auto _pd = pd;
-  auto _desc = _pd.desc();
-  auto def_format = GetDefaultFormat(_desc);
+  mkldnn::memory::primitive_desc _pd = pd;
+  mkldnn::memory::desc _desc = _pd.desc();
+  mkldnn_memory_format_t def_format = GetDefaultFormat(_desc);
   // If the memory is default, don't do anything.
   if (def_format == _desc.data.format && IsDefault())
     return;
@@ -397,10 +380,10 @@ void NDArray::Chunk::MKLDNNDataReorder(const 
mkldnn::memory::primitive_desc &pd)
   std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd));
   std::shared_ptr<mkldnn::memory> old_mem;
   if (IsDefault()) {
-    auto def_pd = GetPrimitiveDesc(pd, def_format);
+    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(pd, def_format);
     old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr));
   } else {
-    old_mem = this->mkl_mem_;
+    old_mem = this->mkl_mem_->GetMem();
   }
   CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims);
 
@@ -413,15 +396,15 @@ void NDArray::Chunk::MKLDNNDataReorder(const 
mkldnn::memory::primitive_desc &pd)
   CheckAndAlloc(pd.get_size());
   // TODO(zhengda) We need to avoid memory copy here.
   memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size());
-  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+  mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr));
 }
 
 void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
   // The shape of the array and the one of the MKL memory may mismatch.
   // For example, if the array stores parameters, the MKL memory may store data
   // in 5 dimensions while the NDArray stores data in 4 dimensions.
-  if (mkl_mem_ && mkl_mem_->get_data_handle() == shandle.dptr
-      && same_shape(shape, dtype, mkl_mem_->get_primitive_desc().desc())) {
+  if (mkl_mem_ && mkl_mem_->GetDataHandle() == shandle.dptr
+      && mkl_mem_->SameFormat(shape, dtype)) {
     return;
   }
 
@@ -459,7 +442,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int 
dtype) {
   }
   mkldnn::memory::primitive_desc pd(data_md, cpu_engine);
   CHECK(shandle.size >= pd.get_size());
-  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+  mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr));
 }
 
 /*
@@ -469,7 +452,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int 
dtype) {
  */
 static inline mkldnn::memory *GetMKLDNNExact(
     const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) {
-  auto src_desc = mem->get_primitive_desc();
+  mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc();
   if (desc == src_desc && desc.desc().data.format == 
src_desc.desc().data.format) {
     return const_cast<mkldnn::memory *>(mem);
   } else {
@@ -486,16 +469,16 @@ const mkldnn::memory *NDArray::GetMKLDNNData(
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN 
memory desc";
     return nullptr;
   }
-  auto mem = GetMKLDNNData();
+  const mkldnn::memory *mem = GetMKLDNNData();
   mkldnn::memory::primitive_desc _desc = desc;
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
+  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
+  mkldnn::memory::desc desc2 = _desc.desc();
   // The MKL memory has the same format and shape as required,
   // or both use the default format, we can return the MKL memory.
   if (mem->get_primitive_desc() == desc
       || (desc1.data.format == GetDefaultFormat(desc1)
         && desc2.data.format == GetDefaultFormat(desc2))) {
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   } else {
     return nullptr;
   }
@@ -509,7 +492,7 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   }
   CHECK(storage_type() == kDefaultStorage);
 
-  auto mem = GetMKLDNNData();
+  const mkldnn::memory *mem = GetMKLDNNData();
   // If the memory descriptor matches, it's easy.
   MKLDNNStream *stream = MKLDNNStream::Get();
   if (mem->get_primitive_desc() == desc) {
@@ -519,15 +502,15 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   mkldnn::memory::primitive_desc _desc = desc;
   // Now we need to determine if we should reorder the memory.
   // If both use the default formats, we think we don't need to reorder.
-  auto desc1 = mem->get_primitive_desc().desc();
-  auto desc2 = _desc.desc();
+  mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc();
+  mkldnn::memory::desc desc2 = _desc.desc();
   if (desc1.data.format == GetDefaultFormat(desc1) &&
       desc2.data.format == GetDefaultFormat(desc2)) {
     mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle()));
     stream->RegisterMem(ret);
     return ret.get();
   } else {
-    auto ret = TmpMemMgr::Get()->Alloc(desc);
+    mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(desc);
     stream->RegisterPrim(mkldnn::reorder(*mem, *ret));
     return ret;
   }
@@ -538,18 +521,15 @@ NDArray NDArray::Reorder2Default() const {
 
   if (ptr_->mkl_mem_ == nullptr)
     return *this;
-  auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
-  if (format == ptr_->mkl_mem_->get_primitive_desc().desc().data.format)
+  mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat();
+  if (format == ptr_->mkl_mem_->GetFormat())
     return *this;
 
   NDArray ret(shape(), ctx(), false, dtype());
-  auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format);
+  mkldnn::memory::primitive_desc def_pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(format);
   CHECK(ret.ptr_->shandle.size >= def_pd.get_size());
   mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr);
-  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
-  std::vector<mkldnn::primitive> net;
-  net.push_back(mkldnn::reorder(*ptr_->mkl_mem_, def_mem));
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  ptr_->mkl_mem_->ReorderTo(&def_mem);
   return ret;
 }
 
@@ -584,17 +564,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
   if (IsMKLDNNData())
     CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
-  // If shandle has data, the data in shandle and mkl_mem_ should match.
-  if (ptr_->shandle.dptr)
-    CHECK(ptr_->shandle.dptr == ptr_->mkl_mem_->get_data_handle());
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-  auto pd = ptr_->mkl_mem_->get_primitive_desc();
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
   if (IsView()) {
+    mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc();
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
-  }
-  if (IsView()) {
-    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->get_data_handle())
+    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle())
         + byte_offset_;
 
     // Create the primitive desc for the new mkldnn memory.
@@ -612,13 +587,13 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
     MKLDNNStream::Get()->RegisterMem(ret);
     return ret.get();
   } else {
-    return ptr_->mkl_mem_.get();
+    return ptr_->mkl_mem_->GetRaw();
   }
 }
 
 void NDArray::CopyFrom(const mkldnn::memory &mem) {
   CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
-  if (ptr_->mkl_mem_.get() == &mem)
+  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem)
     return;
 
   CHECK(mem.get_primitive_desc().get_size() == shape().Size() * 
GetTypeSize(dtype_))
@@ -630,10 +605,10 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
                   dtype_);
-  stream->RegisterMem(ptr_->mkl_mem_);
-  auto from_desc = mem.get_primitive_desc().desc();
-  auto this_desc = ptr_->mkl_mem_->get_primitive_desc().desc();
-  auto from_def_format = GetDefaultFormat(from_desc);
+  stream->RegisterMem(ptr_->mkl_mem_->GetMem());
+  mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc();
+  mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
+  mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
   if (IsView()) {
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
@@ -652,12 +627,13 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
   } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
     // In this case, the source memory stores data in a customized layout. We
     // need to reorganize the data in memory before we can reshape.
-    auto def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), from_def_format);
-    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    mkldnn::memory::primitive_desc def_pd = 
GetPrimitiveDesc(mem.get_primitive_desc(),
+                                                             from_def_format);
+    mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
     // Now we can reshape it
     mkldnn::memory::dims dims(this_desc.data.dims,
@@ -668,32 +644,32 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     mkldnn::memory::primitive_desc pd(data_md, 
mem.get_primitive_desc().get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
-  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->get_primitive_desc()) 
{
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
+  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) {
     // If the layout is the same, we can just copy data.
-    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
+    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
   } else {
-    auto src_def = GetDefaultFormat(mem.get_primitive_desc().desc());
-    auto dst_def = 
GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
+    mkldnn_memory_format_t src_def = 
GetDefaultFormat(mem.get_primitive_desc().desc());
+    mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat();
     // If both are not using the default layouts. There isn't much we can do,
     // other than reorder data layout directly.
-    if (dst_def != ptr_->mkl_mem_->get_primitive_desc().desc().data.format
+    if (dst_def != ptr_->mkl_mem_->GetFormat()
         && src_def != mem.get_primitive_desc().desc().data.format) {
-      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
-    } else if (dst_def == 
ptr_->mkl_mem_->get_primitive_desc().desc().data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
+    } else if (dst_def == ptr_->mkl_mem_->GetFormat()) {
       // If the dest mem uses the default memory layout, we can simply use
       // the default format of the source memory to improve perf of reorder.
-      auto pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), 
src_def);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
ptr_->mkl_mem_->get_data_handle()));
+      mkldnn::memory::primitive_desc pd = 
ptr_->mkl_mem_->GetPrimitiveDesc(src_def);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, 
ptr_->mkl_mem_->GetDataHandle()));
       stream->RegisterMem(tmp_mem);
       stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
     } else {
       // If the src mem uses the default memory layout, we can use
       // the default format of the source memory to improve perf.
-      auto pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
+      mkldnn::memory::primitive_desc pd = 
GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
       mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
       stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, 
*ptr_->mkl_mem_->GetRaw()));
     }
   }
 }
@@ -710,28 +686,28 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const 
mkldnn::memory::primitive_desc &
   }
 
   mkldnn::memory::primitive_desc _desc = desc;
-  auto required_format = _desc.desc().data.format;
-  auto def_format = GetDefaultFormat(_desc.desc());
+  mkldnn_memory_format_t required_format = _desc.desc().data.format;
+  mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc());
   // If the required format is a default format, we don't need to worry about 
the shape.
   // If the shape isn't the same, it actually implicitly reshapes data.
   if (required_format == def_format) {
     ptr_->SetMKLMem(shape_, dtype_);
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   }
 
   if (ptr_->mkl_mem_)
-    CHECK(ptr_->mkl_mem_->get_data_handle() == ptr_->shandle.dptr);
-  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->get_primitive_desc() == desc) {
-    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+    CHECK(ptr_->mkl_mem_->GetDataHandle() == ptr_->shandle.dptr);
+  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetPrimitiveDesc() == desc) {
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+    return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
   }
 
   CHECK(ptr_->shandle.size >= desc.get_size());
   ptr_->CheckAndAlloc(desc.get_size());
-  ptr_->mkl_mem_.reset(new mkldnn::memory(desc, ptr_->shandle.dptr));
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
-  return ptr_->mkl_mem_.get();
+  ptr_->mkl_mem_.reset(new MKLDNNMemory(desc, ptr_->shandle.dptr));
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
+  return ptr_->mkl_mem_->GetRaw();
 }
 #endif
 
diff --git a/src/operator/nn/lrn-inl.h b/src/operator/nn/lrn-inl.h
index fdae1eca0ae..cb441de9927 100644
--- a/src/operator/nn/lrn-inl.h
+++ b/src/operator/nn/lrn-inl.h
@@ -58,8 +58,35 @@ struct LRNParam : public dmlc::Parameter<LRNParam> {
     DMLC_DECLARE_FIELD(nsize)
     .describe("normalization window width in elements.");
   }
+
+  bool operator==(const LRNParam& other) const {
+    return (this->nsize == other.nsize &&
+            fabs(this->alpha - other.alpha) < 1e-6 &&
+            fabs(this->beta  - other.beta)  < 1e-6 &&
+            fabs(this->knorm - other.knorm) < 1e-6);
+  }
 };  // struct LRNParam
 
+}  // namespace op
+}  // namespace mxnet
+
+namespace std {
+template<>
+struct hash<mxnet::op::LRNParam> {
+  size_t operator()(const mxnet::op::LRNParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.alpha);
+    ret = dmlc::HashCombine(ret, val.beta);
+    ret = dmlc::HashCombine(ret, val.knorm);
+    ret = dmlc::HashCombine(ret, val.nsize);
+    return ret;
+  }
+};
+}  // namespace std
+
+namespace mxnet {
+namespace op {
+
 template<typename xpu>
 void LRNForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                 const std::vector<TBlob> &in_data,
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 61bef117a88..489351ebe2c 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -334,11 +334,104 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
                                  const mkldnn::memory::primitive_desc 
&target_pd,
                                  int num_groups);
 
-mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc);
+mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc);
 mkldnn_memory_format_t GetDefaultFormat(int num_dims);
 mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc 
pd,
                                                 mkldnn_memory_format_t format);
 
+inline bool same_shape(const TShape &shape, const mkldnn_dims_t dims, int 
ndims) {
+  if (shape.ndim() != (size_t)ndims)
+    return false;
+  for (int i = 0; i < ndims; i++)
+    if (shape[i] != dims[i])
+      return false;
+  return true;
+}
+
+inline bool same_shape(const TShape &shape, int dtype,
+                       const mkldnn::memory::desc &desc) {
+  return same_shape(shape, desc.data.dims, desc.data.ndims)
+      && get_mkldnn_type(dtype) == desc.data.data_type;
+}
+
+/*
+ * There is a large overhead of getting mkldnn::memory::primitive_desc from
+ * mkldnn::memory. This class is created to cache the metadata of mkldnn memory
+ * to provide a much more lightweight method to access them.
+ */
+class MKLDNNMemory {
+  std::shared_ptr<mkldnn::memory> mem;
+  mkldnn::memory::desc desc;
+  size_t size;      // The number of bytes.
+
+ public:
+  MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) 
{
+    mem.reset(new mkldnn::memory(pd, addr));
+    size = pd.get_size();
+  }
+
+  explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc(
+      mem->get_primitive_desc().desc()) {
+    this->mem = mem;
+    mkldnn::memory::primitive_desc pd = mem->get_primitive_desc();
+    size = pd.get_size();
+  }
+
+  void SetDataHandle(void *handle) {
+    mem->set_data_handle(handle);
+  }
+
+  void *GetDataHandle() const {
+    return mem->get_data_handle();
+  }
+
+  std::shared_ptr<mkldnn::memory> GetMem() const {
+    return mem;
+  }
+
+  mkldnn::memory *GetRaw() const {
+    return mem.get();
+  }
+
+  size_t GetSize() const {
+    return size;
+  }
+
+  mkldnn::memory::primitive_desc GetPrimitiveDesc() const {
+    return mem->get_primitive_desc();
+  }
+
+  mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn_memory_format_t 
format) const {
+    return mxnet::GetPrimitiveDesc(mem->get_primitive_desc(), format);
+  }
+
+  mkldnn_memory_format_t GetDefaultFormat() const {
+    return mxnet::GetDefaultFormat(desc);
+  }
+
+  mkldnn_memory_format_t GetFormat() const {
+    return desc.data.format;
+  }
+
+  bool IsMKLDNN() const {
+    return GetFormat() != GetDefaultFormat();
+  }
+
+  bool SameFormat(mkldnn::memory::primitive_desc pd) const {
+    return mem->get_primitive_desc() == pd;
+  }
+
+  bool SameFormat(const TShape &shape, int dtype) const {
+    return same_shape(shape, dtype, desc);
+  }
+
+  void ReorderTo(mkldnn::memory *other) const {
+    std::vector<mkldnn::primitive> net;
+    net.push_back(mkldnn::reorder(*mem, *other));
+    mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  }
+};
+
 void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
                      const OpContext &ctx,
                      const std::vector<NDArray> &inputs,
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc 
b/src/operator/nn/mkldnn/mkldnn_base.cc
index 820cca1402c..684abd24685 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -190,7 +190,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
   }
 }
 
-mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) {
+mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
   if (desc.data.ndims == 1) {
     return desc.data.format;
   } else if (desc.data.ndims == 2) {
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h 
b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 16f9874bd5c..d1c80a63eee 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -234,20 +234,15 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
     DType* weight_ptr = gamma.data().dptr<DType>();
     DType* bias_ptr = beta.data().dptr<DType>();
     if (!param.fix_gamma) {
-#pragma omp parallel for
-      for (int i = 0; i < channels_; i++) {
-        weight_buf[i] = weight_ptr[i];
-        weight_buf[channels_ + i] = bias_ptr[i];  // bias
-      }
+      memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_);
+      memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * 
channels_);
     } else if (IsBNWriting(req[batchnorm::kGamma])) {
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         weight_buf[i] = (DType)1.0f;
         weight_ptr[i] = (DType)1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
       }
     } else {
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         weight_buf[i] = (DType)1.0f;
         weight_buf[channels_ + i] = bias_ptr[i];  // bias
@@ -260,7 +255,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
       DType* inmean   = 
aux_states[batchnorm::kMovingMean].data().dptr<DType>();
       DType* invar    = aux_states[batchnorm::kMovingVar].data().dptr<DType>();
       // to align with origin implmentation: batch_norm.cc: L164
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         omean[i] = inmean[i];
         ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
@@ -282,14 +276,13 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const 
BatchNormParam &param,
       MKLDNNStream::Get()->Submit();
       DType* mean_mem_ptr = 
reinterpret_cast<DType*>(fwd.GetMean().get_data_handle());
       DType* var_mem_ptr  = 
reinterpret_cast<DType*>(fwd.GetVar().get_data_handle());
-#pragma omp parallel for
       for (int i = 0; i < channels_; i++) {
         omean[i] = mean_mem_ptr[i];
         ovar[i]  = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps);
       }
     }
   } else {  // no input gamma and beta
-      LOG(FATAL) << "MKLDNN batch normalization: should not reach here ...";
+    LOG(FATAL) << "MKLDNN batch normalization: should not reach here ...";
   }
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc 
b/src/operator/nn/mkldnn/mkldnn_concat.cc
index d3e6e775020..240673de4ab 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -30,6 +30,67 @@
 namespace mxnet {
 namespace op {
 
+class MKLDNNConcatFwd {
+  std::shared_ptr<mkldnn::concat> fwd;
+  std::vector<std::shared_ptr<mkldnn::memory>> data;
+  std::vector<mkldnn::primitive::at> data_mem;
+  std::shared_ptr<mkldnn::memory> out;
+
+ public:
+  mkldnn::concat::primitive_desc fwd_pd;
+
+  MKLDNNConcatFwd(
+      int concat_dim,
+      const std::vector<mkldnn::memory::primitive_desc> &data_md): 
fwd_pd(concat_dim, data_md) {
+    data.resize(data_md.size());
+  }
+
+  void SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
+                 const mkldnn::memory &output) {
+    CHECK_EQ(in_data.size(), data.size());
+    for (size_t i = 0; i < data.size(); i++) {
+      if (this->data[i] == nullptr) {
+        this->data[i] = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+                in_data[i]->get_primitive_desc(), 
in_data[i]->get_data_handle()));
+        this->data_mem.push_back(*this->data[i]);
+      } else {
+        this->data[i]->set_data_handle(in_data[i]->get_data_handle());
+      }
+    }
+    if (this->out == nullptr)
+      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+              fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+    else
+      this->out->set_data_handle(output.get_data_handle());
+
+    if (this->fwd == nullptr)
+      fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out));
+  }
+
+  const mkldnn::concat &GetFwd() const {
+    return *fwd;
+  }
+};
+
+static MKLDNNConcatFwd &GetConcatForward(
+    int concat_dim, const std::vector<NDArray> &in_data,
+    const std::vector<mkldnn::memory::primitive_desc> &data_md) {
+  static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> 
fwds;
+  OpSignature key;
+  key.AddSign(concat_dim);
+  key.AddSign(in_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNConcatFwd fwd(concat_dim, data_md);
+    auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>(
+            key, fwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                          const std::vector<NDArray> &in_data,
                          const std::vector<OpReqType> &req,
@@ -39,18 +100,21 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, 
const OpContext &ctx,
   int num_in_data = param.num_args;
   int concat_dim = param.dim;
   std::vector<mkldnn::memory::primitive_desc> data_md;
-  std::vector<mkldnn::primitive::at> data_mem;
+  std::vector<const mkldnn::memory *> data_mem;
+  data_md.reserve(num_in_data);
+  data_mem.reserve(num_in_data);
   for (int i =0; i < num_in_data; i++) {
-      auto tmp_mem = in_data[i].GetMKLDNNData();
-      auto tmp_pd = tmp_mem->get_primitive_desc();
-      data_md.push_back(tmp_pd);
-      data_mem.push_back(*tmp_mem);
+    const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
+    mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
+    data_md.push_back(tmp_pd);
+    data_mem.push_back(tmp_mem);
   }
-  mkldnn::concat::primitive_desc fwd_pd(concat_dim, data_md);
-  auto engine = CpuEngine::Get()->get_engine();
-  auto out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
-      fwd_pd.dst_primitive_desc(), req[concat_enum::kOut]);
-  MKLDNNStream::Get()->RegisterPrim(mkldnn::concat(fwd_pd, data_mem, 
*out_mem.second));
+  MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md);
+  mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut],
+                                                   
fwd.fwd_pd.dst_primitive_desc(),
+                                                   req[concat_enum::kOut]);
+  fwd.SetNewMem(data_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
   CommitOutput(out_data[concat_enum::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h 
b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index 9a9bf62b67d..b0b715a9da0 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_
 
 #if MXNET_USE_MKLDNN == 1
+#include <utility>
 #include <mkldnn.hpp>
 #include "../lrn-inl.h"
 #include "./mkldnn_base-inl.h"
@@ -39,11 +40,11 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam &param) {
   return algorithm::lrn_across_channels;
 }
 
-inline lrn_forward::primitive_desc GetLRNFwd(const LRNParam &param,
-                                             const bool is_train,
-                                             const memory::desc &src_md) {
-  const auto  engine = CpuEngine::Get()->get_engine();
-  const auto  alg = GetMKLDNNLRNAlgo(param);
+inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam &param,
+                                                 const bool is_train,
+                                                 const memory::desc &src_md) {
+  mkldnn::engine &engine = CpuEngine::Get()->get_engine();
+  const algorithm  alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
   const float beta = param.beta;
   const int   nsize = param.nsize;
@@ -63,8 +64,8 @@ GetLRNBwd(const LRNParam &param,
           const mkldnn::memory::desc &diff_in_md,
           const mkldnn::memory::desc &diff_md,
           const lrn_forward::primitive_desc &lrnFwd_desc) {
-  const auto engine = CpuEngine::Get()->get_engine();
-  const auto alg = GetMKLDNNLRNAlgo(param);
+  mkldnn::engine &engine = CpuEngine::Get()->get_engine();
+  const algorithm alg = GetMKLDNNLRNAlgo(param);
   const float alpha = param.alpha;
   const float beta = param.beta;
   const int nsize = param.nsize;
@@ -76,28 +77,113 @@ GetLRNBwd(const LRNParam &param,
                                engine, lrnFwd_desc);
 }
 
+
+typedef ParamOpSign<LRNParam> MKLDNNLRNSignature;
+
+// LRN Forward Class
+class MKLDNNLRNFwd {
+ public:
+  MKLDNNLRNFwd(const LRNParam& param,
+               bool  is_train,
+               const NDArray &in_data):
+               is_train(is_train) {
+    _Init(param, is_train, in_data);
+  }
+
+  ~MKLDNNLRNFwd() {}
+
+  void SetDataHandle(const NDArray &data,
+                     const NDArray &output);
+
+  void Execute();
+
+ private:
+  std::shared_ptr<mkldnn::lrn_forward> fwd;
+  std::shared_ptr<mkldnn::memory> in_mem;
+  std::shared_ptr<mkldnn::memory> out_mem;
+  std::shared_ptr<mkldnn::memory> ws_mem;
+  bool is_train;
+
+ private:
+  void _Init(const LRNParam &param, bool is_train, const NDArray &in_data);
+};  // End of LRN Forword Class
+
+void MKLDNNLRNFwd::_Init(const LRNParam &param,
+                         bool is_train,
+                         const NDArray &in_data) {
+  mkldnn::memory::desc in_data_md = 
in_data.GetMKLDNNData()->get_primitive_desc().desc();
+  lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, 
in_data_md);
+
+  this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData()
+                     ->get_primitive_desc()));
+  this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc()));
+  if (is_train) {
+    // If it's training, we have to create a workspace memory. Otherwise, 
MKLDNN
+    // will have segmentation fault.
+    ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc()));
+    this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
+        new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem),
+                                *this->ws_mem, *this->out_mem));
+  } else {
+    this->fwd = std::shared_ptr<mkldnn::lrn_forward>(
+        new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)),
+                                *(this->out_mem)));
+  }
+}
+
+void MKLDNNLRNFwd::SetDataHandle(const NDArray &in_data,
+                                 const NDArray &out_data) {
+  const mkldnn::memory *in_data_mem   = in_data.GetMKLDNNData();
+  mkldnn::memory *out_data_mem  = 
const_cast<NDArray&>(out_data).CreateMKLDNNData(
+                       this->out_mem->get_primitive_desc());
+  this->in_mem->set_data_handle(in_data_mem->get_data_handle());
+  this->out_mem->set_data_handle(out_data_mem->get_data_handle());
+}
+
+void MKLDNNLRNFwd::Execute() {
+  MKLDNNStream::Get()->RegisterPrim(*(this->fwd));
+  MKLDNNStream::Get()->Submit();
+}
+// End of LRN Class and its functions
+
+static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
+                               const OpContext &ctx,
+                               const NDArray &in_data) {
+  static thread_local std::unordered_map<MKLDNNLRNSignature,
+                                         MKLDNNLRNFwd,
+                                         OpHash> lrn_fwds;
+  auto alg_ = algorithm::lrn_across_channels;
+  auto kind_ = prop_kind::forward_training;
+  if (ctx.is_train) {
+    kind_ = prop_kind::forward_training;
+  } else {
+    kind_ = prop_kind::forward_scoring;
+  }
+
+  MKLDNNLRNSignature key(param);
+  key.AddSign(alg_);
+  key.AddSign(kind_);
+  key.AddSign(in_data);
+
+  auto it = lrn_fwds.find(key);
+  if (it == lrn_fwds.end()) {
+    MKLDNNLRNFwd fwd(param, ctx.is_train, in_data);
+    auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
+                                   (key, fwd));
+    CHECK(ins_ret.second);
+    it = ins_ret.first;
+  }
+  return it->second;
+}
+
 void MKLDNNLRNForward(const OpContext &ctx,
                       const LRNParam &param,
                       const NDArray &in_data,
                       const OpReqType req,
                       const NDArray &out_data) {
-  auto src_mem = in_data.GetMKLDNNData();
-  const auto src_md = src_mem->get_primitive_desc().desc();
-  const auto pdesc = GetLRNFwd(param, ctx.is_train, src_md);
-  auto dst_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
-          pdesc.dst_primitive_desc());
-  if (ctx.is_train) {
-    std::shared_ptr<const mkldnn::memory> ws_mem(
-            new mkldnn::memory(pdesc.workspace_primitive_desc()));
-    MKLDNNStream::Get()->RegisterPrim(
-        lrn_forward(pdesc, mkldnn::primitive::at(*src_mem),
-        *ws_mem, *dst_mem));
-    MKLDNNStream::Get()->Submit();
-  } else {
-    MKLDNNStream::Get()->RegisterPrim(
-        lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), *dst_mem));
-    MKLDNNStream::Get()->Submit();
-  }
+  MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_data);
+  fwd.SetDataHandle(in_data, out_data);
+  fwd.Execute();
 }
 
 void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam &param,
@@ -109,9 +195,10 @@ void MKLDNNLRNBackward(const OpContext &ctx, const 
LRNParam &param,
     return;
   }
   // Repeat FW for getting workspace
-  auto data_mem = in_data.GetMKLDNNData();
-  const auto data_md = data_mem->get_primitive_desc().desc();
-  const auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md);
+  const mkldnn::memory *data_mem = in_data.GetMKLDNNData();
+  const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc();
+  const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, 
ctx.is_train,
+                                                              data_md);
 
   // TODO(Patric): To keep the function stateless, we can't pass workspace
   //               from LRN forward to backward. We have to re-compute
@@ -125,12 +212,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const 
LRNParam &param,
           lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem),
           *ws_mem, *dst_temp));
 
-  const auto data_in_md = pdesc_fwd.src_primitive_desc().desc();
-  auto diff_mem = out_grad.GetMKLDNNData();
-  const auto diff_md = diff_mem->get_primitive_desc().desc();
-  const auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd);
-  auto diff_src_mem = CreateMKLDNNMem(in_grad,
-          pdesc_bwd.diff_src_primitive_desc(), req);
+  const mkldnn::memory::desc data_in_md = 
pdesc_fwd.src_primitive_desc().desc();
+  const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData();
+  const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc();
+  const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, 
data_in_md,
+                                                                   diff_md, 
pdesc_fwd);
+  mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad,
+                                                 
pdesc_bwd.diff_src_primitive_desc(), req);
 
   MKLDNNStream::Get()->RegisterPrim(
         lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem),


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to