zixuanweeei commented on a change in pull request #15621: MKL-DNN LBR-GRU
Inference Integration (FP32 LBR-GRU)
URL: https://github.com/apache/incubator-mxnet/pull/15621#discussion_r309994215
##########
File path: src/operator/nn/mkldnn/mkldnn_rnn_impl.h
##########
@@ -225,55 +241,66 @@ static void MKLDNNRNNForwardSingleLayerBi(bool
state_outputs,
mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo
for reorder
mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo
mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo
for reorder
- mkldnn::memory::dims bias_tz = {1, 2, ngates, H};
+ mkldnn::memory::dims bias_tz = {1, 2, nbias, H}; // ldgo
mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc
mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc
- if (!initialized) {
+ bool has_adjusted = false;
+ if (!initialized || is_train) {
if (mode == rnn_enum::kGru) {
AdjustGruWeightGateOrder(wx, I, H);
AdjustGruWeightGateOrder(back_wx, I, H);
AdjustGruWeightGateOrder(wh, H, H);
AdjustGruWeightGateOrder(back_wh, H, H);
- AdjustGruBiasGateOrder(bx, H);
- AdjustGruBiasGateOrder(back_bx, H);
- AdjustGruBiasGateOrder(bh, H);
- AdjustGruBiasGateOrder(back_bh, H);
+ has_adjusted = true;
}
- auto src_wx = (*concat_weight_memory)[2 * layer_index];
- auto src_wh = (*concat_weight_memory)[2 * layer_index + 1];
+ auto src_wx = mkldnn_mems->concat_weight_memory[2 * layer_index];
+ auto src_wh = mkldnn_mems->concat_weight_memory[2 * layer_index + 1];
std::vector<void*> srcs_data1;
srcs_data1.push_back(wx);
srcs_data1.push_back(back_wx);
ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
{weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz,
- mkldnn_dtype, 1, srcs_data1, src_wx);
+ mkldnn_dtype, 1, srcs_data1, src_wx,
&(mkldnn_mems->weight_layer_mems));
srcs_data1.clear();
srcs_data1.push_back(wh);
srcs_data1.push_back(back_wh);
ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
{weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz,
- mkldnn_dtype, 1, srcs_data1, src_wh);
- int tmpvalue = 0;
- if (lvalue > 0) {
- tmpvalue = lvalue + 1;
- }
- MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue]));
- MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue]));
+ mkldnn_dtype, 1, srcs_data1, src_wh, &(mkldnn_mems->weight_iter_mems));
+
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wx,
mkldnn_mems->wx_memory[layer_index]));
+ MKLDNNStream::Get()->RegisterPrim(reorder(src_wh,
mkldnn_mems->wh_memory[layer_index]));
DType* user_bias = reinterpret_cast<DType *>
- ((*bias_memory)[tmpvalue].get_data_handle());
- #pragma omp parallel for num_threads(omp_threads)
- for (int j = 0; j < single_b_size; j++) {
- user_bias[j] = bx[j] + bh[j];
- user_bias[single_b_size + j] = back_bx[j] + back_bh[j];
+ (mkldnn_mems->bias_memory[layer_index].get_data_handle());
+ if (mode == rnn_enum::kGru) {
+ // While mxnet gru gate order is reset, update and new gates,
+ // mkldnn gru gate order is update, reset and new gates. So
+ // we need to swap the order of reset and update from mxnet.
+ const index_t single_b_sz = nbias * H;
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int j = 0; j < H; j++) {
+ user_bias[j + H] = bx[j] + bh[j];
+ user_bias[single_b_sz + j + H] = back_bx[j] + back_bh[j];
+ user_bias[j] = bx[j + H] + bh[j + H];
+ user_bias[single_b_sz + j] = back_bx[j + H] + back_bh[j + H];
+ }
+ #pragma omp parallel for num_threads(omp_threads)
+ for (int j = 2 * H; j < 3 * H; j++) {
Review comment:
Yep, we can merge these two into one loop. Both variants have the same
performance. They cost about ~18 us with `hidden_size=4096` on 1 socket of
SkyLake 8180 .
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services