cjolivier01 commented on a change in pull request #10317: [MXNET-264] Improve 
performance of MKLDNN in small batch sizes.
URL: https://github.com/apache/incubator-mxnet/pull/10317#discussion_r179326679
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_base-inl.h
 ##########
 @@ -334,11 +334,104 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
                                  const mkldnn::memory::primitive_desc 
&target_pd,
                                  int num_groups);
 
-mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc);
+mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc);
 mkldnn_memory_format_t GetDefaultFormat(int num_dims);
 mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc 
pd,
                                                 mkldnn_memory_format_t format);
 
+static inline bool same_shape(const TShape &shape, const mkldnn_dims_t dims, 
int ndims) {
+  if (shape.ndim() != (size_t)ndims)
+    return false;
+  for (int i = 0; i < ndims; i++)
+    if (shape[i] != dims[i])
+      return false;
+  return true;
+}
+
+static inline bool same_shape(const TShape &shape, int dtype,
+                              const mkldnn::memory::desc &desc) {
+  return same_shape(shape, desc.data.dims, desc.data.ndims)
+      && get_mkldnn_type(dtype) == desc.data.data_type;
+}
+
+/*
+ * There is a large overhead of getting mkldnn::memory::primitive_desc from
+ * mkldnn::memory. This class is created to cache the metadata of mkldnn memory
+ * to provide a much more lightweight method to access them.
+ */
+class MKLDNNMemory {
+  std::shared_ptr<mkldnn::memory> mem;
+  mkldnn::memory::desc desc;
+  size_t size;      // The number of bytes.
+
+ public:
+  MKLDNNMemory(mkldnn::memory::primitive_desc pd, void *addr): desc(pd.desc()) 
{
+    mem.reset(new mkldnn::memory(pd, addr));
+    size = pd.get_size();
+  }
+
+  explicit MKLDNNMemory(std::shared_ptr<mkldnn::memory> mem): desc(
 
 Review comment:
   can this pointer be passed by reference to reduce the interlocked operation?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to