rongzha1 commented on a change in pull request #16466: [mkldnn-1.0] upgrade 
int8 concat to MKLDNN1.0
URL: https://github.com/apache/incubator-mxnet/pull/16466#discussion_r334325833
 
 

 ##########
 File path: src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
 ##########
 @@ -71,36 +71,37 @@ static void MKLDNNQuantizedConcatForward(const 
nnvm::NodeAttrs& attrs, const OpC
       CHECK(in_data[i].dtype() == out_dtype);
       auto mem = in_data[i].GetMKLDNNData();
       data_mem.push_back(mem);
-      data_md.push_back(mem->get_primitive_desc());
+      data_md.push_back(mem->get_desc());
     } else {
       auto mem = in_data[i].GetMKLDNNData();
-      auto pd = mem->get_primitive_desc();
+      auto mem_desc = mem->get_desc();
       if (in_data[i].dtype() != out_dtype) {
-        auto mem_desc = pd.desc();
-        mkldnn::memory::desc new_md(
-            mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + 
mem_desc.data.ndims),
-            get_mkldnn_type(out_dtype), 
static_cast<mkldnn::memory::format>(mem_desc.data.format));
-        pd = mkldnn::memory::primitive_desc(new_md, 
CpuEngine::Get()->get_engine());
+        mem_desc.data.data_type = 
static_cast<mkldnn_data_type_t>(get_mkldnn_type(out_dtype));
       }
-      const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
+      const auto rescaled_mem =
+          std::make_shared<mkldnn::memory>(mem_desc, 
CpuEngine::Get()->get_engine());
       new_data_mem.push_back(rescaled_mem);
       std::vector<float> reorder_scale = {out_scale / i_scale};
-      primitive_attr reorder_attr;
-      reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
+      mkldnn::primitive_attr reorder_attr;
       reorder_attr.set_output_scales(0, reorder_scale);
-      const auto reorder_pd =
-          mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, 
reorder_attr);
-      MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, 
*rescaled_mem));
+      const auto reorder_pd = mkldnn::reorder::primitive_desc(*mem, 
*rescaled_mem, reorder_attr);
+      mkldnn_args_map_t reorder_args;
+      reorder_args[MKLDNN_ARG_SRC] = *mem;
+      reorder_args[MKLDNN_ARG_DST] = *rescaled_mem;
+      MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), 
reorder_args);
       data_mem.push_back(rescaled_mem.get());
-      data_md.push_back(pd);
+      data_md.push_back(mem_desc);
     }
   }
   MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md);
-  mxnet::mkldnn_output_t out_mem =
-      CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], 
fwd.fwd_pd.dst_primitive_desc(),
-                      req[concat_enum::kOut]);
-  fwd.SetNewMem(data_mem, *out_mem.second);
-  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+  mxnet::mkldnn_output_t out_mem = 
CreateMKLDNNMem(out_data[quantized_concat_enum::kOut],
+                                                   fwd.fwd_pd.dst_desc(), 
req[concat_enum::kOut]);
+  std::unordered_map<int, mkldnn::memory> net_args;
+  net_args.insert({MKLDNN_ARG_DST, *out_mem.second});
 
 Review comment:
   use same style like line 89 net_args[MKLDNN_ARG_DST] = ...

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to