This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 51e5204  Add support for up 12 dims for oneDNN tensors in MXNet 
(#20913)
51e5204 is described below

commit 51e5204cab23b59182860aa7dbedef67dedc1e06
Author: bartekkuncer <[email protected]>
AuthorDate: Tue Mar 22 14:12:25 2022 +0100

    Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
---
 src/ndarray/ndarray.cc            | 27 +++------------------------
 src/operator/nn/dnnl/dnnl_base.cc | 12 ++++++++++++
 2 files changed, 15 insertions(+), 24 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 49e1f94..3baa29c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -605,36 +605,15 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape& 
shape, int dtype) {
 
   dnnl::memory::dims dims;
   // These are shapes supprted by DNNL.
-  if (shape.ndim() >= 1 && shape.ndim() <= 6) {
+  const int MAX_ONEDNN_DIMS = 12;
+  if (shape.ndim() >= 1 && shape.ndim() <= MAX_ONEDNN_DIMS) {
     dims.resize(shape.ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape[i];
   } else {
     LOG(FATAL) << "oneDNN doesn't support " << shape.ndim() << " dimensions";
   }
-  dnnl::memory::format_tag layout = dnnl::memory::format_tag::undef;
-  switch (dims.size()) {
-    case 1:
-      layout = dnnl::memory::format_tag::a;
-      break;
-    case 2:
-      layout = dnnl::memory::format_tag::ab;
-      break;
-    case 3:
-      layout = dnnl::memory::format_tag::abc;
-      break;
-    case 4:
-      layout = dnnl::memory::format_tag::abcd;
-      break;
-    case 5:
-      layout = dnnl::memory::format_tag::abcde;
-      break;
-    case 6:
-      layout = dnnl::memory::format_tag::abcdef;
-      break;
-    default:
-      LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for 
oneDNN";
-  }
+  auto layout = 
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(dims.size()));
   dnnl::memory::desc data_md{dims, get_dnnl_type(dtype), layout};
   if (shandle.dptr == nullptr) {
     CHECK(delay_alloc);
diff --git a/src/operator/nn/dnnl/dnnl_base.cc 
b/src/operator/nn/dnnl/dnnl_base.cc
index 27345a0..05fabd5 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -329,6 +329,18 @@ dnnl_format_tag_t GetDefaultFormat(int num_dims) {
       return dnnl_abcde;
     case 6:
       return dnnl_abcdef;
+    case 7:
+      return dnnl_abcdefg;
+    case 8:
+      return dnnl_abcdefgh;
+    case 9:
+      return dnnl_abcdefghi;
+    case 10:
+      return dnnl_abcdefghij;
+    case 11:
+      return dnnl_abcdefghijk;
+    case 12:
+      return dnnl_abcdefghijkl;
     default:
       LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for 
oneDNN";
       return dnnl_format_tag_undef;

Reply via email to