You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/22 13:14:15 UTC
[incubator-mxnet] branch master updated: Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 51e5204 Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
51e5204 is described below
commit 51e5204cab23b59182860aa7dbedef67dedc1e06
Author: bartekkuncer <ba...@intel.com>
AuthorDate: Tue Mar 22 14:12:25 2022 +0100
Add support for up 12 dims for oneDNN tensors in MXNet (#20913)
---
src/ndarray/ndarray.cc | 27 +++------------------------
src/operator/nn/dnnl/dnnl_base.cc | 12 ++++++++++++
2 files changed, 15 insertions(+), 24 deletions(-)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 49e1f94..3baa29c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -605,36 +605,15 @@ void NDArray::Chunk::SetMKLMem(const mxnet::TShape& shape, int dtype) {
dnnl::memory::dims dims;
// These are shapes supprted by DNNL.
- if (shape.ndim() >= 1 && shape.ndim() <= 6) {
+ const int MAX_ONEDNN_DIMS = 12;
+ if (shape.ndim() >= 1 && shape.ndim() <= MAX_ONEDNN_DIMS) {
dims.resize(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape[i];
} else {
LOG(FATAL) << "oneDNN doesn't support " << shape.ndim() << " dimensions";
}
- dnnl::memory::format_tag layout = dnnl::memory::format_tag::undef;
- switch (dims.size()) {
- case 1:
- layout = dnnl::memory::format_tag::a;
- break;
- case 2:
- layout = dnnl::memory::format_tag::ab;
- break;
- case 3:
- layout = dnnl::memory::format_tag::abc;
- break;
- case 4:
- layout = dnnl::memory::format_tag::abcd;
- break;
- case 5:
- layout = dnnl::memory::format_tag::abcde;
- break;
- case 6:
- layout = dnnl::memory::format_tag::abcdef;
- break;
- default:
- LOG(FATAL) << "Not implemented dimension (" << dims.size() << ") for oneDNN";
- }
+ auto layout = static_cast<dnnl::memory::format_tag>(GetDefaultFormat(dims.size()));
dnnl::memory::desc data_md{dims, get_dnnl_type(dtype), layout};
if (shandle.dptr == nullptr) {
CHECK(delay_alloc);
diff --git a/src/operator/nn/dnnl/dnnl_base.cc b/src/operator/nn/dnnl/dnnl_base.cc
index 27345a0..05fabd5 100644
--- a/src/operator/nn/dnnl/dnnl_base.cc
+++ b/src/operator/nn/dnnl/dnnl_base.cc
@@ -329,6 +329,18 @@ dnnl_format_tag_t GetDefaultFormat(int num_dims) {
return dnnl_abcde;
case 6:
return dnnl_abcdef;
+ case 7:
+ return dnnl_abcdefg;
+ case 8:
+ return dnnl_abcdefgh;
+ case 9:
+ return dnnl_abcdefghi;
+ case 10:
+ return dnnl_abcdefghij;
+ case 11:
+ return dnnl_abcdefghijk;
+ case 12:
+ return dnnl_abcdefghijkl;
default:
LOG(FATAL) << "Not implemented dimension (" << num_dims << ") for oneDNN";
return dnnl_format_tag_undef;