You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/22 13:14:15 UTC

[incubator-mxnet] branch master updated: Add support for up 12 dims for oneDNN tensors in MXNet (#20913)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 51e5204  Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
51e5204 is described below

commit 51e5204cab23b59182860aa7dbedef67dedc1e06
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Mar 22 14:12:25 2022 +0100

    Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
---
 src/ndarray/ndarray.cc            | 27 +++------------------------
 src/operator/nn/dnnl/dnnl_base.cc | 12 ++++++++++++
 2 files changed, 15 insertions(+), 24 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 49e1f94..3baa29c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -605,36 +605,15 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape& shape, int dtype) {
 
   dnnl::memory::dims dims;
   // These are shapes supprted by DNNL.
-  if (shape.ndim() >= 1 && shape.ndim() <= 6) {
+  const int MAX_ONEDNN_DIMS = 12;
+  if (shape.ndim() >= 1 && shape.ndim() <= MAX_ONEDNN_DIMS) {
     dims.resize(shape.ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape[i];
   } else {
     LOG(FATAL) << "oneDNN doesn't support " << shape.ndim() << " dimensions";
   }
-  dnnl::memory::format_tag layout = dnnl::memory::format_tag::undef;
-  switch (dims.size()) {
-    case 1:
-      layout = dnnl::memory::format_tag::a;
-      break;
-    case 2:
-      layout = dnnl::memory::format_tag::ab;
-      break;
-    case 3:
-      layout = dnnl::memory::format_tag::abc;
-      break;
-    case 4:
-      layout = dnnl::memory::format_tag::abcd;
-      break;
-    case 5:
-      layout = dnnl::memory::format_tag::abcde;
-      break;
-    case 6:
-      layout = dnnl::memory::format_tag::abcdef;
-      break;
-    default:
-      LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for oneDNN";
-  }
+  auto layout = static_cast<dnnl::memory::format_tag>(GetDefaultFormat(dims.size()));
   dnnl::memory::desc data_md{dims, get_dnnl_type(dtype), layout};
   if (shandle.dptr == nullptr) {
     CHECK(delay_alloc);
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 27345a0..05fabd5 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -329,6 +329,18 @@ dnnl_format_tag_t GetDefaultFormat(int num_dims) {
       return dnnl_abcde;
     case 6:
       return dnnl_abcdef;
+    case 7:
+      return dnnl_abcdefg;
+    case 8:
+      return dnnl_abcdefgh;
+    case 9:
+      return dnnl_abcdefghi;
+    case 10:
+      return dnnl_abcdefghij;
+    case 11:
+      return dnnl_abcdefghijk;
+    case 12:
+      return dnnl_abcdefghijkl;
     default:
       LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
       return dnnl_format_tag_undef;