This is an automated email from the ASF dual-hosted git repository.
taolv pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push:
new ad1ff3a [v1.6.x] Cherry-pick MKL-DNN Rnn operator enhancements to
v1.6.x (#17225)
ad1ff3a is described below
commit ad1ff3aa8532f2f7e42a732f0d35dfb0574ca05c
Author: Zixuan Wei <[email protected]>
AuthorDate: Tue Jan 7 09:50:53 2020 +0800
[v1.6.x] Cherry-pick MKL-DNN Rnn operator enhancements to v1.6.x (#17225)
* [MKLDNN] mkldnn RNN operator enhancement (#17075)
* mkldnn rnn operator enhancement
`add` operation support
Rename AddTo
Add MXNET_USE_MKLDNN_RNN env
Add Env var for switching to naive RNN impl and naive add/copy impl
* Re-run CI, op:test_reduce failed on Unix-CPU
* Rerun CI, Python2 CPU on Unix-CPU timeout
* MKL-DNN RNN backward path enhancement (#17183)
* Flush memory before RNN backward primitive
* Add gluon rnn unit test for gradients check
* Cache reorder
* Re-write rnn supporting check
* Update OpSignature.AddSign to avoid potential hash collision for
rnn-packed memory
Get the data type from mkldnn memory descriptor when setting grad handle
---
docs/static_site/src/pages/api/faq/env_var.md | 12 +-
src/common/utils.h | 20 +-
src/operator/nn/mkldnn/mkldnn_base-inl.h | 9 +-
src/operator/nn/mkldnn/mkldnn_rnn-inl.h | 38 ++-
src/operator/nn/mkldnn/mkldnn_rnn.cc | 430 +++++++++++++++-----------
src/operator/operator_common.h | 18 ++
src/operator/rnn.cc | 13 +-
tests/python/unittest/test_gluon_rnn.py | 124 ++++++++
tests/python/unittest/test_operator.py | 21 +-
9 files changed, 466 insertions(+), 219 deletions(-)
diff --git a/docs/static_site/src/pages/api/faq/env_var.md
b/docs/static_site/src/pages/api/faq/env_var.md
index e4fe58a..d63da61 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -283,11 +283,11 @@ If ctypes is used, it must be
`mxnet._ctypes.ndarray.NDArrayBase`.
If no such algorithm exists given other constraints, MXNet will error out.
This variable affects the choice
of CUDNN convolution algorithms. Please see [CUDNN developer
guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html)
for more details.
-* MXNET_CPU_PARALLEL_COPY_SIZE
+* MXNET_CPU_PARALLEL_SIZE
- Values: Int ```(default=200000)```
- - The minimum size to call parallel copy by OpenMP in CPU2CPU mode.
- - When the array size is bigger than or equal to this threshold,
NDArray::Copy(from, to) is implemented by OpenMP with the Recommended OMP
Thread Count.
- - When the array size is less than this threshold, NDArray::Copy(from , to))
is implemented by memcpy in single thread.
+ - The minimum size to call parallel operations by OpenMP for CPU context.
+ - When the array size is bigger than or equal to this threshold, the
operation implemented by OpenMP is executed with the Recommended OMP Thread
Count.
+ - When the array size is less than this threshold, the operation is
implemented naively in single thread.
* MXNET_OPTIMIZER_AGGREGATION_SIZE
- Values: Int ```(default=4)```
@@ -343,6 +343,10 @@ If ctypes is used, it must be
`mxnet._ctypes.ndarray.NDArrayBase`.
- Values: 0(false) or 1(true) ```(default=1)```
- If this variable is set, MXNet will simplify the computation graph,
eliminating duplicated operations on the same inputs.
+* MXNET_USE_MKLDNN_RNN
+ - Values: 0(false) or 1(true) ```(default=1)```
+ - This variable controls whether to use the MKL-DNN backend in fused RNN
operator for CPU context. There are two fusion implementations of RNN operator
in MXNet. The MKL-DNN implementation has a better performance than the naive
one, but the latter is more stable in the backward operation currently.
+
Settings for Minimum Memory Usage
---------------------------------
- Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/src/common/utils.h b/src/common/utils.h
index 0e3e354..fcb61b7 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -760,7 +760,7 @@ inline void EmplaceBackZeros(const NDArrayStorageType
stype, const mxnet::TShape
*/
template<typename DType>
inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
- static index_t copy_block_size =
dmlc::GetEnv("MXNET_CPU_PARALLEL_COPY_SIZE", 200000);
+ static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE",
200000);
if (size >= copy_block_size) {
#pragma omp parallel for
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (index_t i = 0; i < size; ++i) {
@@ -772,6 +772,24 @@ inline void ParallelCopy(DType* dst, const DType* src,
index_t size) {
}
/*!
+ * \breif parallelize add by OpenMP
+ */
+template<typename DType>
+inline void ParallelAdd(DType* dst, const DType* src, index_t size) {
+ static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE",
200000);
+ if (size >= add_block_size) {
+ #pragma omp parallel for
num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+ for (index_t i = 0; i < size; ++i) {
+ dst[i] += src[i];
+ }
+ } else {
+ for (index_t i = 0; i < size; ++i) {
+ dst[i] += src[i];
+ }
+ }
+}
+
+/*!
* \brief If numpy compatibility is turned off (default), the shapes passed in
* by users follow the legacy shape definition:
* 1. 0 ndim means the shape is completely unknown.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 9bfc20c..9763c42 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -132,9 +132,12 @@ static inline bool SupportMKLDNN(int dtype, const
mxnet::TShape &shape) {
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}
-static inline bool SupportMKLDNNRNN(const NDArray &input) {
- int ndim = input.shape().ndim();
- return (input.dtype() == mshadow::kFloat32) && (ndim == 3);
+static inline bool SupportMKLDNNRnn(const NDArray &input) {
+ if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
+ && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
+ return true;
+ }
+ return false;
}
static inline bool SupportMKLDNNQuantize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
index ad3f733..a4104bf 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
@@ -120,25 +120,24 @@ class RnnPrimitive {
template<typename rnn_fwd, typename... Args>
static RnnPrimitive Create(Args&&... args) {
RnnPrimitive rnn_fwd_prim;
- rnn_fwd_prim.pd_.reset(
- new typename rnn_fwd::desc(std::forward<Args>(args)...),
- [](typename rnn_fwd::desc* pd) {
- delete reinterpret_cast<typename rnn_fwd::desc*>(pd);
+ auto fwd_desc = typename rnn_fwd::desc(std::forward<Args>(args)...);
+ rnn_fwd_prim.fwd_pd_.reset(
+ new typename rnn_fwd::primitive_desc(fwd_desc,
CpuEngine::Get()->get_engine()),
+ [](typename rnn_fwd::primitive_desc* pd) {
+ delete reinterpret_cast<typename rnn_fwd::primitive_desc*>(pd);
});
- const typename rnn_fwd::desc& fwd_desc =
- *(reinterpret_cast<typename rnn_fwd::desc*>(rnn_fwd_prim.pd_.get()));
- typename rnn_fwd::primitive_desc fwd_pd(fwd_desc,
CpuEngine::Get()->get_engine());
- rnn_fwd_prim.weights_layer_desc_ = fwd_pd.weights_layer_desc();
- rnn_fwd_prim.weights_iter_desc_ = fwd_pd.weights_iter_desc();
- rnn_fwd_prim.workspace_desc_ = fwd_pd.workspace_desc();
+ auto fwd_pd = reinterpret_cast<typename
rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
+ rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
+ rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc();
+ rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
- rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new
rnn_fwd(fwd_pd));
+ rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new
rnn_fwd(*fwd_pd));
return rnn_fwd_prim;
}
RnnPrimitive() {
- this->pd_ = nullptr;
+ this->fwd_pd_ = nullptr;
this->primitive_ = nullptr;
this->weights_layer_desc_ = mkldnn::memory::desc();
this->weights_iter_desc_ = mkldnn::memory::desc();
@@ -146,7 +145,7 @@ class RnnPrimitive {
}
RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) {
- this->pd_ = rnn_fwd_prim.pd_;
+ this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -155,7 +154,7 @@ class RnnPrimitive {
RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) {
if (this != &rnn_fwd_prim) {
- this->pd_ = rnn_fwd_prim.pd_;
+ this->fwd_pd_ = rnn_fwd_prim.fwd_pd_;
this->primitive_ = rnn_fwd_prim.primitive_;
this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
@@ -165,7 +164,7 @@ class RnnPrimitive {
return *this;
}
- const void* GetPrimDesc() const { return pd_.get(); }
+ const void* GetPrimDesc() const { return fwd_pd_.get(); }
const mkldnn::primitive& GetPrim() const { return *primitive_; }
const mkldnn::memory::desc& GetLayerDesc() const {
@@ -181,7 +180,7 @@ class RnnPrimitive {
}
private:
- std::shared_ptr<void> pd_;
+ std::shared_ptr<void> fwd_pd_;
std::shared_ptr<mkldnn::primitive> primitive_;
mkldnn::memory::desc weights_layer_desc_;
mkldnn::memory::desc weights_iter_desc_;
@@ -370,7 +369,10 @@ class MKLDNNRnnBackward {
void SetDataGradsMem(void* diff_src, void* diff_state, void* diff_statecell,
void* diff_out, void* diff_state_out, void*
diff_statecell_out,
const int dtype = mshadow::kFloat32);
- void CommitWeightsDiff(void* diff_weights, void* diff_bias, const int dtype
= mshadow::kFloat32);
+ void SetNativeWeightsGrads() const;
+ void CommitWeightsGrads(void* diff_weights, void* diff_bias,
+ const OpReqType req,
+ const int dtype = mshadow::kFloat32);
const mkldnn::primitive& GetBwd() const { return *bwd_.primitive_; }
const mkldnn_args_map_t& GetArgsMap() const { return net_args_; }
@@ -385,6 +387,8 @@ class MKLDNNRnnBackward {
mkldnn_shared_mem_t diff_weights_layer_;
mkldnn_shared_mem_t diff_weights_iter_;
+ mkldnn_shared_mem_t diff_weights_layer_r_;
+ mkldnn_shared_mem_t diff_weights_iter_r_;
mkldnn_shared_mem_t diff_bias_;
mkldnn_args_map_t net_args_;
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc
b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index e797b64..8af0e99 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -213,13 +213,13 @@ RnnBwdPrimitive GetRnnBwdPrim(const
MKLDNNRnnForwardTraining &fwd,
auto dst_state_desc = layer_param.state_outputs ? memory::desc(
layer_param.state_dims, data_type, tag::ldnc) : memory::desc();
- const void* fwd_desc = fwd.GetPrimDesc();
+ const void* fwd_pd = fwd.GetPrimDesc();
auto bwd = RnnBwdPrimitive();
switch (mode) {
case rnn_enum::kLstm: {
- const lstm_forward::primitive_desc* desc =
- reinterpret_cast<const lstm_forward::primitive_desc*>(fwd_desc);
- bwd = RnnBwdPrimitive::Create<lstm_forward, lstm_backward>(*desc,
+ const lstm_forward::primitive_desc* pd =
+ reinterpret_cast<const lstm_forward::primitive_desc*>(fwd_pd);
+ bwd = RnnBwdPrimitive::Create<lstm_forward, lstm_backward>(*pd,
prop, mkldnn_rnn_direction,
// data desc
src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc,
@@ -231,9 +231,9 @@ RnnBwdPrimitive GetRnnBwdPrim(const
MKLDNNRnnForwardTraining &fwd,
dst_state_desc);
} break;
case rnn_enum::kGru: {
- const lbr_gru_forward::primitive_desc* desc =
- reinterpret_cast<const lbr_gru_forward::primitive_desc*>(fwd_desc);
- bwd = RnnBwdPrimitive::Create<lbr_gru_forward, lbr_gru_backward>(*desc,
+ const lbr_gru_forward::primitive_desc* pd =
+ reinterpret_cast<const lbr_gru_forward::primitive_desc*>(fwd_pd);
+ bwd = RnnBwdPrimitive::Create<lbr_gru_forward, lbr_gru_backward>(*pd,
prop, mkldnn_rnn_direction,
// data desc
src_layer_desc, src_state_desc, weight_layer_desc,
@@ -244,10 +244,10 @@ RnnBwdPrimitive GetRnnBwdPrim(const
MKLDNNRnnForwardTraining &fwd,
} break;
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh: {
- const vanilla_rnn_forward::primitive_desc* desc =
- reinterpret_cast<const
vanilla_rnn_forward::primitive_desc*>(fwd_desc);
+ const vanilla_rnn_forward::primitive_desc* pd =
+ reinterpret_cast<const vanilla_rnn_forward::primitive_desc*>(fwd_pd);
bwd = RnnBwdPrimitive::Create<vanilla_rnn_forward, vanilla_rnn_backward>(
- *desc, prop,
+ *pd, prop,
mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh :
algorithm::eltwise_relu,
mkldnn_rnn_direction,
// data desc
@@ -364,18 +364,38 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx,
void* cx,
}
}
+inline void MKLDNNMemoryReorder(const mkldnn::memory& src,
+ const mkldnn::memory& dst) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<OpSignature,
+ mkldnn::reorder, OpHash> reorderPrimitives;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<OpSignature,
+ mkldnn::reorder, OpHash> reorderPrimitives;
+#endif
+ OpSignature key{};
+ key.AddSign(src);
+ key.AddSign(dst);
+
+ auto it = reorderPrimitives.find(key);
+ if (it == reorderPrimitives.end()) {
+ auto reorder = mkldnn::reorder(src, dst);
+ it = AddToCache(&reorderPrimitives, key, reorder);
+ }
+
+ mkldnn_args_map_t net_args;
+ net_args.emplace(MKLDNN_ARG_SRC, src);
+ net_args.emplace(MKLDNN_ARG_DST, dst);
+ MKLDNNStream::Get()->RegisterPrimArgs(it->second, net_args);
+}
+
/*
* Reorder the concatenated weights memory to a efficient memory block
* with primitive-prefered format.
*/
void MKLDNNRnnForward::ReorderWeights() {
- auto& cpu_engine = CpuEngine::Get()->get_engine();
- mkldnn::stream s(cpu_engine);
- mkldnn::reorder(*weights_layer_r_, *weights_layer_)
- .execute(s, *weights_layer_r_, *weights_layer_);
- mkldnn::reorder(*weights_iter_r_, *weights_iter_)
- .execute(s, *weights_iter_r_, *weights_iter_);
- s.wait();
+ MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_);
+ MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_);
}
void AdjustGruGateOrder(char* weight,
@@ -394,7 +414,7 @@ void AdjustGruGateOrder(char* weight,
* Fuse uni-directional bias among single layer.
*/
template <typename DType>
-void FuseBias(DType* fuse_bias, DType* naive_bias,
+void FuseBias(DType* fuse_bias, DType* native_bias,
const int mode, const size_t state_size) {
const size_t ngates = GetRnnGatesNum(mode);
const int omp_threads =
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
@@ -403,8 +423,8 @@ void FuseBias(DType* fuse_bias, DType* naive_bias,
// OpenMP 'for' statement.
const int state_size_ = static_cast<int>(state_size);
const int single_b_sz = static_cast<int>(nbias * state_size);
- DType* bx = naive_bias;
- DType* bh = naive_bias + state_size * ngates;
+ DType* bx = native_bias;
+ DType* bh = native_bias + state_size * ngates;
if (mode == rnn_enum::kGru) {
// While mxnet gru gate order is reset, update and new gates,
// mkldnn gru gate order is update, reset and new gates. So
@@ -528,12 +548,6 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
void *w_ptr, void *b_
}
}
}
- // Reorder after adjustment only when is_train == false. When is_train ==
true, i.e.
- // in forward training path, we use plain memory (ldxxx) as the space for
weights and
- // their gradients. Then, forward training primitives could fetch them from
the scope
- // of forward inference. And from there, we don't need to reorder the plain
memory to
- // the optimal rnn-packed memory for forward inference.
- ReorderWeights();
// Process bias
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
@@ -553,7 +567,15 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
void *w_ptr, void *b_
EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER,
this->weights_iter_);
EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS, this->bias_);
- initialized_ = true;
+ if (!is_train) {
+ // Reorder after adjustment only when is_train == false. When is_train ==
true, i.e.
+ // in forward training path, we use plain memory (ldxxx) as the space for
weights and
+ // their gradients. Then, forward training primitives could fetch them
from the scope
+ // of forward inference. And from there, we don't need to reorder the
plain memory to
+ // the optimal rnn-packed memory for forward inference.
+ ReorderWeights();
+ initialized_ = true;
+ }
}
void MKLDNNRnnForwardTraining::SetTrnMem(const MKLDNNRnnForward& fwd) {
@@ -572,17 +594,14 @@ void MKLDNNRnnForwardTraining::SetTrnMem(const
MKLDNNRnnForward& fwd) {
if (fwd.weights_layer_r_->get_desc() == fwd_trn_.GetLayerDesc()) {
weights_layer_->set_data_handle(fwd.weights_layer_r_->get_data_handle());
} else {
- mkldnn::reorder(*fwd.weights_layer_r_, *weights_layer_)
- .execute(s, *fwd.weights_layer_r_, *weights_layer_);
+ MKLDNNMemoryReorder(*fwd.weights_layer_r_, *weights_layer_);
}
if (fwd.weights_iter_r_->get_desc() == fwd_trn_.GetIterDesc()) {
weights_iter_->set_data_handle(fwd.weights_iter_r_->get_data_handle());
} else {
- mkldnn::reorder(*fwd.weights_iter_r_, *weights_iter_)
- .execute(s, *fwd.weights_iter_r_, *weights_iter_);
+ MKLDNNMemoryReorder(*fwd.weights_iter_r_, *weights_iter_);
}
- s.wait();
// bias are always in format_tag::ldgo
this->bias_ = fwd.bias_;
@@ -687,18 +706,17 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
{fwd->GetParam().dst_dims, get_mkldnn_type(data_dtype),
format_tag::tnc}));
}
- initialized_ = true;
+ if (!is_training) initialized_ = true;
}
void MKLDNNRnnBackward::FetchDataWeightsMem(const MKLDNNRnnForwardTraining&
fwd) {
using memory = mkldnn::memory;
auto& cpu_engine = CpuEngine::Get()->get_engine();
- auto s = mkldnn::stream(cpu_engine);
- if (this->weights_layer_ == nullptr)
+ if (this->weights_layer_ == nullptr || this-> weights_iter_ == nullptr) {
this->weights_layer_ = mkldnn_shared_mem_t(new
memory(bwd_.weights_layer_desc_, cpu_engine));
- if (this->weights_iter_ == nullptr)
this->weights_iter_ = mkldnn_shared_mem_t(new
memory(bwd_.weights_iter_desc_, cpu_engine));
+ }
for (auto& kv : fwd.net_args_) {
const mkldnn::memory* valid_mem;
@@ -707,17 +725,15 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const
MKLDNNRnnForwardTraining& fwd)
if (bwd_.weights_layer_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
this->weights_layer_->set_data_handle(kv.second.get_data_handle());
} else {
- mkldnn::reorder(*fwd.weights_layer_, *this->weights_layer_)
- .execute(s, *fwd.weights_layer_, *this->weights_layer_);
+ MKLDNNMemoryReorder(*fwd.weights_layer_, *this->weights_layer_);
}
valid_mem = this->weights_layer_.get();
} break;
case MKLDNN_ARG_WEIGHTS_ITER: {
- if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetLayerDesc()) {
+ if (bwd_.weights_iter_desc_ == fwd.fwd_trn_.GetIterDesc()) {
this->weights_iter_->set_data_handle(kv.second.get_data_handle());
} else {
- mkldnn::reorder(*fwd.weights_iter_, *this->weights_iter_)
- .execute(s, *fwd.weights_iter_, *this->weights_iter_);
+ MKLDNNMemoryReorder(*fwd.weights_iter_, *this->weights_iter_);
}
valid_mem = this->weights_iter_.get();
} break;
@@ -727,20 +743,50 @@ void MKLDNNRnnBackward::FetchDataWeightsMem(const
MKLDNNRnnForwardTraining& fwd)
}
EmplaceNetArgs(&this->net_args_, kv.first, valid_mem);
}
- s.wait();
}
void MKLDNNRnnBackward::SetWeightsGradsMem() {
- auto& cpu_engine = CpuEngine::Get()->get_engine();
- if (this->diff_weights_layer_ == nullptr)
- this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
- bwd_.diff_weights_layer_desc_, cpu_engine);
- if (this->diff_weights_iter_ == nullptr)
- this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
- bwd_.diff_weights_iter_desc_, cpu_engine);
- if (this->diff_bias_ == nullptr)
+ using tag = mkldnn::memory::format_tag;
+
+ if (this->diff_weights_layer_ == nullptr
+ || this->diff_weights_iter_ == nullptr
+ || this->diff_bias_ == nullptr) {
+ const auto& cpu_engine = CpuEngine::Get()->get_engine();
+ const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam();
+ const auto mkldnn_type = static_cast<mkldnn::memory::data_type>(
+ bwd_.diff_weights_layer_desc_.data.data_type);
+
+ auto native_layer_desc = mkldnn::memory::desc(param.weight_layer_dims,
mkldnn_type, tag::ldgoi);
+ auto native_iter_desc = mkldnn::memory::desc(param.weight_iter_dims,
mkldnn_type, tag::ldgoi);
+
+ this->diff_weights_layer_r_ = std::make_shared<mkldnn::memory>(
+ native_layer_desc, cpu_engine);
+ this->diff_weights_iter_r_ = std::make_shared<mkldnn::memory>(
+ native_iter_desc, cpu_engine);
+
+ if (native_layer_desc == bwd_.diff_weights_layer_desc_) {
+ this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
+ bwd_.diff_weights_layer_desc_, cpu_engine,
diff_weights_layer_r_->get_data_handle());
+ } else {
+ this->diff_weights_layer_ = std::make_shared<mkldnn::memory>(
+ bwd_.diff_weights_layer_desc_, cpu_engine);
+ }
+ if (native_iter_desc == bwd_.diff_weights_iter_desc_) {
+ this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
+ bwd_.diff_weights_iter_desc_, cpu_engine,
diff_weights_iter_r_->get_data_handle());
+ } else {
+ this->diff_weights_iter_ = std::make_shared<mkldnn::memory>(
+ bwd_.diff_weights_iter_desc_, cpu_engine);
+ }
this->diff_bias_ = std::make_shared<mkldnn::memory>(
bwd_.diff_bias_desc_, cpu_engine);
+ }
+ std::memset(this->diff_weights_layer_->get_data_handle(), 0,
+ bwd_.diff_weights_layer_desc_.get_size());
+ std::memset(this->diff_weights_iter_->get_data_handle(), 0,
+ bwd_.diff_weights_iter_desc_.get_size());
+ std::memset(this->diff_bias_->get_data_handle(), 0,
+ bwd_.diff_bias_desc_.get_size());
EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
this->diff_weights_layer_.get());
EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_DIFF_WEIGHTS_ITER,
@@ -776,30 +822,40 @@ void MKLDNNRnnBackward::SetDataGradsMem(
}
}
-template <typename DType>
-void HalveWeightsDiff(DType* w, const size_t size) {
- const int omp_threads =
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
- #pragma omp parallel for num_threads(omp_threads)
- for (int i = 0; i < static_cast<int>(size); ++i) {
- w[i] *= 0.5;
+void MKLDNNRnnBackward::SetNativeWeightsGrads() const {
+ if (this->diff_weights_layer_->get_desc() !=
this->diff_weights_layer_r_->get_desc()) {
+ MKLDNNMemoryReorder(*this->diff_weights_layer_,
*this->diff_weights_layer_r_);
+ }
+ if (this->diff_weights_iter_->get_desc() !=
this->diff_weights_iter_r_->get_desc()) {
+ MKLDNNMemoryReorder(*this->diff_weights_iter_,
*this->diff_weights_iter_r_);
}
}
-void MKLDNNRnnBackward::CommitWeightsDiff(void* diff_weights, void* diff_bias,
int dtype) {
- using tag = mkldnn::memory::format_tag;
- auto& cpu_engine = CpuEngine::Get()->get_engine();
- auto s = mkldnn::stream(cpu_engine);
+#define OPREQTYPE_SWITCH(ReqType, DType, FWrapper, ...) \
+std::function<void(DType*, DType*, size_t)> FWrapper = nullptr; \
+if (kWriteTo == ReqType || kWriteInplace == ReqType) \
+ FWrapper = common::ParallelCopy<DType>; \
+else \
+ FWrapper = common::ParallelAdd<DType>; \
+{__VA_ARGS__}
+void MKLDNNRnnBackward::CommitWeightsGrads(void* diff_weights, void* diff_bias,
+ const OpReqType req, const int
dtype) {
const MKLDNNRnnLayerParam& param = fwd_ptr_->GetParam();
+
+ void* diff_weights_layer_ptr = this->diff_weights_layer_->get_data_handle();
+ void* diff_weights_iter_ptr = this->diff_weights_iter_->get_data_handle();
+ if (this->diff_weights_layer_->get_desc() !=
this->diff_weights_layer_r_->get_desc())
+ diff_weights_layer_ptr = this->diff_weights_layer_r_->get_data_handle();
+ if (this->diff_weights_iter_->get_desc() !=
this->diff_weights_iter_r_->get_desc())
+ diff_weights_iter_ptr = this->diff_weights_iter_r_->get_data_handle();
+
const int num_layer = param.num_layer;
const int direction = param.bidirectional ? 2 : 1;
const int ngates = GetRnnGatesNum(param.mode);
- const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
- const size_t wxh_bytes = param.single_w_size * dtype_bytes;
- const size_t wx_bytes = param.input_size * param.state_size * ngates *
dtype_bytes;
- const size_t wh_bytes = param.state_size * param.state_size * ngates *
dtype_bytes;
- char* diff_wx_ptr = static_cast<char
*>(diff_weights_layer_->get_data_handle());
- char* diff_wh_ptr = static_cast<char
*>(diff_weights_iter_->get_data_handle());
+ const size_t wxh_size = param.single_w_size;
+ const size_t wx_size = param.input_size * param.state_size * ngates;
+ const size_t wh_size = param.state_size * param.state_size * ngates;
/* naive weights layout is:
1st-layer: | wx_lr | wh_lr | wx_rl | wh_rl |
@@ -807,82 +863,81 @@ void MKLDNNRnnBackward::CommitWeightsDiff(void*
diff_weights, void* diff_bias, i
size: | wxh_bytes |
|wx_bytes|wh_bytes|
*/
- char* naive_weights = static_cast<char *>(diff_weights);
- if (param.mode != rnn_enum::kGru) {
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- std::memcpy(naive_weights + shift * wxh_bytes,
- diff_wx_ptr + shift * wx_bytes, wx_bytes);
- }
- // align naive_weights to weights_iter memory
- naive_weights += wx_bytes;
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- std::memcpy(naive_weights + shift * wxh_bytes,
- diff_wh_ptr + shift * wh_bytes, wh_bytes);
- }
- } else {
- const size_t wx_bytes_per_gate = param.input_size * param.state_size *
dtype_bytes;
- const size_t wh_bytes_per_gate = param.state_size * param.state_size *
dtype_bytes;
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- std::memcpy(naive_weights + shift * wxh_bytes + wx_bytes_per_gate,
- diff_wx_ptr + shift * wx_bytes, wx_bytes_per_gate);
- std::memcpy(naive_weights + shift * wxh_bytes,
- diff_wx_ptr + shift * wx_bytes + wx_bytes_per_gate,
wx_bytes_per_gate);
- std::memcpy(naive_weights + shift * wxh_bytes + 2 * wx_bytes_per_gate,
- diff_wx_ptr + shift * wx_bytes + 2 * wx_bytes_per_gate,
wx_bytes_per_gate);
- }
- // align naive_weights to weights_iter memory
- naive_weights += wx_bytes;
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- std::memcpy(naive_weights + shift * wxh_bytes + wh_bytes_per_gate,
- diff_wh_ptr + shift * wh_bytes, wh_bytes_per_gate);
- std::memcpy(naive_weights + shift * wxh_bytes,
- diff_wh_ptr + shift * wh_bytes + wh_bytes_per_gate,
wh_bytes_per_gate);
- std::memcpy(naive_weights + shift * wxh_bytes + 2 * wh_bytes_per_gate,
- diff_wh_ptr + shift * wh_bytes + 2 * wh_bytes_per_gate,
wh_bytes_per_gate);
- }
- }
-
- char* naive_bias = static_cast<char *>(diff_bias);
- char* diff_bias_ptr = static_cast<char
*>(this->diff_bias_->get_data_handle());
- const size_t bias_bytes = param.single_b_size * dtype_bytes;
- const size_t naive_bias_bytes = param.naive_single_b_size * dtype_bytes;
- if (param.mode != rnn_enum::kGru) {
- MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- DType* typed_bias = reinterpret_cast<DType *>(diff_bias_ptr);
- HalveWeightsDiff(typed_bias, num_layer * direction *
param.single_b_size);
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ DType* native_weights = static_cast<DType *>(diff_weights);
+ DType* diff_wx_ptr = static_cast<DType *>(diff_weights_layer_ptr);
+ DType* diff_wh_ptr = static_cast<DType *>(diff_weights_iter_ptr);
+ OPREQTYPE_SWITCH(req, DType, FAccGrad, {
+ if (param.mode != rnn_enum::kGru) {
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ FAccGrad(native_weights + shift * wxh_size, diff_wx_ptr + shift *
wx_size, wx_size);
+ }
+ // align native_weights to weights_iter memory
+ native_weights += wx_size;
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ FAccGrad(native_weights + shift * wxh_size, diff_wh_ptr + shift *
wh_size, wh_size);
+ }
+ } else {
+ const size_t wx_size_per_gate = param.input_size * param.state_size;
+ const size_t wh_size_per_gate = param.state_size * param.state_size;
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ FAccGrad(native_weights + shift * wxh_size + wx_size_per_gate,
+ diff_wx_ptr + shift * wx_size, wx_size_per_gate);
+ FAccGrad(native_weights + shift * wxh_size,
+ diff_wx_ptr + shift * wx_size + wx_size_per_gate,
wx_size_per_gate);
+ FAccGrad(native_weights + shift * wxh_size + 2 * wx_size_per_gate,
+ diff_wx_ptr + shift * wx_size + 2 * wx_size_per_gate,
wx_size_per_gate);
+ }
+ // align native_weights to weights_iter memory
+ native_weights += wx_size;
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ FAccGrad(native_weights + shift * wxh_size + wh_size_per_gate,
+ diff_wh_ptr + shift * wh_size, wh_size_per_gate);
+ FAccGrad(native_weights + shift * wxh_size,
+ diff_wh_ptr + shift * wh_size + wh_size_per_gate,
wh_size_per_gate);
+ FAccGrad(native_weights + shift * wxh_size + 2 * wh_size_per_gate,
+ diff_wh_ptr + shift * wh_size + 2 * wh_size_per_gate,
wh_size_per_gate);
+ }
+ }
});
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- std::memcpy(naive_bias + shift * naive_bias_bytes,
- diff_bias_ptr + shift * bias_bytes, bias_bytes);
- std::memcpy(naive_bias + shift * naive_bias_bytes + bias_bytes,
- diff_bias_ptr + shift * bias_bytes, bias_bytes);
- }
- } else {
- const size_t bias_bytes_per_gate = param.state_size * dtype_bytes;
- MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- for (int shift = 0; shift < num_layer * direction; ++shift) {
- char* naive_reset = naive_bias + shift * naive_bias_bytes;
- char* naive_update = naive_reset + bias_bytes_per_gate;
- char* update = diff_bias_ptr + shift * bias_bytes;
- char* reset = update + bias_bytes_per_gate;
-
- DType* typed_update = reinterpret_cast<DType *>(update);
- HalveWeightsDiff(typed_update, param.state_size * 2);
-
- std::memcpy(naive_update, update, bias_bytes_per_gate);
- std::memcpy(naive_reset, reset, bias_bytes_per_gate);
- std::memcpy(naive_update + naive_bias_bytes / 2, update,
bias_bytes_per_gate);
- std::memcpy(naive_reset + naive_bias_bytes / 2, reset,
bias_bytes_per_gate);
-
- char* naive_new_bx = naive_update + bias_bytes_per_gate;
- char* naive_new_bh = naive_new_bx + naive_bias_bytes / 2;
- char* new_bx = reset + bias_bytes_per_gate;
- char* new_bh = new_bx + bias_bytes_per_gate;
- std::memcpy(naive_new_bx, new_bx, bias_bytes_per_gate);
- std::memcpy(naive_new_bh, new_bh, bias_bytes_per_gate);
+ });
+
+ const size_t bias_size = param.single_b_size;
+ const size_t naive_bias_size = param.naive_single_b_size;
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ DType* native_bias = static_cast<DType *>(diff_bias);
+ DType* diff_bias_ptr = static_cast<DType
*>(this->diff_bias_->get_data_handle());
+ OPREQTYPE_SWITCH(req, DType, FAccGrad, {
+ if (param.mode != rnn_enum::kGru) {
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ FAccGrad(native_bias + shift * naive_bias_size,
+ diff_bias_ptr + shift * bias_size, bias_size);
+ FAccGrad(native_bias + shift * naive_bias_size + bias_size,
+ diff_bias_ptr + shift * bias_size, bias_size);
+ }
+ } else {
+ const size_t bias_size_per_gate = param.state_size;
+ for (int shift = 0; shift < num_layer * direction; ++shift) {
+ DType* native_reset = native_bias + shift * naive_bias_size;
+ DType* native_update = native_reset + bias_size_per_gate;
+ DType* update = diff_bias_ptr + shift * bias_size;
+ DType* reset = update + bias_size_per_gate;
+
+ FAccGrad(native_update, update, bias_size_per_gate);
+ FAccGrad(native_reset, reset, bias_size_per_gate);
+ FAccGrad(native_update + naive_bias_size / 2, update,
bias_size_per_gate);
+ FAccGrad(native_reset + naive_bias_size / 2, reset,
bias_size_per_gate);
+
+ DType* native_new_bx = native_update + bias_size_per_gate;
+ DType* native_new_bh = native_new_bx + naive_bias_size / 2;
+ DType* new_bx = reset + bias_size_per_gate;
+ DType* new_bh = new_bx + bias_size_per_gate;
+ FAccGrad(native_new_bx, new_bx, bias_size_per_gate);
+ FAccGrad(native_new_bh, new_bh, bias_size_per_gate);
+ }
}
});
- }
+ });
}
template <typename MKLDNNRnnX>
@@ -893,25 +948,18 @@ inline void RegisterMKLDNNRnn(MKLDNNRnnX const& rnn) {
template <>
inline void RegisterMKLDNNRnn(MKLDNNRnnBackward const& rnn) {
MKLDNNStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap());
+ rnn.SetNativeWeightsGrads();
}
void MKLDNNRnnOp::Forward(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
+ TmpMemMgr::Get()->Init(ctx.requested[0]);
// In the `autograd.record()` context, RNNOp is required to run into
// forward_training mode.
const bool is_training = (ctx.is_train || ctx.need_grad);
- // check output requests
- if (kAddTo == req[rnn_enum::kOut])
- LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
const RNNParam& default_param = full_param_.default_param;
- if (default_param.state_outputs) {
- if (kAddTo == req[rnn_enum::kStateOut])
- LOG(FATAL) << "Currently, `add` operation is not supported by RNNs.";
- if (default_param.mode == rnn_enum::kLstm && kAddTo ==
req[rnn_enum::kStateCellOut])
- LOG(FATAL) << "Currently, `add` operation against lstm-cell output is
not supported.";
- }
// Initialize weights version
if (!initialized_ && weights_version_ == 0) {
@@ -919,8 +967,8 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
}
// Check if weights NDArray was changed. If so, reset initialized_
- if (weights_version_ != inputs[rnn_enum::kParams].version() &&
- fwd_inf_vec_.size() > 0) {
+ if (!is_training && fwd_inf_vec_.size() > 0
+ && weights_version_ != inputs[rnn_enum::kParams].version()) {
initialized_ = false;
for (auto& fwd : fwd_inf_vec_) fwd.Reset();
weights_version_ = inputs[rnn_enum::kParams].version();
@@ -932,24 +980,40 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
// Get data type
int data_dtype = inputs[rnn_enum::kData].dtype();
+ // Get temporary memory for output, state_out, statecell_out
+ const int num_layers = default_param.num_layers;
+ const int seq_length = default_param.seq_length_;
+ const int batch_size = default_param.batch_size_;
+ const int state_size = default_param.state_size;
+ const int directions = default_param.bidirectional ? 2 : 1;
+ mkldnn::memory::desc dst_desc({seq_length, batch_size, directions *
state_size},
+ get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc);
+ mkldnn::memory::desc state_desc({num_layers, directions, batch_size,
state_size},
+ get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
+ auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc,
req[rnn_enum::kOut]);
+ mkldnn_output_t stateout_mem;
+ mkldnn_output_t statecellout_mem;
// Get input & output NDArray
char *src = static_cast<char *>(inputs[rnn_enum::kData].data().dptr_);
char *src_state = static_cast<char *>(inputs[rnn_enum::kState].data().dptr_);
- char *dst = req[rnn_enum::kOut] == kNullOp ? nullptr
- : static_cast<char *>(outputs[rnn_enum::kOut].data().dptr_);
+ char *dst = static_cast<char *>(out_mem.second->get_data_handle());
char *dst_state = nullptr; // Output state
char *src_state_cell = nullptr; // Used in LSTM for cell state
char *dst_state_cell = nullptr; // Used in LSTM for cell state
if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) {
- dst_state = static_cast<char *>(outputs[rnn_enum::kStateOut].data().dptr_);
+ stateout_mem = CreateMKLDNNMem(
+ outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]);
+ dst_state = static_cast<char *>(stateout_mem.second->get_data_handle());
}
if (default_param.mode == rnn_enum::kLstm) {
src_state_cell = static_cast<char
*>(inputs[rnn_enum::kStateCell].data().dptr_);
if (default_param.state_outputs && req[rnn_enum::kStateCellOut] !=
kNullOp) {
- dst_state_cell = static_cast<char
*>(outputs[rnn_enum::kStateCellOut].data().dptr_);
+ statecellout_mem = CreateMKLDNNMem(
+ outputs[rnn_enum::kStateCellOut], state_desc,
req[rnn_enum::kStateCellOut]);
+ dst_state_cell = static_cast<char
*>(statecellout_mem.second->get_data_handle());
}
}
@@ -1000,6 +1064,12 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
} else {
for (auto& inf_lyr : fwd_inf_vec_) RegisterMKLDNNRnn(inf_lyr);
}
+ CommitOutput(outputs[rnn_enum::kOut], out_mem);
+ if (default_param.state_outputs) {
+ CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem);
+ if (default_param.mode == rnn_enum::kLstm)
+ CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem);
+ }
MKLDNNStream::Get()->Submit();
}
@@ -1008,18 +1078,11 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using tag = mkldnn::memory::format_tag;
+ TmpMemMgr::Get()->Init(ctx.requested[0]);
const RNNParam& default_param = full_param_.default_param;
- if (kAddTo == req[rnn_enum::kData] || kAddTo == req[rnn_enum::kParams])
- LOG(FATAL) << "Currently, `add` operations against gradients of input and
weights"
- << " are not supported by RNNs.";
- if (default_param.state_outputs) {
- if (kAddTo == req[rnn_enum::kStateOut])
- LOG(FATAL) << "Currently, `add` operation against gradients of begining
state"
- << " is not supported by RNNs.";
- if (default_param.mode == rnn_enum::kLstm && req[rnn_enum::kStateCell])
- LOG(FATAL) << "Currently, `add` operation against gradients of begining
cell-state"
- << " is not supported by LSTM.";
- }
+ const int data_dtype = inputs[rnn_enum::kData].dtype();
+ const int w_dtype = inputs[rnn_enum::kParams].dtype();
+
// Initialize the bwd_vec_
if (bwd_vec_.size() != fwd_inf_vec_.size()) {
bwd_vec_.clear();
@@ -1035,24 +1098,39 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
bwd_vec_.at(lyr).SetWeightsGradsMem();
}
- const int data_dtype = inputs[rnn_enum::kData].dtype();
- const int w_dtype = inputs[rnn_enum::kParams].dtype();
const size_t w_bytes = mshadow::mshadow_sizeof(w_dtype);
+ // Get temporary memory for diff_src, diff_state, diff_statecell
+ const int num_layers = default_param.num_layers;
+ const int seq_length = default_param.seq_length_;
+ const int batch_size = default_param.batch_size_;
+ const int input_size = default_param.input_size_;
+ const int state_size = default_param.state_size;
+ const int directions = default_param.bidirectional ? 2 : 1;
+ mkldnn::memory::desc src_desc({seq_length, batch_size, input_size},
+ get_mkldnn_type(data_dtype), tag::tnc);
+ mkldnn::memory::desc state_desc({num_layers, directions, batch_size,
state_size},
+ get_mkldnn_type(data_dtype), tag::ldnc);
+ auto diff_input_mem = CreateMKLDNNMem(outputs[rnn_enum::kData], src_desc,
req[rnn_enum::kData]);
+ mkldnn_output_t diff_state_mem;
+ mkldnn_output_t diff_statecell_mem;
// index description of outputs NDArray
// 0 1 2 3
// | dx | dw | dhx | dcx|
- char* dx = req[rnn_enum::kData] == kNullOp ? nullptr
- : static_cast<char *>(outputs[rnn_enum::kData].data().dptr_);
+ char* dx = static_cast<char *>(diff_input_mem.second->get_data_handle());
char* dw = static_cast<char *>(outputs[rnn_enum::kParams].data().dptr_);
char* db = dw + (inputs[rnn_enum::kParams].data().Size() -
GetRnnBiasSize(default_param.num_layers, default_param.state_size,
default_param.bidirectional + 1, default_param.mode)) * w_bytes;
- char* dhx = req[rnn_enum::kState] == kNullOp ? nullptr
- : static_cast<char *>(outputs[rnn_enum::kState].data().dptr_);
+ diff_state_mem = CreateMKLDNNMem(
+ outputs[rnn_enum::kState], state_desc, req[rnn_enum::kState]);
+ char* dhx = static_cast<char *>(diff_state_mem.second->get_data_handle());
char* dcx = nullptr;
if (full_param_.default_param.mode == rnn_enum::kLstm
- && req[rnn_enum::kStateCell] != kNullOp)
- dcx = static_cast<char *>(outputs[rnn_enum::kStateCell].data().dptr_);
+ && req[rnn_enum::kStateCell] != kNullOp) {
+ diff_statecell_mem = CreateMKLDNNMem(
+ outputs[rnn_enum::kStateCell], state_desc, req[rnn_enum::kStateCell]);
+ dcx = static_cast<char *>(diff_statecell_mem.second->get_data_handle());
+ }
// index description of inputs NDArray
// 0 1 2 3 4 5 6 7 8 9
@@ -1100,12 +1178,16 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
RegisterMKLDNNRnn(*bwd);
}
}
+ CommitOutput(outputs[rnn_enum::kData], diff_input_mem);
+ CommitOutput(outputs[rnn_enum::kState], diff_state_mem);
+ if (full_param_.default_param.mode == rnn_enum::kLstm)
+ CommitOutput(outputs[rnn_enum::kStateCell], diff_statecell_mem);
MKLDNNStream::Get()->Submit();
// Commit weights diff
if (req[rnn_enum::kParams] != kNullOp) {
for (size_t lyr = 0; lyr < bwd_vec_.size(); ++lyr) {
- bwd_vec_.at(lyr).CommitWeightsDiff(dw, db, w_dtype);
+ bwd_vec_.at(lyr).CommitWeightsGrads(dw, db, req[rnn_enum::kParams],
w_dtype);
dw += full_param_.layer_params.at(lyr).single_w_size * w_bytes;
db += full_param_.layer_params.at(lyr).single_b_size * w_bytes;
}
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index c23a5a8..c3cb5c8 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -561,6 +561,24 @@ class OpSignature {
case mkldnn_format_kind_rnn_packed:
hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.format;
eles.push_back(desc.data.format_desc.rnn_packed_desc.format);
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n_parts;
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.n_parts);
+ for (int i = 0; i < desc.data.format_desc.rnn_packed_desc.n_parts;
++i) {
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.parts[i];
+ hash = hash * 2 +
desc.data.format_desc.rnn_packed_desc.part_pack_size[i];
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.pack_part[i];
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.parts[i]);
+
eles.push_back(desc.data.format_desc.rnn_packed_desc.part_pack_size[i]);
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.pack_part[i]);
+ }
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.n;
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.ldb;
+ hash = hash * 2 +
desc.data.format_desc.rnn_packed_desc.offset_compensation;
+ hash = hash * 2 + desc.data.format_desc.rnn_packed_desc.size;
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.n);
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.ldb);
+
eles.push_back(desc.data.format_desc.rnn_packed_desc.offset_compensation);
+ eles.push_back(desc.data.format_desc.rnn_packed_desc.size);
break;
default:
// nothing need to add
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 6d568c8..a8e1b12 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -182,6 +182,10 @@ static std::vector<ResourceRequest> RNNResourceEx(const
NodeAttrs& attrs, const
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
}
#endif
+ } else {
+#if MXNET_USE_MKLDNN == 1
+ request.emplace_back(ResourceRequest::kTempSpace);
+#endif
}
return request;
}
@@ -243,7 +247,8 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs
&attrs,
#if MXNET_USE_MKLDNN == 1
if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
- && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU) {
+ && in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
+ && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
data_shape[1], data_shape[2]);
@@ -269,8 +274,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr&
state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() ==
mshadow::kFloat16) &&
- inputs[0].shape().ndim() == 3) {
+ if (SupportMKLDNNRnn(inputs[0])) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Forward(ctx, inputs, req, outputs);
} else {
@@ -283,8 +287,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr&
state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- if ((inputs[0].dtype() == mshadow::kFloat32 || inputs[0].dtype() ==
mshadow::kFloat16) &&
- inputs[0].shape().ndim() == 3) {
+ if (SupportMKLDNNRnn(inputs[0])) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Backward(ctx, inputs, req, outputs);
} else {
diff --git a/tests/python/unittest/test_gluon_rnn.py
b/tests/python/unittest/test_gluon_rnn.py
index 309756b..0f27f53 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -19,6 +19,8 @@ import mxnet as mx
from mxnet import gluon, nd
import numpy as np
import copy
+from itertools import product
+from functools import partial
from numpy.testing import assert_allclose
import unittest
from mxnet.test_utils import almost_equal, assert_almost_equal
@@ -545,6 +547,128 @@ def test_rnn_layers_fp16():
run_rnn_layers('float16', 'float32', mx.gpu())
+def check_rnn_consistency(fused_layer, stack_layer, loss, input_size,
hidden_size, bidirectional=False, rtol=1e-2, atol=1e-4):
+ fused_begin_state = fused_layer.begin_state(1)
+ stack_state = stack_layer.begin_state(batch_size=1)
+ x = nd.random.normal(shape=(1, 5, input_size))
+ x.attach_grad()
+ y = nd.random.normal(shape=(1, 5, hidden_size * 2 if bidirectional else
hidden_size))
+
+ with mx.autograd.record():
+ fused_out, fused_state = fused_layer(x, fused_begin_state)
+ l = loss(fused_out, y).mean()
+ l.backward()
+ fused_grads = dict([(name, p.grad()) for name, p in
fused_layer.collect_params().items()])
+ fused_input_grad = x.grad.asnumpy()
+
+ with mx.autograd.record():
+ stack_out, stack_state = stack_layer.unroll(5, x, stack_state,
merge_outputs=True)
+ l = loss(stack_out, y).mean()
+ l.backward()
+ stack_grads = dict([(name, p.grad()) for name, p in
stack_layer.collect_params().items()])
+ stack_input_grad = x.grad.asnumpy()
+
+ assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol,
atol=atol)
+ assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
+ for key, value in fused_grads.items():
+ assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(),
rtol=rtol, atol=atol)
+
+
+def create_op_by_mode(mode):
+ if mode == 'lstm':
+ fused_op = gluon.rnn.LSTM
+ stack_op = gluon.rnn.LSTMCell
+ recurrent_block_prefix = 'lstm0_'
+ elif mode == 'gru':
+ fused_op = gluon.rnn.GRU
+ stack_op = gluon.rnn.GRUCell
+ recurrent_block_prefix = 'gru0_'
+ elif mode == 'rnn_relu':
+ fused_op = partial(gluon.rnn.RNN, activation='relu')
+ stack_op = partial(gluon.rnn.RNNCell, activation='relu')
+ recurrent_block_prefix = 'rnn0_'
+ elif mode == 'rnn_tanh':
+ fused_op = partial(gluon.rnn.RNN, activation='tanh')
+ stack_op = partial(gluon.rnn.RNNCell, activation='tanh')
+ recurrent_block_prefix = 'rnn0_'
+
+ return fused_op, stack_op, recurrent_block_prefix
+
+
+def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss):
+ fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
+ # ==== Single layer ====
+ fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC',
bidirectional=False)
+ fused_layer.collect_params().initialize()
+
+ params = fused_layer.collect_params()
+ stack_layer =
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix,
params=params)
+ with stack_layer.name_scope():
+ stack_layer.add(stack_op(hidden_size, prefix='l0_'))
+ stack_layer.initialize()
+
+ check_rnn_consistency(fused_layer, stack_layer, loss, input_size,
hidden_size)
+
+ # ==== Multiple layer ====
+ fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC',
bidirectional=False)
+ fused_layer.collect_params().initialize()
+
+ params = fused_layer.collect_params()
+ stack_layer =
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix,
params=params)
+ with stack_layer.name_scope():
+ stack_layer.add(stack_op(hidden_size, prefix='l0_'))
+ stack_layer.add(stack_op(hidden_size, prefix='l1_'))
+ stack_layer.add(stack_op(hidden_size, prefix='l2_'))
+ stack_layer.initialize()
+
+ check_rnn_consistency(fused_layer, stack_layer, loss, input_size,
hidden_size)
+
+
+def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
+ fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
+ # ==== Single layer ====
+ fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC',
bidirectional=True)
+ fused_layer.collect_params().initialize()
+
+ params = fused_layer.collect_params()
+ stack_layer =
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix,
params=params)
+ with stack_layer.name_scope():
+ stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size,
prefix='l0_'),
+ stack_op(hidden_size,
prefix='r0_')))
+ stack_layer.initialize()
+
+ check_rnn_consistency(fused_layer, stack_layer, loss, input_size,
hidden_size, bidirectional=True)
+
+ # ==== Multiple layer ====
+ fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC',
bidirectional=True)
+ fused_layer.collect_params().initialize()
+
+ params = fused_layer.collect_params()
+ stack_layer =
mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix,
params=params)
+ with stack_layer.name_scope():
+ stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size,
prefix='l0_'),
+ stack_op(hidden_size,
prefix='r0_')))
+ stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size,
prefix='l1_'),
+ stack_op(hidden_size,
prefix='r1_')))
+ stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size,
prefix='l2_'),
+ stack_op(hidden_size,
prefix='r2_')))
+ stack_layer.initialize()
+
+ check_rnn_consistency(fused_layer, stack_layer, loss, input_size,
hidden_size, bidirectional=True)
+
+
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_fused_rnn_layer():
+ input_sizes = [128]
+ hidden_sizes = [128, 256]
+ modes = ['lstm', 'gru', 'rnn_relu', 'rnn_tanh']
+ # single layer
+ for mode, input_size, hidden_size in product(modes, input_sizes,
hidden_sizes):
+ loss = mx.gluon.loss.L2Loss()
+ check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss)
+ check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss)
+
+
def test_rnn_unroll_variant_length():
# Test for imperative usage
cell_list = []
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 6fdb3c8..9ae35f1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -36,15 +36,6 @@ import unittest
import os
def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2,
atol=1e-4):
- if default_context().device_type == 'cpu':
- # NOTE(zixuanweeei): Currently, we don't add `add` requests support on
fused mkl-dnn rnn operator.
- # We tracked this issue by
https://github.com/apache/incubator-mxnet/issues/16578
- if isinstance(grad_req, dict) and 'add' in grad_req.values():
- print("Skip the test when requiring `add` operation against
gradients on CPU context.")
- return
- if isinstance(grad_req, str) and grad_req == 'add':
- print("Skip the test when requiring `add` operation against
gradients on CPU context.")
- return
dshape = (N, T, I)
data = mx.sym.Variable('data')
@@ -182,9 +173,9 @@ def test_gru_sym():
stack.add(mx.rnn.GRUCell(H, prefix='l1_'))
stack.add(mx.rnn.GRUCell(H, prefix='l2_'))
- check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4)
- check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4)
- check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4)
+ check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+ check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+ check_rnn_consistency(fused, stack, T, N, I, H, 'null')
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
@@ -208,9 +199,9 @@ def test_gru_bidirectional():
mx.rnn.GRUCell(H, prefix='r1_'),
output_prefix='bi_gru_1_'))
- check_rnn_consistency(fused, stack, T, N, I, H, 'write', atol=2e-4)
- check_rnn_consistency(fused, stack, T, N, I, H, 'add', atol=2e-4)
- check_rnn_consistency(fused, stack, T, N, I, H, 'null', atol=2e-4)
+ check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+ check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+ check_rnn_consistency(fused, stack, T, N, I, H, 'null')
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')