bartekkuncer commented on a change in pull request #20606:
URL: https://github.com/apache/incubator-mxnet/pull/20606#discussion_r718220715
##########
File path: src/operator/subgraph/dnnl/dnnl_conv.cc
##########
@@ -168,42 +166,42 @@ void SgMKLDNNConvOperator::Forward(const OpContext& ctx,
}
}
CHECK_EQ(input_size, idx);
- bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias;
+ bool has_bias = dnnl_param.with_bn || !conv_param.no_bias;
NDArray data = inputs[in_data];
- NDArray output = mkldnn_param.with_sum ? inputs[in_sum] : outputs[kOut];
+ NDArray output = dnnl_param.with_sum ? inputs[in_sum] : outputs[kOut];
// Copy inputs[in_sum] into outputs[kOut] in case inplace optimization
failed.
- if (mkldnn_param.with_sum) {
+ if (dnnl_param.with_sum) {
if (!initialized_) {
- // TODO(zhennan): Currently, mkldnn fallback mechanism will break
inplace option,
+ // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace
option,
// which make check (req[kOut] == kWriteInplace) useless.
- auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
- auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+ auto in_mkl_mem = inputs[in_sum].GetDNNLData();
+ auto out_mkl_mem = outputs[kOut].GetDNNLData();
if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
inplace_ = true;
}
}
if (!inplace_) {
- auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
- auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
+ auto in_mkl_mem = inputs[in_sum].GetDNNLData();
+ auto out_mkl_mem = outputs[kOut].GetDNNLData();
if (outputs[kOut].dtype() == mshadow::kInt32) {
const auto& mem_desc = in_mkl_mem->get_desc();
- const auto this_dtype = get_mkldnn_type(mshadow::kInt32);
+ const auto this_dtype = get_dnnl_type(mshadow::kInt32);
auto omd = mem_desc;
- omd.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
- mkldnn_mem_ptr tmp_mem(new mkldnn::memory(
- omd, CpuEngine::Get()->get_engine(),
out_mkl_mem->get_data_handle()));
- MKLDNNStream::Get()->RegisterMem(tmp_mem);
- MKLDNNStream::Get()->RegisterPrimArgs(
- mkldnn::reorder(*in_mkl_mem, *tmp_mem),
- {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
+ omd.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
+ dnnl_mem_ptr tmp_mem(
+ new dnnl::memory(omd, CpuEngine::Get()->get_engine(),
out_mkl_mem->get_data_handle()));
Review comment:
done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]