KellenSunderland commented on a change in pull request #13697: [MKLDNN] Enable 
signed int8 support for convolution.
URL: https://github.com/apache/incubator-mxnet/pull/13697#discussion_r244621044
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_base.cc
 ##########
 @@ -229,51 +229,44 @@ void CommitOutput(const NDArray &arr, const 
mkldnn_output_t &res) {
   }
 }
 
-const mkldnn::memory *GetWeights(const NDArray &arr,
-                                 const mkldnn::memory::primitive_desc 
&target_pd,
-                                 int num_groups) {
-  const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd);
-  // If the weight array already uses the target layout, simply return it
-  // directly.
-  if (mem)
-    return mem;
-
+const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
   mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
+  const mkldnn::memory *mem = nullptr;
   auto engine = CpuEngine::Get()->get_engine();
   if (arr.shape().ndim() == 2) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz =
+        mkldnn::memory::dims{static_cast<int>(arr.shape()[0]), 
static_cast<int>(arr.shape()[1])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::oi};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else if (arr.shape().ndim() == 4 && num_groups == 1) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
-          static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz =
+        mkldnn::memory::dims{static_cast<int>(arr.shape()[0]), 
static_cast<int>(arr.shape()[1]),
+                             static_cast<int>(arr.shape()[2]), 
static_cast<int>(arr.shape()[3])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::oihw};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else if (arr.shape().ndim() == 4) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
-      static_cast<int>(arr.shape()[0] / num_groups),
-      static_cast<int>(arr.shape()[1]),
-      static_cast<int>(arr.shape()[2]),
-      static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
+    mkldnn::memory::dims tz = mkldnn::memory::dims{
+        num_groups, static_cast<int>(arr.shape()[0] / num_groups), 
static_cast<int>(arr.shape()[1]),
+        static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
+    mkldnn::memory::desc md = mkldnn::memory::desc{tz, type, 
mkldnn::memory::format::goihw};
+    mkldnn::memory::primitive_desc pd = mkldnn::memory::primitive_desc{md, 
engine};
     mem = arr.GetMKLDNNData(pd);
   } else {
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
-    return nullptr;
   }
-  if (mem == nullptr)
-    mem = arr.GetMKLDNNDataReorder(target_pd);
+  return mem;
+}
+
+const mkldnn::memory *GetWeights(const NDArray &arr,
+                                 const mkldnn::memory::primitive_desc 
&target_pd, int num_groups) {
+  const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd);
+  // If the weight array already uses the target layout, simply return it
 
 Review comment:
   This can probably be on a single line.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to