zixuanweeei commented on a change in pull request #15621: MKL-DNN LBR-GRU
Inference Integration (FP32 LBR-GRU)
URL: https://github.com/apache/incubator-mxnet/pull/15621#discussion_r310033770
##########
File path: src/operator/nn/mkldnn/mkldnn_rnn_impl.h
##########
@@ -73,64 +93,69 @@ static void ConcatData(mkldnn::memory::format src_format,
mkldnn::memory::dims dst_cds,
mkldnn::memory::data_type mkldnn_dtype,
int concat_dimension,
- std::vector<void*> srcs_data,
- const mkldnn::memory &dst) {
+ const std::vector<void*> &srcs_data,
+ const mkldnn::memory &dst,
+ std::vector<mkldnn::memory> *tmp_src_mems) {
auto cpu_engine = CpuEngine::Get()->get_engine();
std::vector<mkldnn::memory::primitive_desc> srcs_pd;
- std::vector<mkldnn::memory> srcs;
+ bool initialized = tmp_src_mems->size() > 0;
for (size_t i = 0; i < srcs_cds.size(); i++) {
auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format);
auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine);
- auto src_memory = mkldnn::memory(mpd, srcs_data[i]);
srcs_pd.push_back(mpd);
- srcs.push_back(src_memory);
- }
- std::vector<primitive::at> inputs;
- for (size_t i = 0; i < srcs_cds.size(); i++) {
- inputs.push_back(srcs[i]);
+ if (initialized) {
+ tmp_src_mems->at(i).set_data_handle(srcs_data[i]);
+ } else {
+ auto src_memory = mkldnn::memory(mpd, srcs_data[i]);
+ tmp_src_mems->push_back(src_memory);
+ }
}
+ std::vector<primitive::at> inputs(tmp_src_mems->begin(),
tmp_src_mems->end());
auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format);
auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd);
MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst));
- MKLDNNStream::Get()->Submit();
}
-// cached mkldnn memory
-// first layer wx, wh with next L - 1 layers wx and wh
-// with L layers hx and cx, src and dst data/iter etc.
-// it will prepare memory on before and after reorder and concat.
-// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H.
-// for bidirectional, it will fused as data + back_data (weight, bias, iter
etc),
-// also need to identify first layer and next layers
-static size_t GetMKLDNNRNNCacheMemorySize(int L,
- int D,
- int T,
- int N,
- int I,
- int H,
+/**
+ * Size of cached memory
+ *
+ * Cache memory of wx, wh from the first layer and next L - 1 layers
+ * seperately, as well as the layer and iter memory for src and dst.
+ * Output states memory hx, hc and bias memory are also cached. It
+ * will prepare memory on before and after reorder and concat. For
+ * unidirectional, it will fused as dim like 1 + (L - 1) when I != H.
+ * For bidirectional, it will fused as data + back_data (weight, bias,
+ * iter etc)
+ *
+ * @param num_layer Number of Layers
+ * @param direction Direction of the RNN implement. It should be 1 or 2.
+ * @param seq_len The maximum sequence length.
+ * @param batch_size Batch size.
+ * @param input_size Input channel. Also the dimension of the input feature.
+ * @param hidden_size Hidden state size.
+ * @return The required cache size.
+ */
+static size_t GetMKLDNNRNNCacheMemorySize(int num_layer,
+ int direction,
+ int seq_len,
+ int batch_size,
+ int input_size,
+ int hidden_size,
int mode) {
- size_t size = 0;
- switch (mode) {
- case rnn_enum::kLstm:
- size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H +
- L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2)
* D * 2 * N * H +
- 6 * D * (I + H + 2) * 4 * H + T * N * I * 2;
- break;
- case rnn_enum::kGru:
- size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H +
- L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2)
* D * 2 * N * H +
- 6 * D * (I + H + 2) * 3 * H + T * N * I * 2;
- break;
- case rnn_enum::kRnnRelu:
- case rnn_enum::kRnnTanh:
- size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H +
- L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2)
* D * 2 * N * H +
- 6 * D * (I + H + 2) * 1 * H + T * N * I * 2;
- break;
- default:
- LOG(FATAL) << "unknown RNN mode " << mode;
- break;
- }
+ int n_gates = 0, n_states = 0;
+ GetMKLDNNRNNAlgo(mode, &n_gates, &n_states);
+ int n_bias = mode == rnn_enum::kGru ? n_gates + 1 : n_gates;
+ // sizes of single gates from a single cell
+ const size_t weights_size_0 = direction * (input_size + hidden_size) *
hidden_size;
Review comment:
The input params of `GetMKLDNNRNNCacheMemorySize` are set to be `const
size_t` type. The intermdiate results of `int` type may overflow.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services