This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 5f94c9a [MXNET-264] Improve performance of MKLDNN in small batch sizes. (#10317) 5f94c9a is described below commit 5f94c9aa320703432b04c692931aae021604e2b9 Author: Da Zheng <zhengda1...@gmail.com> AuthorDate: Tue Apr 10 15:33:34 2018 -0700 [MXNET-264] Improve performance of MKLDNN in small batch sizes. (#10317) * Create MKLDNNMemory to cache metadata. * Fix lint error. * Cache concat. * Fix a bug in NDArray. * improve hashing. * don't use omp for gamma and beta in batchnorm. * address the comments. * Avoid computing out mean&var in batchnorm. * Cache LRN. * Fix a bug in LRN. * Fix lint error. * Revert "Avoid computing out mean&var in batchnorm." This reverts commit 71d0dec0a8bae90a9dc53e08ed1297eede34afa1. * remove more omp in batchnorm. * add comments for MKLDNNMemory. * Revert "improve hashing." This reverts commit 58854be7e82a65383f3a9989f358e26a85765a17. * Remove unnecessary TODO. * address comments. * Remove additional auto. * Fix compile error. * remove more auto. --- include/mxnet/ndarray.h | 3 +- src/ndarray/ndarray.cc | 178 +++++++++++-------------- src/operator/nn/lrn-inl.h | 27 ++++ src/operator/nn/mkldnn/mkldnn_base-inl.h | 95 ++++++++++++- src/operator/nn/mkldnn/mkldnn_base.cc | 2 +- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 13 +- src/operator/nn/mkldnn/mkldnn_concat.cc | 84 ++++++++++-- src/operator/nn/mkldnn/mkldnn_lrn-inl.h | 154 ++++++++++++++++----- 8 files changed, 399 insertions(+), 157 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index cd9004f..e172279 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -74,6 +74,7 @@ enum NDArrayFormatErr { kRSPIdxErr, // indices error for row sparse }; +class MKLDNNMemory; /*! * \brief ndarray interface @@ -730,7 +731,7 @@ class NDArray { #if MXNET_USE_MKLDNN == 1 /*! This is created when data is stored in MKLDNN format. */ - std::shared_ptr<mkldnn::memory> mkl_mem_; + std::shared_ptr<MKLDNNMemory> mkl_mem_; #endif /*! \brief variable from engine */ Engine::VarHandle var; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 8745093..d175a13 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -100,7 +100,7 @@ struct ChunkMem { Storage::Handle h; std::vector<Storage::Handle> aux_h; #if MXNET_USE_MKLDNN == 1 - std::shared_ptr<mkldnn::memory> mem; + std::shared_ptr<MKLDNNMemory> mem; #endif }; @@ -117,8 +117,8 @@ NDArray::Chunk::~Chunk() { if (skip_free == false) { #if MXNET_USE_MKLDNN == 1 if (mem.mem) { - CHECK_LE(mem.mem->get_primitive_desc().get_size(), mem.h.size); - CHECK_EQ(mem.mem->get_data_handle(), mem.h.dptr); + CHECK_LE(mem.mem->GetSize(), mem.h.size); + CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); } #endif if (mem.h.size > 0) Storage::Get()->Free(mem.h); @@ -181,19 +181,20 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const { NDArray ret(shape, ctx(), true, dtype()); // We shouldn't submit the reorder primitive here because submit will // be called in operators. - auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc()); - CHECK_NE(format, ptr_->mkl_mem_->get_primitive_desc().desc().data.format); - auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format); - auto def_mem = TmpMemMgr::Get()->Alloc(def_pd); + mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); + CHECK_NE(format, ptr_->mkl_mem_->GetFormat()); + mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterMem(ptr_->mkl_mem_); - stream->RegisterPrim(mkldnn::reorder(*ptr_->mkl_mem_, *def_mem)); + std::shared_ptr<mkldnn::memory> curr_mem = ptr_->mkl_mem_->GetMem(); + stream->RegisterMem(curr_mem); + stream->RegisterPrim(mkldnn::reorder(*curr_mem, *def_mem)); // def_mem points to a memory region in the temp space. It's only valid // inside an operator. As such, the returned NDArray can only be valid // inside an operator and the shared point doesn't need to do anything // when it's destroyed. - ret.ptr_->mkl_mem_ = std::shared_ptr<mkldnn::memory>(def_mem, - [](mkldnn::memory *mem){}); + auto tmp = std::shared_ptr<mkldnn::memory>(def_mem, [](mkldnn::memory *mem){}); + ret.ptr_->mkl_mem_.reset(new MKLDNNMemory(tmp)); ret.ptr_->shandle.dptr = def_mem->get_data_handle(); ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size(); ret.ptr_->delay_alloc = false; @@ -323,27 +324,13 @@ void NDArray::set_fresh_out_grad(bool state) const { } #if MXNET_USE_MKLDNN == 1 -static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int ndims) { - if (shape.ndim() != (size_t)ndims) - return false; - for (int i = 0; i < ndims; i++) - if (shape[i] != dims[i]) - return false; - return true; -} - -static inline bool same_shape(const TShape &shape, int dtype, mkldnn::memory::desc desc) { - return same_shape(shape, desc.data.dims, desc.data.ndims) - && get_mkldnn_type(dtype) == desc.data.data_type; -} bool NDArray::Chunk::IsMKLDNN() const { if (storage_type != kDefaultStorage) return false; if (mkl_mem_ == nullptr) return false; - auto desc = mkl_mem_->get_primitive_desc().desc(); - return desc.data.format != GetDefaultFormat(desc); + return mkl_mem_->IsMKLDNN(); } bool NDArray::Chunk::IsDefault() const { @@ -353,23 +340,19 @@ bool NDArray::Chunk::IsDefault() const { // format. if (mkl_mem_ == nullptr) return true; - auto desc = mkl_mem_->get_primitive_desc().desc(); - return desc.data.format == GetDefaultFormat(desc); + return !mkl_mem_->IsMKLDNN(); } void NDArray::Chunk::Reorder2Default() { if (mkl_mem_ == nullptr) return; - auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc()); - CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format); + mkldnn_memory_format_t format = mkl_mem_->GetDefaultFormat(); + CHECK_NE(format, mkl_mem_->GetFormat()); - auto def_pd = GetPrimitiveDesc(mkl_mem_->get_primitive_desc(), format); + mkldnn::memory::primitive_desc def_pd = mkl_mem_->GetPrimitiveDesc(format); mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd)); - // This may be called in MKLDNN operators. We can't use MKLDNNStream here. - std::vector<mkldnn::primitive> net; - net.push_back(mkldnn::reorder(*mkl_mem_, *def_mem)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + mkl_mem_->ReorderTo(def_mem.get()); CHECK(shandle.size >= def_pd.get_size()); CheckAndAlloc(def_pd.get_size()); @@ -380,11 +363,11 @@ void NDArray::Chunk::Reorder2Default() { void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) { // If the memory already uses the specified layout, don't do anything. - if (mkl_mem_ != nullptr && mkl_mem_->get_primitive_desc() == pd) + if (mkl_mem_ != nullptr && mkl_mem_->SameFormat(pd)) return; - auto _pd = pd; - auto _desc = _pd.desc(); - auto def_format = GetDefaultFormat(_desc); + mkldnn::memory::primitive_desc _pd = pd; + mkldnn::memory::desc _desc = _pd.desc(); + mkldnn_memory_format_t def_format = GetDefaultFormat(_desc); // If the memory is default, don't do anything. if (def_format == _desc.data.format && IsDefault()) return; @@ -397,10 +380,10 @@ void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd)); std::shared_ptr<mkldnn::memory> old_mem; if (IsDefault()) { - auto def_pd = GetPrimitiveDesc(pd, def_format); + mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(pd, def_format); old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr)); } else { - old_mem = this->mkl_mem_; + old_mem = this->mkl_mem_->GetMem(); } CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims); @@ -413,15 +396,15 @@ void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) CheckAndAlloc(pd.get_size()); // TODO(zhengda) We need to avoid memory copy here. memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size()); - mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr)); + mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); } void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { // The shape of the array and the one of the MKL memory may mismatch. // For example, if the array stores parameters, the MKL memory may store data // in 5 dimensions while the NDArray stores data in 4 dimensions. - if (mkl_mem_ && mkl_mem_->get_data_handle() == shandle.dptr - && same_shape(shape, dtype, mkl_mem_->get_primitive_desc().desc())) { + if (mkl_mem_ && mkl_mem_->GetDataHandle() == shandle.dptr + && mkl_mem_->SameFormat(shape, dtype)) { return; } @@ -459,7 +442,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { } mkldnn::memory::primitive_desc pd(data_md, cpu_engine); CHECK(shandle.size >= pd.get_size()); - mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr)); + mkl_mem_.reset(new MKLDNNMemory(pd, shandle.dptr)); } /* @@ -469,7 +452,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) { */ static inline mkldnn::memory *GetMKLDNNExact( const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) { - auto src_desc = mem->get_primitive_desc(); + mkldnn::memory::primitive_desc src_desc = mem->get_primitive_desc(); if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) { return const_cast<mkldnn::memory *>(mem); } else { @@ -486,16 +469,16 @@ const mkldnn::memory *NDArray::GetMKLDNNData( LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc"; return nullptr; } - auto mem = GetMKLDNNData(); + const mkldnn::memory *mem = GetMKLDNNData(); mkldnn::memory::primitive_desc _desc = desc; - auto desc1 = mem->get_primitive_desc().desc(); - auto desc2 = _desc.desc(); + mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); + mkldnn::memory::desc desc2 = _desc.desc(); // The MKL memory has the same format and shape as required, // or both use the default format, we can return the MKL memory. if (mem->get_primitive_desc() == desc || (desc1.data.format == GetDefaultFormat(desc1) && desc2.data.format == GetDefaultFormat(desc2))) { - return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc); + return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); } else { return nullptr; } @@ -509,7 +492,7 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( } CHECK(storage_type() == kDefaultStorage); - auto mem = GetMKLDNNData(); + const mkldnn::memory *mem = GetMKLDNNData(); // If the memory descriptor matches, it's easy. MKLDNNStream *stream = MKLDNNStream::Get(); if (mem->get_primitive_desc() == desc) { @@ -519,15 +502,15 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder( mkldnn::memory::primitive_desc _desc = desc; // Now we need to determine if we should reorder the memory. // If both use the default formats, we think we don't need to reorder. - auto desc1 = mem->get_primitive_desc().desc(); - auto desc2 = _desc.desc(); + mkldnn::memory::desc desc1 = mem->get_primitive_desc().desc(); + mkldnn::memory::desc desc2 = _desc.desc(); if (desc1.data.format == GetDefaultFormat(desc1) && desc2.data.format == GetDefaultFormat(desc2)) { mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle())); stream->RegisterMem(ret); return ret.get(); } else { - auto ret = TmpMemMgr::Get()->Alloc(desc); + mkldnn::memory *ret = TmpMemMgr::Get()->Alloc(desc); stream->RegisterPrim(mkldnn::reorder(*mem, *ret)); return ret; } @@ -538,18 +521,15 @@ NDArray NDArray::Reorder2Default() const { if (ptr_->mkl_mem_ == nullptr) return *this; - auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc()); - if (format == ptr_->mkl_mem_->get_primitive_desc().desc().data.format) + mkldnn_memory_format_t format = ptr_->mkl_mem_->GetDefaultFormat(); + if (format == ptr_->mkl_mem_->GetFormat()) return *this; NDArray ret(shape(), ctx(), false, dtype()); - auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format); + mkldnn::memory::primitive_desc def_pd = ptr_->mkl_mem_->GetPrimitiveDesc(format); CHECK(ret.ptr_->shandle.size >= def_pd.get_size()); mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr); - // This may be called in MKLDNN operators. We can't use MKLDNNStream here. - std::vector<mkldnn::primitive> net; - net.push_back(mkldnn::reorder(*ptr_->mkl_mem_, def_mem)); - mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + ptr_->mkl_mem_->ReorderTo(&def_mem); return ret; } @@ -584,17 +564,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { if (IsMKLDNNData()) CHECK(!IsView()); ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_); - // If shandle has data, the data in shandle and mkl_mem_ should match. - if (ptr_->shandle.dptr) - CHECK(ptr_->shandle.dptr == ptr_->mkl_mem_->get_data_handle()); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_); - auto pd = ptr_->mkl_mem_->get_primitive_desc(); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); if (IsView()) { + mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(); // Sliced array must use the default layout. CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format); - } - if (IsView()) { - void *off_addr = static_cast<char *>(ptr_->mkl_mem_->get_data_handle()) + void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle()) + byte_offset_; // Create the primitive desc for the new mkldnn memory. @@ -612,13 +587,13 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { MKLDNNStream::Get()->RegisterMem(ret); return ret.get(); } else { - return ptr_->mkl_mem_.get(); + return ptr_->mkl_mem_->GetRaw(); } } void NDArray::CopyFrom(const mkldnn::memory &mem) { CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized"; - if (ptr_->mkl_mem_.get() == &mem) + if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem) return; CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_)) @@ -630,10 +605,10 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { CHECK(!IsView()); ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_); - stream->RegisterMem(ptr_->mkl_mem_); - auto from_desc = mem.get_primitive_desc().desc(); - auto this_desc = ptr_->mkl_mem_->get_primitive_desc().desc(); - auto from_def_format = GetDefaultFormat(from_desc); + stream->RegisterMem(ptr_->mkl_mem_->GetMem()); + mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc(); + mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc(); + mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc); if (IsView()) { // Sliced array must use the default layout. CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format); @@ -652,12 +627,13 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_)); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) { // In this case, the source memory stores data in a customized layout. We // need to reorganize the data in memory before we can reshape. - auto def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), from_def_format); - auto def_mem = TmpMemMgr::Get()->Alloc(def_pd); + mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), + from_def_format); + mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); // Now we can reshape it mkldnn::memory::dims dims(this_desc.data.dims, @@ -668,32 +644,32 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) { mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine()); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_)); - } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->get_primitive_desc()) { + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); + } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) { // If the layout is the same, we can just copy data. - stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_)); + stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw())); } else { - auto src_def = GetDefaultFormat(mem.get_primitive_desc().desc()); - auto dst_def = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc()); + mkldnn_memory_format_t src_def = GetDefaultFormat(mem.get_primitive_desc().desc()); + mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat(); // If both are not using the default layouts. There isn't much we can do, // other than reorder data layout directly. - if (dst_def != ptr_->mkl_mem_->get_primitive_desc().desc().data.format + if (dst_def != ptr_->mkl_mem_->GetFormat() && src_def != mem.get_primitive_desc().desc().data.format) { - stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_)); - } else if (dst_def == ptr_->mkl_mem_->get_primitive_desc().desc().data.format) { + stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw())); + } else if (dst_def == ptr_->mkl_mem_->GetFormat()) { // If the dest mem uses the default memory layout, we can simply use // the default format of the source memory to improve perf of reorder. - auto pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), src_def); - mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->mkl_mem_->get_data_handle())); + mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(src_def); + mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->mkl_mem_->GetDataHandle())); stream->RegisterMem(tmp_mem); stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); } else { // If the src mem uses the default memory layout, we can use // the default format of the source memory to improve perf. - auto pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def); + mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def); mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); stream->RegisterMem(tmp_mem); - stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_)); + stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw())); } } } @@ -710,28 +686,28 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc & } mkldnn::memory::primitive_desc _desc = desc; - auto required_format = _desc.desc().data.format; - auto def_format = GetDefaultFormat(_desc.desc()); + mkldnn_memory_format_t required_format = _desc.desc().data.format; + mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc()); // If the required format is a default format, we don't need to worry about the shape. // If the shape isn't the same, it actually implicitly reshapes data. if (required_format == def_format) { ptr_->SetMKLMem(shape_, dtype_); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_); - return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); + return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); } if (ptr_->mkl_mem_) - CHECK(ptr_->mkl_mem_->get_data_handle() == ptr_->shandle.dptr); - if (ptr_->mkl_mem_ && ptr_->mkl_mem_->get_primitive_desc() == desc) { - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_); - return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc); + CHECK(ptr_->mkl_mem_->GetDataHandle() == ptr_->shandle.dptr); + if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetPrimitiveDesc() == desc) { + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); + return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc); } CHECK(ptr_->shandle.size >= desc.get_size()); ptr_->CheckAndAlloc(desc.get_size()); - ptr_->mkl_mem_.reset(new mkldnn::memory(desc, ptr_->shandle.dptr)); - MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_); - return ptr_->mkl_mem_.get(); + ptr_->mkl_mem_.reset(new MKLDNNMemory(desc, ptr_->shandle.dptr)); + MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem()); + return ptr_->mkl_mem_->GetRaw(); } #endif diff --git a/src/operator/nn/lrn-inl.h b/src/operator/nn/lrn-inl.h index fdae1ec..cb441de 100644 --- a/src/operator/nn/lrn-inl.h +++ b/src/operator/nn/lrn-inl.h @@ -58,8 +58,35 @@ struct LRNParam : public dmlc::Parameter<LRNParam> { DMLC_DECLARE_FIELD(nsize) .describe("normalization window width in elements."); } + + bool operator==(const LRNParam& other) const { + return (this->nsize == other.nsize && + fabs(this->alpha - other.alpha) < 1e-6 && + fabs(this->beta - other.beta) < 1e-6 && + fabs(this->knorm - other.knorm) < 1e-6); + } }; // struct LRNParam +} // namespace op +} // namespace mxnet + +namespace std { +template<> +struct hash<mxnet::op::LRNParam> { + size_t operator()(const mxnet::op::LRNParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.alpha); + ret = dmlc::HashCombine(ret, val.beta); + ret = dmlc::HashCombine(ret, val.knorm); + ret = dmlc::HashCombine(ret, val.nsize); + return ret; + } +}; +} // namespace std + +namespace mxnet { +namespace op { + template<typename xpu> void LRNForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector<TBlob> &in_data, diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 61bef11..489351e 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -334,11 +334,104 @@ const mkldnn::memory *GetWeights(const NDArray &arr, const mkldnn::memory::primitive_desc &target_pd, int num_groups); -mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc); +mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc); mkldnn_memory_format_t GetDefaultFormat(int num_dims); mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd, mkldnn_memory_format_t format); +inline bool same_shape(const TShape &shape, const mkldnn_dims_t dims, int ndims) { + if (shape.ndim() != (size_t)ndims) + return false; + for (int i = 0; i < ndims; i++) + if (shape[i] != dims[i]) + return false; + return true; +} + +inline bool same_shape(const TShape &shape, int dtype, + const mkldnn::memory::desc &desc) { + return same_shape(shape, desc.data.dims, desc.data.ndims) + && get_mkldnn_type(dtype) == desc.data.data_type; +} + +/* + * There is a large overhead of getting mkldnn::memory::primitive_desc from + * mkldnn::memory. This class is created to cache the metadata of mkldnn memory + * to provide a much more lightweight method to access them. + */ +class MKLDNNMemory { + std::shared_ptr<mkldnn::memory> mem; + mkldnn::memory::desc desc; + size_t size; // The number of bytes. + + public: + MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) { + mem.reset(new mkldnn::memory(pd, addr)); + size = pd.get_size(); + } + + explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc( + mem->get_primitive_desc().desc()) { + this->mem = mem; + mkldnn::memory::primitive_desc pd = mem->get_primitive_desc(); + size = pd.get_size(); + } + + void SetDataHandle(void *handle) { + mem->set_data_handle(handle); + } + + void *GetDataHandle() const { + return mem->get_data_handle(); + } + + std::shared_ptr<mkldnn::memory> GetMem() const { + return mem; + } + + mkldnn::memory *GetRaw() const { + return mem.get(); + } + + size_t GetSize() const { + return size; + } + + mkldnn::memory::primitive_desc GetPrimitiveDesc() const { + return mem->get_primitive_desc(); + } + + mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn_memory_format_t format) const { + return mxnet::GetPrimitiveDesc(mem->get_primitive_desc(), format); + } + + mkldnn_memory_format_t GetDefaultFormat() const { + return mxnet::GetDefaultFormat(desc); + } + + mkldnn_memory_format_t GetFormat() const { + return desc.data.format; + } + + bool IsMKLDNN() const { + return GetFormat() != GetDefaultFormat(); + } + + bool SameFormat(mkldnn::memory::primitive_desc pd) const { + return mem->get_primitive_desc() == pd; + } + + bool SameFormat(const TShape &shape, int dtype) const { + return same_shape(shape, dtype, desc); + } + + void ReorderTo(mkldnn::memory *other) const { + std::vector<mkldnn::primitive> net; + net.push_back(mkldnn::reorder(*mem, *other)); + mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait(); + } +}; + void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector<NDArray> &inputs, diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index 820cca1..684abd2 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -190,7 +190,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) { } } -mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) { +mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { if (desc.data.ndims == 1) { return desc.data.format; } else if (desc.data.ndims == 2) { diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 16f9874..d1c80a6 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -234,20 +234,15 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, DType* weight_ptr = gamma.data().dptr<DType>(); DType* bias_ptr = beta.data().dptr<DType>(); if (!param.fix_gamma) { -#pragma omp parallel for - for (int i = 0; i < channels_; i++) { - weight_buf[i] = weight_ptr[i]; - weight_buf[channels_ + i] = bias_ptr[i]; // bias - } + memcpy(weight_buf, weight_ptr, sizeof(weight_buf[0]) * channels_); + memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_); } else if (IsBNWriting(req[batchnorm::kGamma])) { -#pragma omp parallel for for (int i = 0; i < channels_; i++) { weight_buf[i] = (DType)1.0f; weight_ptr[i] = (DType)1.0f; weight_buf[channels_ + i] = bias_ptr[i]; // bias } } else { -#pragma omp parallel for for (int i = 0; i < channels_; i++) { weight_buf[i] = (DType)1.0f; weight_buf[channels_ + i] = bias_ptr[i]; // bias @@ -260,7 +255,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr<DType>(); DType* invar = aux_states[batchnorm::kMovingVar].data().dptr<DType>(); // to align with origin implmentation: batch_norm.cc: L164 -#pragma omp parallel for for (int i = 0; i < channels_; i++) { omean[i] = inmean[i]; ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); @@ -282,14 +276,13 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, MKLDNNStream::Get()->Submit(); DType* mean_mem_ptr = reinterpret_cast<DType*>(fwd.GetMean().get_data_handle()); DType* var_mem_ptr = reinterpret_cast<DType*>(fwd.GetVar().get_data_handle()); -#pragma omp parallel for for (int i = 0; i < channels_; i++) { omean[i] = mean_mem_ptr[i]; ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps); } } } else { // no input gamma and beta - LOG(FATAL) << "MKLDNN batch normalization: should not reach here ..."; + LOG(FATAL) << "MKLDNN batch normalization: should not reach here ..."; } } diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc index d3e6e77..240673d 100644 --- a/src/operator/nn/mkldnn/mkldnn_concat.cc +++ b/src/operator/nn/mkldnn/mkldnn_concat.cc @@ -30,6 +30,67 @@ namespace mxnet { namespace op { +class MKLDNNConcatFwd { + std::shared_ptr<mkldnn::concat> fwd; + std::vector<std::shared_ptr<mkldnn::memory>> data; + std::vector<mkldnn::primitive::at> data_mem; + std::shared_ptr<mkldnn::memory> out; + + public: + mkldnn::concat::primitive_desc fwd_pd; + + MKLDNNConcatFwd( + int concat_dim, + const std::vector<mkldnn::memory::primitive_desc> &data_md): fwd_pd(concat_dim, data_md) { + data.resize(data_md.size()); + } + + void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, + const mkldnn::memory &output) { + CHECK_EQ(in_data.size(), data.size()); + for (size_t i = 0; i < data.size(); i++) { + if (this->data[i] == nullptr) { + this->data[i] = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( + in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle())); + this->data_mem.push_back(*this->data[i]); + } else { + this->data[i]->set_data_handle(in_data[i]->get_data_handle()); + } + } + if (this->out == nullptr) + this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( + fwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (this->fwd == nullptr) + fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out)); + } + + const mkldnn::concat &GetFwd() const { + return *fwd; + } +}; + +static MKLDNNConcatFwd &GetConcatForward( + int concat_dim, const std::vector<NDArray> &in_data, + const std::vector<mkldnn::memory::primitive_desc> &data_md) { + static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds; + OpSignature key; + key.AddSign(concat_dim); + key.AddSign(in_data); + + auto it = fwds.find(key); + if (it == fwds.end()) { + MKLDNNConcatFwd fwd(concat_dim, data_md); + auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>( + key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector<NDArray> &in_data, const std::vector<OpReqType> &req, @@ -39,18 +100,21 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, int num_in_data = param.num_args; int concat_dim = param.dim; std::vector<mkldnn::memory::primitive_desc> data_md; - std::vector<mkldnn::primitive::at> data_mem; + std::vector<const mkldnn::memory *> data_mem; + data_md.reserve(num_in_data); + data_mem.reserve(num_in_data); for (int i =0; i < num_in_data; i++) { - auto tmp_mem = in_data[i].GetMKLDNNData(); - auto tmp_pd = tmp_mem->get_primitive_desc(); - data_md.push_back(tmp_pd); - data_mem.push_back(*tmp_mem); + const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData(); + mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc(); + data_md.push_back(tmp_pd); + data_mem.push_back(tmp_mem); } - mkldnn::concat::primitive_desc fwd_pd(concat_dim, data_md); - auto engine = CpuEngine::Get()->get_engine(); - auto out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut], - fwd_pd.dst_primitive_desc(), req[concat_enum::kOut]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::concat(fwd_pd, data_mem, *out_mem.second)); + MKLDNNConcatFwd &fwd = GetConcatForward(concat_dim, in_data, data_md); + mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data[concat_enum::kOut], + fwd.fwd_pd.dst_primitive_desc(), + req[concat_enum::kOut]); + fwd.SetNewMem(data_mem, *out_mem.second); + MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); CommitOutput(out_data[concat_enum::kOut], out_mem); MKLDNNStream::Get()->Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 9a9bf62..b0b715a 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -26,6 +26,7 @@ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_LRN_INL_H_ #if MXNET_USE_MKLDNN == 1 +#include <utility> #include <mkldnn.hpp> #include "../lrn-inl.h" #include "./mkldnn_base-inl.h" @@ -39,11 +40,11 @@ inline algorithm GetMKLDNNLRNAlgo(const LRNParam ¶m) { return algorithm::lrn_across_channels; } -inline lrn_forward::primitive_desc GetLRNFwd(const LRNParam ¶m, - const bool is_train, - const memory::desc &src_md) { - const auto engine = CpuEngine::Get()->get_engine(); - const auto alg = GetMKLDNNLRNAlgo(param); +inline lrn_forward::primitive_desc GetLRNFwdDesc(const LRNParam ¶m, + const bool is_train, + const memory::desc &src_md) { + mkldnn::engine &engine = CpuEngine::Get()->get_engine(); + const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; @@ -63,8 +64,8 @@ GetLRNBwd(const LRNParam ¶m, const mkldnn::memory::desc &diff_in_md, const mkldnn::memory::desc &diff_md, const lrn_forward::primitive_desc &lrnFwd_desc) { - const auto engine = CpuEngine::Get()->get_engine(); - const auto alg = GetMKLDNNLRNAlgo(param); + mkldnn::engine &engine = CpuEngine::Get()->get_engine(); + const algorithm alg = GetMKLDNNLRNAlgo(param); const float alpha = param.alpha; const float beta = param.beta; const int nsize = param.nsize; @@ -76,28 +77,113 @@ GetLRNBwd(const LRNParam ¶m, engine, lrnFwd_desc); } + +typedef ParamOpSign<LRNParam> MKLDNNLRNSignature; + +// LRN Forward Class +class MKLDNNLRNFwd { + public: + MKLDNNLRNFwd(const LRNParam& param, + bool is_train, + const NDArray &in_data): + is_train(is_train) { + _Init(param, is_train, in_data); + } + + ~MKLDNNLRNFwd() {} + + void SetDataHandle(const NDArray &data, + const NDArray &output); + + void Execute(); + + private: + std::shared_ptr<mkldnn::lrn_forward> fwd; + std::shared_ptr<mkldnn::memory> in_mem; + std::shared_ptr<mkldnn::memory> out_mem; + std::shared_ptr<mkldnn::memory> ws_mem; + bool is_train; + + private: + void _Init(const LRNParam ¶m, bool is_train, const NDArray &in_data); +}; // End of LRN Forword Class + +void MKLDNNLRNFwd::_Init(const LRNParam ¶m, + bool is_train, + const NDArray &in_data) { + mkldnn::memory::desc in_data_md = in_data.GetMKLDNNData()->get_primitive_desc().desc(); + lrn_forward::primitive_desc fwd_pd = GetLRNFwdDesc(param, is_train, in_data_md); + + this->in_mem.reset(new mkldnn::memory(in_data.GetMKLDNNData() + ->get_primitive_desc())); + this->out_mem.reset(new mkldnn::memory(fwd_pd.dst_primitive_desc())); + if (is_train) { + // If it's training, we have to create a workspace memory. Otherwise, MKLDNN + // will have segmentation fault. + ws_mem.reset(new mkldnn::memory(fwd_pd.workspace_primitive_desc())); + this->fwd = std::shared_ptr<mkldnn::lrn_forward>( + new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*this->in_mem), + *this->ws_mem, *this->out_mem)); + } else { + this->fwd = std::shared_ptr<mkldnn::lrn_forward>( + new mkldnn::lrn_forward(fwd_pd, mkldnn::primitive::at(*(this->in_mem)), + *(this->out_mem))); + } +} + +void MKLDNNLRNFwd::SetDataHandle(const NDArray &in_data, + const NDArray &out_data) { + const mkldnn::memory *in_data_mem = in_data.GetMKLDNNData(); + mkldnn::memory *out_data_mem = const_cast<NDArray&>(out_data).CreateMKLDNNData( + this->out_mem->get_primitive_desc()); + this->in_mem->set_data_handle(in_data_mem->get_data_handle()); + this->out_mem->set_data_handle(out_data_mem->get_data_handle()); +} + +void MKLDNNLRNFwd::Execute() { + MKLDNNStream::Get()->RegisterPrim(*(this->fwd)); + MKLDNNStream::Get()->Submit(); +} +// End of LRN Class and its functions + +static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, + const OpContext &ctx, + const NDArray &in_data) { + static thread_local std::unordered_map<MKLDNNLRNSignature, + MKLDNNLRNFwd, + OpHash> lrn_fwds; + auto alg_ = algorithm::lrn_across_channels; + auto kind_ = prop_kind::forward_training; + if (ctx.is_train) { + kind_ = prop_kind::forward_training; + } else { + kind_ = prop_kind::forward_scoring; + } + + MKLDNNLRNSignature key(param); + key.AddSign(alg_); + key.AddSign(kind_); + key.AddSign(in_data); + + auto it = lrn_fwds.find(key); + if (it == lrn_fwds.end()) { + MKLDNNLRNFwd fwd(param, ctx.is_train, in_data); + auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd> + (key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNLRNForward(const OpContext &ctx, const LRNParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data) { - auto src_mem = in_data.GetMKLDNNData(); - const auto src_md = src_mem->get_primitive_desc().desc(); - const auto pdesc = GetLRNFwd(param, ctx.is_train, src_md); - auto dst_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData( - pdesc.dst_primitive_desc()); - if (ctx.is_train) { - std::shared_ptr<const mkldnn::memory> ws_mem( - new mkldnn::memory(pdesc.workspace_primitive_desc())); - MKLDNNStream::Get()->RegisterPrim( - lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), - *ws_mem, *dst_mem)); - MKLDNNStream::Get()->Submit(); - } else { - MKLDNNStream::Get()->RegisterPrim( - lrn_forward(pdesc, mkldnn::primitive::at(*src_mem), *dst_mem)); - MKLDNNStream::Get()->Submit(); - } + MKLDNNLRNFwd fwd = GetLRNFwd(param, ctx, in_data); + fwd.SetDataHandle(in_data, out_data); + fwd.Execute(); } void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, @@ -109,9 +195,10 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, return; } // Repeat FW for getting workspace - auto data_mem = in_data.GetMKLDNNData(); - const auto data_md = data_mem->get_primitive_desc().desc(); - const auto pdesc_fwd = GetLRNFwd(param, ctx.is_train, data_md); + const mkldnn::memory *data_mem = in_data.GetMKLDNNData(); + const mkldnn::memory::desc data_md = data_mem->get_primitive_desc().desc(); + const lrn_forward::primitive_desc pdesc_fwd = GetLRNFwdDesc(param, ctx.is_train, + data_md); // TODO(Patric): To keep the function stateless, we can't pass workspace // from LRN forward to backward. We have to re-compute @@ -125,12 +212,13 @@ void MKLDNNLRNBackward(const OpContext &ctx, const LRNParam ¶m, lrn_forward(pdesc_fwd, mkldnn::primitive::at(*data_mem), *ws_mem, *dst_temp)); - const auto data_in_md = pdesc_fwd.src_primitive_desc().desc(); - auto diff_mem = out_grad.GetMKLDNNData(); - const auto diff_md = diff_mem->get_primitive_desc().desc(); - const auto pdesc_bwd = GetLRNBwd(param, data_in_md, diff_md, pdesc_fwd); - auto diff_src_mem = CreateMKLDNNMem(in_grad, - pdesc_bwd.diff_src_primitive_desc(), req); + const mkldnn::memory::desc data_in_md = pdesc_fwd.src_primitive_desc().desc(); + const mkldnn::memory *diff_mem = out_grad.GetMKLDNNData(); + const mkldnn::memory::desc diff_md = diff_mem->get_primitive_desc().desc(); + const mkldnn::lrn_backward::primitive_desc pdesc_bwd = GetLRNBwd(param, data_in_md, + diff_md, pdesc_fwd); + mkldnn_output_t diff_src_mem = CreateMKLDNNMem(in_grad, + pdesc_bwd.diff_src_primitive_desc(), req); MKLDNNStream::Get()->RegisterPrim( lrn_backward(pdesc_bwd, mkldnn::primitive::at(*data_mem), -- To stop receiving notification emails like this one, please contact j...@apache.org.