You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/26 07:35:27 UTC

[tvm] branch main updated: [BYOC] Enable bfloat16 in DNNL BYOC (#11111)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8135860527 [BYOC] Enable bfloat16 in DNNL BYOC (#11111)
8135860527 is described below

commit 8135860527fa28853660f8f4747795b594b3f53f
Author: Youlei Yang <yo...@intel.com>
AuthorDate: Thu May 26 15:35:23 2022 +0800

    [BYOC] Enable bfloat16 in DNNL BYOC (#11111)
    
    * refine the code style (#10112)
    
    * support more data types in oneDNN BYOC
    
    * consider dtype when query layout
    
    * support more translation of blocked layout
    
    * refine log for invalid layout transform
    
    * reset N and C for the weights
    
    * support multi-blocking in TransDims2Plain()
    
    * add tests for bf16 oneDNN BYOC
    
    * unregister 'round' OP in oneDNN BYOC
    
    * restore the criteria for fp32 tests
    
    * disable test_prune_dnnl_subgraph for bf16
    
    * fix typo in dnnl.py
    
    * delete tag::format_tag_last
    
    * delete 'is_weight' in layout2tag()
    
    * reuse dtype_dl2dnnl()
    
    * fix lint errors
    
    * change to WARNING for invalid laytout transform
    
    * skip bf16 tests if AVX512 is unavailable
---
 cmake/modules/contrib/DNNL.cmake               |   4 +-
 include/tvm/tir/op.h                           |  24 +-
 python/tvm/relay/op/contrib/dnnl.py            |  25 +-
 src/relay/backend/contrib/dnnl/query_layout.cc |  33 +-
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc  | 656 ++++++++++++++++++++-----
 src/runtime/contrib/dnnl/dnnl_utils.cc         |  56 +++
 src/runtime/contrib/dnnl/dnnl_utils.h          |  46 ++
 src/tir/ir/data_layout.cc                      |  12 +-
 tests/python/contrib/test_dnnl.py              |  76 ++-
 9 files changed, 758 insertions(+), 174 deletions(-)

diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake
index 9e36f39891..6642719cb4 100644
--- a/cmake/modules/contrib/DNNL.cmake
+++ b/cmake/modules/contrib/DNNL.cmake
@@ -19,11 +19,11 @@ if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON"))
   add_definitions(-DUSE_JSON_RUNTIME=1)
   tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc)
   list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
-  list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})
 
   find_library(EXTERN_LIBRARY_DNNL dnnl)
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
-  tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc)
+  tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+                                      src/runtime/contrib/dnnl/dnnl_utils.cc)
   list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
   message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL})
 elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC")
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 5b63016d2f..905c67f1c5 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
                                   Span span = Span());
 
 // Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName)                       \
-  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {     \
-    static const Op& op = Op::Get("tir." #OpName);             \
-    if (x.dtype().is_bfloat16()) {                             \
-      DataType srcType = x.dtype();                            \
-      DataType dstType(kDLFloat, 32, srcType.lanes());         \
-      PrimExpr castX = tir::Cast(dstType, {x}, span);          \
-      PrimExpr result = tir::Call(dstType, op, {castX}, span); \
-      return tir::Cast(srcType, {result}, span);               \
-    } else {                                                   \
-      return tir::Call(x.dtype(), op, {x}, span);              \
-    }                                                          \
+#define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
+  inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
+    static const Op& op = Op::Get("tir." #OpName);                      \
+    if (x.dtype().is_bfloat16()) {                                      \
+      DataType bf16_dtype = x.dtype();                                  \
+      DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes());            \
+      PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span);               \
+      PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
+      return tir::Cast(bf16_dtype, {result_fp32}, span);                \
+    } else {                                                            \
+      return tir::Call(x.dtype(), op, {x}, span);                       \
+    }                                                                   \
   }
 
 TVM_DECLARE_INTRIN_UNARY(exp);
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 72e004b868..2e975cf49c 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -85,7 +85,6 @@ _register_external_op_helper("clip")
 _register_external_op_helper("exp")
 _register_external_op_helper("log")
 _register_external_op_helper("sqrt")
-_register_external_op_helper("round")
 _register_external_op_helper("nn.relu")
 _register_external_op_helper("nn.leaky_relu")
 _register_external_op_helper("tanh")
@@ -212,7 +211,7 @@ def pattern_table():
 
 
 def get_optimal_layout_for_conv(
-    data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
+    data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups, dtype
 ):
     """Get the optimal layout of dnnl, given shape of conv2d.
 
@@ -236,6 +235,7 @@ def get_optimal_layout_for_conv(
         strides,
         dilates,
         groups,
+        dtype,
     )
 
 
@@ -249,6 +249,7 @@ def get_optimal_layout_for_conv_transpose(
     strides,
     dilates,
     groups,
+    dtype,
 ):
     """Get the optimal layout of dnnl, given shape of tranposed conv2d.
 
@@ -274,6 +275,7 @@ def get_optimal_layout_for_conv_transpose(
         strides,
         dilates,
         groups,
+        dtype,
     )
 
 
@@ -292,6 +294,21 @@ def get_shape(tensor):
     raise TypeError("Unsupport data type: %s" % type(tensor))
 
 
+def get_dtype(tensor):
+    """Get tensor's dtype."""
+    if isinstance(tensor, relay.expr.Var):
+        return tensor.type_annotation.dtype
+    if isinstance(tensor, relay.expr.Constant):
+        return tensor.data.dtype
+    if isinstance(tensor, tvm.ir.tensor_type.TensorType):
+        return tensor.dtype
+    if isinstance(tensor, tvm.ir.container.Array):
+        return tensor[-1].dtype
+    if isinstance(tensor, relay.expr.Call):
+        return tensor.checked_type.dtype
+    raise TypeError("Unsupport data type: %s" % type(tensor))
+
+
 def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
     """Transfer layout, denoted with `a, b, c, d, e`,
     into valid layout (NCHW / OIHW) of TVM."""
@@ -353,6 +370,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
     paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
     strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
     dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
+    dtype = get_dtype(weight)
     new_attrs = dict(attrs)
     conv_type = type(attrs).__name__.split("Attrs")[0]
 
@@ -365,6 +383,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
         strides,
         dilates,
         groups,
+        dtype,
     )
     src_df, weight_df, dst_df = res.split(",")
     new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
@@ -389,6 +408,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
     strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
     dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
     groups = str(attrs.groups)
+    dtype = get_dtype(weight)
     new_attrs = dict(attrs)
     conv_type = type(attrs).__name__.split("Attrs")[0]
 
@@ -402,6 +422,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
         strides,
         dilates,
         groups,
+        dtype,
     )
     src_df, weight_df, dst_df = res.split(",")
     new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc
index 7fb1d824c7..3762c1906f 100755
--- a/src/relay/backend/contrib/dnnl/query_layout.cc
+++ b/src/relay/backend/contrib/dnnl/query_layout.cc
@@ -34,16 +34,17 @@
 #include <regex>
 #include <sstream>
 
+#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
 #include "../../utils.h"
 #include "dnnl.hpp"
-
-using dim_t = dnnl_dim_t;
-using dims_t = dnnl_dims_t;
-
 namespace tvm {
 namespace relay {
 namespace contrib {
 
+using dim_t = dnnl_dim_t;
+using dims_t = dnnl_dims_t;
+using tvm::runtime::contrib::dtype_dl2dnnl;
+
 template <typename T, typename U>
 inline void array_set(T* arr, const U& val, size_t size) {
   for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
@@ -192,7 +193,7 @@ void check_layout(bool var, bool ref) {
 std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
                                         std::string weight_shape, std::string out_shape,
                                         std::string paddings, std::string strides,
-                                        std::string dilates, std::string G) {
+                                        std::string dilates, std::string G, std::string dtype) {
   check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
   check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
   check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
@@ -200,7 +201,6 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
   dnnl::engine eng(dnnl::engine::kind::cpu, 0);
   dnnl::stream s(eng);
   using tag = dnnl::memory::format_tag;
-  using dt = dnnl::memory::data_type;
 
   dnnl::memory::dim groups = std::stoi(G);
   dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
@@ -249,9 +249,10 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
   dnnl::memory::dims conv_padding_l = padding_dims_l;
   dnnl::memory::dims conv_padding_r = padding_dims_r;
 
-  auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dt::f32, tag::any);
-  auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dt::f32, tag::any);
-  auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dt::f32, tag::any);
+  auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
+  auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dnnl_dtype, tag::any);
+  auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dnnl_dtype, tag::any);
+  auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dnnl_dtype, tag::any);
 
   auto conv_desc = dnnl::convolution_forward::desc(
       dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
@@ -276,7 +277,7 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
                                                   std::string weight_shape, std::string out_shape,
                                                   std::string paddings, std::string output_paddings,
                                                   std::string strides, std::string dilates,
-                                                  std::string G) {
+                                                  std::string G, std::string dtype) {
   check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
   check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
   check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
@@ -284,7 +285,6 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
   dnnl::engine eng(dnnl::engine::kind::cpu, 0);
   dnnl::stream s(eng);
   using tag = dnnl::memory::format_tag;
-  using dt = dnnl::memory::data_type;
 
   dnnl::memory::dim groups = std::stoi(G);
   dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
@@ -338,9 +338,10 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
   dnnl::memory::dims deconv_padding_l = padding_dims_l;
   dnnl::memory::dims deconv_padding_r = padding_dims_r;
 
-  auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dt::f32, tag::any);
-  auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dt::f32, tag::any);
-  auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dt::f32, tag::any);
+  auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
+  auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dnnl_dtype, tag::any);
+  auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dnnl_dtype, tag::any);
+  auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dnnl_dtype, tag::any);
 
   auto deconv_desc = dnnl::deconvolution_forward::desc(
       dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, deconv_src_md,
@@ -364,13 +365,13 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
 TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
       *rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5],
-                                        args[6], args[7]);
+                                        args[6], args[7], args[8]);
     });
 
 TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose")
     .set_body([](TVMArgs args, TVMRetValue* rv) {
       *rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4],
-                                                  args[5], args[6], args[7], args[8]);
+                                                  args[5], args[6], args[7], args[8], args[9]);
     });
 
 }  // namespace contrib
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index dc2afecbaf..f6a1c3b790 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -33,6 +33,7 @@
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
 #include "dnnl.hpp"
+#include "dnnl_utils.h"
 
 namespace tvm {
 namespace runtime {
@@ -66,8 +67,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // Fill in the input buffers.
     for (size_t i = 0; i < input_nodes_.size(); ++i) {
       auto eid = EntryID(input_nodes_[i], 0);
-      // TODO(@comaniac): Support other data lengths.
-      size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+      size_t offset_in_bytes =
+          entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8);
       size_t buffer_size = GetDataSize(*data_entry_[eid]);
       write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
                            offset_in_bytes);
@@ -82,7 +83,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // Read output buffers.
     for (size_t i = 0; i < outputs_.size(); ++i) {
       auto eid = EntryID(outputs_[i]);
-      size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+      size_t offset_in_bytes =
+          entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8);
       size_t buffer_size = GetDataSize(*data_entry_[eid]);
       read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
                             offset_in_bytes);
@@ -90,7 +92,501 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
   }
 
  private:
-  // Build up the engine based on the input graph.
+  tag layout2tag(std::string layout) {
+    static const std::map<std::string, tag> str2tag = {{"nc", tag::nc},
+                                                       {"cn", tag::cn},
+                                                       {"tn", tag::tn},
+                                                       {"nt", tag::nt},
+                                                       {"ncw", tag::ncw},
+                                                       {"nwc", tag::nwc},
+                                                       {"nchw", tag::nchw},
+                                                       {"nhwc", tag::nhwc},
+                                                       {"chwn", tag::chwn},
+                                                       {"ncdhw", tag::ncdhw},
+                                                       {"ndhwc", tag::ndhwc},
+                                                       {"oi", tag::oi},
+                                                       {"io", tag::io},
+                                                       {"oiw", tag::oiw},
+                                                       {"owi", tag::owi},
+                                                       {"wio", tag::wio},
+                                                       {"iwo", tag::iwo},
+                                                       {"oihw", tag::oihw},
+                                                       {"hwio", tag::hwio},
+                                                       {"ohwi", tag::ohwi},
+                                                       {"ihwo", tag::ihwo},
+                                                       {"iohw", tag::iohw},
+                                                       {"oidhw", tag::oidhw},
+                                                       {"dhwio", tag::dhwio},
+                                                       {"odhwi", tag::odhwi},
+                                                       {"iodhw", tag::iodhw},
+                                                       {"idhwo", tag::idhwo},
+                                                       {"goiw", tag::goiw},
+                                                       {"gowi", tag::gowi},
+                                                       {"wigo", tag::wigo},
+                                                       {"gohwi", tag::gohwi},
+                                                       {"goihw", tag::goihw},
+                                                       {"hwigo", tag::hwigo},
+                                                       {"giohw", tag::giohw},
+                                                       {"goidhw", tag::goidhw},
+                                                       {"giodhw", tag::giodhw},
+                                                       {"godhwi", tag::godhwi},
+                                                       {"dhwigo", tag::dhwigo},
+                                                       {"tnc", tag::tnc},
+                                                       {"ntc", tag::ntc},
+                                                       {"ldnc", tag::ldnc},
+                                                       {"ldigo", tag::ldigo},
+                                                       {"ldgoi", tag::ldgoi},
+                                                       {"ldio", tag::ldio},
+                                                       {"ldoi", tag::ldoi},
+                                                       {"ldgo", tag::ldgo},
+                                                       {"nCdhw16c", tag::nCdhw16c},
+                                                       {"nCdhw4c", tag::nCdhw4c},
+                                                       {"nCdhw8c", tag::nCdhw8c},
+                                                       {"nChw16c", tag::nChw16c},
+                                                       {"nChw4c", tag::nChw4c},
+                                                       {"nChw8c", tag::nChw8c},
+                                                       {"nCw16c", tag::nCw16c},
+                                                       {"nCw4c", tag::nCw4c},
+                                                       {"nCw8c", tag::nCw8c},
+                                                       {"NCw16n16c", tag::NCw16n16c},
+                                                       {"NChw16n16c", tag::NChw16n16c},
+                                                       {"NCdhw16n16c", tag::NCdhw16n16c},
+                                                       {"NCdhw32n32c", tag::NCdhw32n32c},
+                                                       {"NChw32n32c", tag::NChw32n32c},
+                                                       {"IOhw16i16o", tag::IOhw16i16o},
+                                                       {"OI16i16o", tag::OI16i16o},
+                                                       {"OI16i32o", tag::OI16i32o},
+                                                       {"OI16i64o", tag::OI16i64o},
+                                                       {"OI8i16o2i", tag::OI8i16o2i},
+                                                       {"OI8i32o2i", tag::OI8i32o2i},
+                                                       {"OI8i64o2i", tag::OI8i64o2i},
+                                                       {"OI4i16o4i", tag::OI4i16o4i},
+                                                       {"OI4i32o4i", tag::OI4i32o4i},
+                                                       {"OI4i64o4i", tag::OI4i64o4i},
+                                                       {"Ohwi32o", tag::Ohwi32o},
+                                                       {"IOdhw16i16o", tag::IOdhw16i16o},
+                                                       {"gIOhw16i16o", tag::gIOhw16i16o},
+                                                       {"gOhwi32o", tag::gOhwi32o},
+                                                       {"Goidhw16g", tag::Goidhw16g},
+                                                       {"IOw16o16i", tag::IOw16o16i},
+                                                       {"OIw16i16o", tag::OIw16i16o},
+                                                       {"OIw16i32o", tag::OIw16i32o},
+                                                       {"OIw16i64o", tag::OIw16i64o},
+                                                       {"IOw16i16o", tag::IOw16i16o},
+                                                       {"gIOw16i16o", tag::gIOw16i16o},
+                                                       {"OIw16o16i", tag::OIw16o16i},
+                                                       {"Oiw16o", tag::Oiw16o},
+                                                       {"OIw4i16o4i", tag::OIw4i16o4i},
+                                                       {"OIw4i32o4i", tag::OIw4i32o4i},
+                                                       {"OIw4i64o4i", tag::OIw4i64o4i},
+                                                       {"OIw2i8o4i", tag::OIw2i8o4i},
+                                                       {"OIw4i4o", tag::OIw4i4o},
+                                                       {"OIw4o4i", tag::OIw4o4i},
+                                                       {"Oiw4o", tag::Oiw4o},
+                                                       {"OIw8i16o2i", tag::OIw8i16o2i},
+                                                       {"OIw8i32o2i", tag::OIw8i32o2i},
+                                                       {"OIw8i64o2i", tag::OIw8i64o2i},
+                                                       {"OIw8i8o", tag::OIw8i8o},
+                                                       {"OIw8o16i2o", tag::OIw8o16i2o},
+                                                       {"OIw8o8i", tag::OIw8o8i},
+                                                       {"OIw8o4i", tag::OIw8o4i},
+                                                       {"OIw16i16o4i", tag::OIw16i16o4i},
+                                                       {"OIw16i32o4i", tag::OIw16i32o4i},
+                                                       {"OIw16i48o4i", tag::OIw16i48o4i},
+                                                       {"OIw16i64o4i", tag::OIw16i64o4i},
+                                                       {"OIw16i16o2i", tag::OIw16i16o2i},
+                                                       {"OIw16i32o2i", tag::OIw16i32o2i},
+                                                       {"OIw16i48o2i", tag::OIw16i48o2i},
+                                                       {"OIw16i64o2i", tag::OIw16i64o2i},
+                                                       {"OIw16o16i2o", tag::OIw16o16i2o},
+                                                       {"Owi16o", tag::Owi16o},
+                                                       {"OwI16o2i", tag::OwI16o2i},
+                                                       {"Owi4o", tag::Owi4o},
+                                                       {"Owi8o", tag::Owi8o},
+                                                       {"IOhw16o16i", tag::IOhw16o16i},
+                                                       {"Ohwi16o", tag::Ohwi16o},
+                                                       {"OhwI16o2i", tag::OhwI16o2i},
+                                                       {"Ohwi4o", tag::Ohwi4o},
+                                                       {"Ohwi8o", tag::Ohwi8o},
+                                                       {"OIhw16i16o", tag::OIhw16i16o},
+                                                       {"OIhw16i32o", tag::OIhw16i32o},
+                                                       {"OIhw16i64o", tag::OIhw16i64o},
+                                                       {"OIhw16o16i", tag::OIhw16o16i},
+                                                       {"Oihw16o", tag::Oihw16o},
+                                                       {"OIhw4i16o4i", tag::OIhw4i16o4i},
+                                                       {"OIhw4i32o4i", tag::OIhw4i32o4i},
+                                                       {"OIhw4i64o4i", tag::OIhw4i64o4i},
+                                                       {"OIhw4i4o", tag::OIhw4i4o},
+                                                       {"OIhw4o4i", tag::OIhw4o4i},
+                                                       {"Oihw4o", tag::Oihw4o},
+                                                       {"OIhw8i16o2i", tag::OIhw8i16o2i},
+                                                       {"OIhw8i32o2i", tag::OIhw8i32o2i},
+                                                       {"OIhw8i64o2i", tag::OIhw8i64o2i},
+                                                       {"OIhw8i8o", tag::OIhw8i8o},
+                                                       {"OIhw8o16i2o", tag::OIhw8o16i2o},
+                                                       {"OIhw8o8i", tag::OIhw8o8i},
+                                                       {"OIhw8o4i", tag::OIhw8o4i},
+                                                       {"OIhw2i8o4i", tag::OIhw2i8o4i},
+                                                       {"IOdhw16o16i", tag::IOdhw16o16i},
+                                                       {"Odhwi16o", tag::Odhwi16o},
+                                                       {"OdhwI16o2i", tag::OdhwI16o2i},
+                                                       {"Odhwi4o", tag::Odhwi4o},
+                                                       {"Odhwi8o", tag::Odhwi8o},
+                                                       {"OIdhw16i16o", tag::OIdhw16i16o},
+                                                       {"OIdhw16i32o", tag::OIdhw16i32o},
+                                                       {"OIdhw16i64o", tag::OIdhw16i64o},
+                                                       {"OIdhw16o16i", tag::OIdhw16o16i},
+                                                       {"Oidhw16o", tag::Oidhw16o},
+                                                       {"OIdhw4i4o", tag::OIdhw4i4o},
+                                                       {"OIdhw4o4i", tag::OIdhw4o4i},
+                                                       {"Oidhw4o", tag::Oidhw4o},
+                                                       {"OIdhw8i16o2i", tag::OIdhw8i16o2i},
+                                                       {"OIdhw8i32o2i", tag::OIdhw8i32o2i},
+                                                       {"OIdhw8i64o2i", tag::OIdhw8i64o2i},
+                                                       {"OIdhw4i16o4i", tag::OIdhw4i16o4i},
+                                                       {"OIdhw16i16o4i", tag::OIdhw16i16o4i},
+                                                       {"OIdhw16i32o4i", tag::OIdhw16i32o4i},
+                                                       {"OIdhw16i48o4i", tag::OIdhw16i48o4i},
+                                                       {"OIdhw16i64o4i", tag::OIdhw16i64o4i},
+                                                       {"OIdhw16i16o2i", tag::OIdhw16i16o2i},
+                                                       {"OIdhw16i32o2i", tag::OIdhw16i32o2i},
+                                                       {"OIdhw16i48o2i", tag::OIdhw16i48o2i},
+                                                       {"OIdhw16i64o2i", tag::OIdhw16i64o2i},
+                                                       {"OIdhw4i32o4i", tag::OIdhw4i32o4i},
+                                                       {"OIdhw4i64o4i", tag::OIdhw4i64o4i},
+                                                       {"OIdhw2i8o4i", tag::OIdhw2i8o4i},
+                                                       {"OIdhw8i8o", tag::OIdhw8i8o},
+                                                       {"OIdhw8o8i", tag::OIdhw8o8i},
+                                                       {"OIdhw8o4i", tag::OIdhw8o4i},
+                                                       {"gIOw16o16i", tag::gIOw16o16i},
+                                                       {"gOIw16i16o", tag::gOIw16i16o},
+                                                       {"gOIw16o16i", tag::gOIw16o16i},
+                                                       {"gOiw16o", tag::gOiw16o},
+                                                       {"gOIw4i16o4i", tag::gOIw4i16o4i},
+                                                       {"gOIw2i8o4i", tag::gOIw2i8o4i},
+                                                       {"gOIw4i4o", tag::gOIw4i4o},
+                                                       {"gOIw4o4i", tag::gOIw4o4i},
+                                                       {"gOiw4o", tag::gOiw4o},
+                                                       {"gOIw8i16o2i", tag::gOIw8i16o2i},
+                                                       {"gOIw8i8o", tag::gOIw8i8o},
+                                                       {"gOIw8o16i2o", tag::gOIw8o16i2o},
+                                                       {"gOIw8o8i", tag::gOIw8o8i},
+                                                       {"gOIw8o4i", tag::gOIw8o4i},
+                                                       {"gOIw16i16o4i", tag::gOIw16i16o4i},
+                                                       {"gOIw16i16o2i", tag::gOIw16i16o2i},
+                                                       {"gOIw16o16i2o", tag::gOIw16o16i2o},
+                                                       {"gOwi16o", tag::gOwi16o},
+                                                       {"gOwI16o2i", tag::gOwI16o2i},
+                                                       {"gOwi4o", tag::gOwi4o},
+                                                       {"gOwi8o", tag::gOwi8o},
+                                                       {"Goiw8g", tag::Goiw8g},
+                                                       {"Goiw16g", tag::Goiw16g},
+                                                       {"gIOhw16o16i", tag::gIOhw16o16i},
+                                                       {"gOhwi16o", tag::gOhwi16o},
+                                                       {"gOhwI16o2i", tag::gOhwI16o2i},
+                                                       {"gOhwi4o", tag::gOhwi4o},
+                                                       {"gOhwi8o", tag::gOhwi8o},
+                                                       {"Goihw16g", tag::Goihw16g},
+                                                       {"gOIhw16i16o", tag::gOIhw16i16o},
+                                                       {"gOIhw16o16i", tag::gOIhw16o16i},
+                                                       {"gOihw16o", tag::gOihw16o},
+                                                       {"gOIhw4i16o4i", tag::gOIhw4i16o4i},
+                                                       {"gOIhw2i8o4i", tag::gOIhw2i8o4i},
+                                                       {"gOIhw4i4o", tag::gOIhw4i4o},
+                                                       {"gOIhw4o4i", tag::gOIhw4o4i},
+                                                       {"gOihw4o", tag::gOihw4o},
+                                                       {"Goihw8g", tag::Goihw8g},
+                                                       {"gOIhw8i16o2i", tag::gOIhw8i16o2i},
+                                                       {"gOIhw8i8o", tag::gOIhw8i8o},
+                                                       {"gOIhw8o16i2o", tag::gOIhw8o16i2o},
+                                                       {"OIw4o8i8o4i", tag::OIw4o8i8o4i},
+                                                       {"OIdhw4o8i8o4i", tag::OIdhw4o8i8o4i},
+                                                       {"OIhw4o8i8o4i", tag::OIhw4o8i8o4i},
+                                                       {"OIhw2o8i8o2i", tag::OIhw2o8i8o2i},
+                                                       {"gOIw4o8i8o4i", tag::gOIw4o8i8o4i},
+                                                       {"gOIdhw4o8i8o4i", tag::gOIdhw4o8i8o4i},
+                                                       {"gOIhw4o8i8o4i", tag::gOIhw4o8i8o4i},
+                                                       {"gOIhw2o8i8o2i", tag::gOIhw2o8i8o2i},
+                                                       {"OIhw16i16o4i", tag::OIhw16i16o4i},
+                                                       {"OIhw16i32o4i", tag::OIhw16i32o4i},
+                                                       {"OIhw16i48o4i", tag::OIhw16i48o4i},
+                                                       {"OIhw16i64o4i", tag::OIhw16i64o4i},
+                                                       {"OIhw16i16o2i", tag::OIhw16i16o2i},
+                                                       {"OIhw16i32o2i", tag::OIhw16i32o2i},
+                                                       {"OIhw16i48o2i", tag::OIhw16i48o2i},
+                                                       {"OIhw16i64o2i", tag::OIhw16i64o2i},
+                                                       {"OIhw16o16i2o", tag::OIhw16o16i2o},
+                                                       {"gOIhw16i16o4i", tag::gOIhw16i16o4i},
+                                                       {"gOIhw16i16o2i", tag::gOIhw16i16o2i},
+                                                       {"gOIhw16o16i2o", tag::gOIhw16o16i2o},
+                                                       {"gOIhw8o8i", tag::gOIhw8o8i},
+                                                       {"gOIhw8o4i", tag::gOIhw8o4i},
+                                                       {"gIOdhw16i16o", tag::gIOdhw16i16o},
+                                                       {"gIOdhw16o16i", tag::gIOdhw16o16i},
+                                                       {"gOdhwi16o", tag::gOdhwi16o},
+                                                       {"gOdhwI16o2i", tag::gOdhwI16o2i},
+                                                       {"gOdhwi4o", tag::gOdhwi4o},
+                                                       {"gOdhwi8o", tag::gOdhwi8o},
+                                                       {"gOIdhw16i16o", tag::gOIdhw16i16o},
+                                                       {"gOIdhw16o16i", tag::gOIdhw16o16i},
+                                                       {"gOidhw16o", tag::gOidhw16o},
+                                                       {"gOIdhw4i4o", tag::gOIdhw4i4o},
+                                                       {"gOIdhw4o4i", tag::gOIdhw4o4i},
+                                                       {"gOidhw4o", tag::gOidhw4o},
+                                                       {"gOIdhw8i16o2i", tag::gOIdhw8i16o2i},
+                                                       {"gOIdhw4i16o4i", tag::gOIdhw4i16o4i},
+                                                       {"gOIdhw16i16o4i", tag::gOIdhw16i16o4i},
+                                                       {"gOIdhw16i16o2i", tag::gOIdhw16i16o2i},
+                                                       {"gOIdhw2i8o4i", tag::gOIdhw2i8o4i},
+                                                       {"gOIdhw8i8o", tag::gOIdhw8i8o},
+                                                       {"gOIdhw8o8i", tag::gOIdhw8o8i},
+                                                       {"gOIdhw8o4i", tag::gOIdhw8o4i},
+                                                       {"gOIw2i4o2i", tag::gOIw2i4o2i},
+                                                       {"gOIhw2i4o2i", tag::gOIhw2i4o2i},
+                                                       {"gOIdhw2i4o2i", tag::gOIdhw2i4o2i},
+                                                       {"gOIw2o4i2o", tag::gOIw2o4i2o},
+                                                       {"gOIhw2o4i2o", tag::gOIhw2o4i2o},
+                                                       {"gOIdhw2o4i2o", tag::gOIdhw2o4i2o},
+                                                       {"gOIw4i8o2i", tag::gOIw4i8o2i},
+                                                       {"gOIhw4i8o2i", tag::gOIhw4i8o2i},
+                                                       {"gOIdhw4i8o2i", tag::gOIdhw4i8o2i},
+                                                       {"gOIw4o8i2o", tag::gOIw4o8i2o},
+                                                       {"gOIhw4o8i2o", tag::gOIhw4o8i2o},
+                                                       {"gOIdhw4o8i2o", tag::gOIdhw4o8i2o},
+                                                       {"ldOi32o", tag::ldOi32o},
+                                                       {"ldOI32o4i", tag::ldOI32o4i},
+                                                       {"ldgOi32o", tag::ldgOi32o},
+                                                       {"ldgOI32o2i", tag::ldgOI32o2i},
+                                                       {"ldgOI32o4i", tag::ldgOI32o4i},
+                                                       {"OwI16o4i", tag::OwI16o4i},
+                                                       {"OhwI16o4i", tag::OhwI16o4i},
+                                                       {"gOwI16o4i", tag::gOwI16o4i},
+                                                       {"gOhwI16o4i", tag::gOhwI16o4i},
+                                                       {"OdhwI16o4i", tag::OdhwI16o4i},
+                                                       {"gOdhwI16o4i", tag::gOdhwI16o4i},
+                                                       {"Owi32o", tag::Owi32o},
+                                                       {"OwI32o2i", tag::OwI32o2i},
+                                                       {"OwI32o4i", tag::OwI32o4i},
+                                                       {"Owi48o", tag::Owi48o},
+                                                       {"OwI48o2i", tag::OwI48o2i},
+                                                       {"OwI48o4i", tag::OwI48o4i},
+                                                       {"Owi64o", tag::Owi64o},
+                                                       {"OwI64o2i", tag::OwI64o2i},
+                                                       {"OwI64o4i", tag::OwI64o4i},
+                                                       {"wIo2i", tag::wIo2i},
+                                                       {"wIo4i", tag::wIo4i},
+                                                       {"gOwi32o", tag::gOwi32o},
+                                                       {"gOwI32o2i", tag::gOwI32o2i},
+                                                       {"gOwI32o4i", tag::gOwI32o4i},
+                                                       {"gOwi48o", tag::gOwi48o},
+                                                       {"gOwI48o2i", tag::gOwI48o2i},
+                                                       {"gOwI48o4i", tag::gOwI48o4i},
+                                                       {"gOwi64o", tag::gOwi64o},
+                                                       {"gOwI64o2i", tag::gOwI64o2i},
+                                                       {"gOwI64o4i", tag::gOwI64o4i},
+                                                       {"gwio", tag::gwio},
+                                                       {"gwIo2i", tag::gwIo2i},
+                                                       {"gwIo4i", tag::gwIo4i},
+                                                       {"OhwI32o", tag::OhwI32o},
+                                                       {"OhwI32o2i", tag::OhwI32o2i},
+                                                       {"OhwI32o4i", tag::OhwI32o4i},
+                                                       {"Ohwi48o", tag::Ohwi48o},
+                                                       {"OhwI48o2i", tag::OhwI48o2i},
+                                                       {"OhwI48o4i", tag::OhwI48o4i},
+                                                       {"Ohwi64o", tag::Ohwi64o},
+                                                       {"OhwI64o2i", tag::OhwI64o2i},
+                                                       {"OhwI64o4i", tag::OhwI64o4i},
+                                                       {"hwIo2i", tag::hwIo2i},
+                                                       {"hwIo4i", tag::hwIo4i},
+                                                       {"gOhwI32o", tag::gOhwI32o},
+                                                       {"gOhwI32o2i", tag::gOhwI32o2i},
+                                                       {"gOhwI32o4i", tag::gOhwI32o4i},
+                                                       {"gOhwi48o", tag::gOhwi48o},
+                                                       {"gOhwI48o2i", tag::gOhwI48o2i},
+                                                       {"gOhwI48o4i", tag::gOhwI48o4i},
+                                                       {"gOhwi64o", tag::gOhwi64o},
+                                                       {"gOhwI64o2i", tag::gOhwI64o2i},
+                                                       {"gOhwI64o4i", tag::gOhwI64o4i},
+                                                       {"ghwio", tag::ghwio},
+                                                       {"ghwIo2i", tag::ghwIo2i},
+                                                       {"ghwIo4i", tag::ghwIo4i},
+                                                       {"Odhwi32o", tag::Odhwi32o},
+                                                       {"OdhwI32o2i", tag::OdhwI32o2i},
+                                                       {"OdhwI32o4i", tag::OdhwI32o4i},
+                                                       {"Odhwi48o", tag::Odhwi48o},
+                                                       {"OdhwI48o2i", tag::OdhwI48o2i},
+                                                       {"OdhwI48o4i", tag::OdhwI48o4i},
+                                                       {"Odhwi64o", tag::Odhwi64o},
+                                                       {"OdhwI64o2i", tag::OdhwI64o2i},
+                                                       {"OdhwI64o4i", tag::OdhwI64o4i},
+                                                       {"dhwIo2i", tag::dhwIo2i},
+                                                       {"dhwIo4i", tag::dhwIo4i},
+                                                       {"gOdhwi32o", tag::gOdhwi32o},
+                                                       {"gOdhwI32o2i", tag::gOdhwI32o2i},
+                                                       {"gOdhwI32o4i", tag::gOdhwI32o4i},
+                                                       {"gOdhwi48o", tag::gOdhwi48o},
+                                                       {"gOdhwI48o2i", tag::gOdhwI48o2i},
+                                                       {"gOdhwI48o4i", tag::gOdhwI48o4i},
+                                                       {"gOdhwi64o", tag::gOdhwi64o},
+                                                       {"gOdhwI64o2i", tag::gOdhwI64o2i},
+                                                       {"gOdhwI64o4i", tag::gOdhwI64o4i},
+                                                       {"gdhwio", tag::gdhwio},
+                                                       {"gdhwIo2i", tag::gdhwIo2i},
+                                                       {"gdhwIo4i", tag::gdhwIo4i},
+                                                       {"ldIo32i", tag::ldIo32i},
+                                                       {"ldgIo32i", tag::ldgIo32i},
+                                                       {"ldgIO32i2o", tag::ldgIO32i2o},
+                                                       {"nCdhw32c", tag::nCdhw32c},
+                                                       {"nChw32c", tag::nChw32c},
+                                                       {"nCw32c", tag::nCw32c},
+                                                       {"NCw32n16c", tag::NCw32n16c},
+                                                       {"NChw32n16c", tag::NChw32n16c},
+                                                       {"NCdhw32n16c", tag::NCdhw32n16c},
+                                                       {"NCw32n32c", tag::NCw32n32c},
+                                                       {"OI16i16o4i", tag::OI16i16o4i},
+                                                       {"IOw8o16i2o", tag::IOw8o16i2o},
+                                                       {"IOhw8o16i2o", tag::IOhw8o16i2o},
+                                                       {"Owhi16o", tag::Owhi16o},
+                                                       {"OIdhw8o16i2o", tag::OIdhw8o16i2o},
+                                                       {"IOdhw8o16i2o", tag::IOdhw8o16i2o},
+                                                       {"Goiw4g", tag::Goiw4g},
+                                                       {"gIOw8o16i2o", tag::gIOw8o16i2o},
+                                                       {"Goiw32g", tag::Goiw32g},
+                                                       {"Goihw4g", tag::Goihw4g},
+                                                       {"gIOhw8o16i2o", tag::gIOhw8o16i2o},
+                                                       {"Goihw32g", tag::Goihw32g},
+                                                       {"gOwhi16o", tag::gOwhi16o},
+                                                       {"IOw4i8o8i4o", tag::IOw4i8o8i4o},
+                                                       {"IOhw4i8o8i4o", tag::IOhw4i8o8i4o},
+                                                       {"IOdhw4i8o8i4o", tag::IOdhw4i8o8i4o},
+                                                       {"gIOw4i8o8i4o", tag::gIOw4i8o8i4o},
+                                                       {"gIOhw4i8o8i4o", tag::gIOhw4i8o8i4o},
+                                                       {"gIOdhw4i8o8i4o", tag::gIOdhw4i8o8i4o},
+                                                       {"gOIdhw8o16i2o", tag::gOIdhw8o16i2o},
+                                                       {"gIOdhw8o16i2o", tag::gIOdhw8o16i2o},
+                                                       {"Goidhw32g", tag::Goidhw32g},
+                                                       {"OI16i32o4i", tag::OI16i32o4i},
+                                                       {"OI16i48o4i", tag::OI16i48o4i},
+                                                       {"OI16i64o4i", tag::OI16i64o4i},
+                                                       {"OI16i16o2i", tag::OI16i16o2i},
+                                                       {"OI16i32o2i", tag::OI16i32o2i},
+                                                       {"OI16i48o2i", tag::OI16i48o2i},
+                                                       {"OI16i64o2i", tag::OI16i64o2i},
+                                                       {"OwI16i16o2i", tag::OwI16i16o2i},
+                                                       {"gOwI16i16o2i", tag::gOwI16i16o2i},
+                                                       {"OhwI16i16o2i", tag::OhwI16i16o2i},
+                                                       {"gOhwI16i16o2i", tag::gOhwI16i16o2i},
+                                                       {"OdhwI16i16o2i", tag::OdhwI16i16o2i},
+                                                       {"gOdhwI16i16o2i", tag::gOdhwI16i16o2i},
+                                                       {"OwI16i16o4i", tag::OwI16i16o4i},
+                                                       {"gOwI16i16o4i", tag::gOwI16i16o4i},
+                                                       {"OhwI16i16o4i", tag::OhwI16i16o4i},
+                                                       {"gOhwI16i16o4i", tag::gOhwI16i16o4i},
+                                                       {"OdhwI16i16o4i", tag::OdhwI16i16o4i},
+                                                       {"gOdhwI16i16o4i", tag::gOdhwI16i16o4i},
+                                                       {"OwI16i32o2i", tag::OwI16i32o2i},
+                                                       {"OwI16i32o4i", tag::OwI16i32o4i},
+                                                       {"OwI16i48o2i", tag::OwI16i48o2i},
+                                                       {"OwI16i48o4i", tag::OwI16i48o4i},
+                                                       {"OwI16i64o2i", tag::OwI16i64o2i},
+                                                       {"OwI16i64o4i", tag::OwI16i64o4i},
+                                                       {"gOwI16i32o2i", tag::gOwI16i32o2i},
+                                                       {"gOwI16i32o4i", tag::gOwI16i32o4i},
+                                                       {"gOwI16i48o2i", tag::gOwI16i48o2i},
+                                                       {"gOwI16i48o4i", tag::gOwI16i48o4i},
+                                                       {"gOwI16i64o2i", tag::gOwI16i64o2i},
+                                                       {"gOwI16i64o4i", tag::gOwI16i64o4i},
+                                                       {"OhwI16i32o2i", tag::OhwI16i32o2i},
+                                                       {"OhwI16i32o4i", tag::OhwI16i32o4i},
+                                                       {"OhwI16i48o2i", tag::OhwI16i48o2i},
+                                                       {"OhwI16i48o4i", tag::OhwI16i48o4i},
+                                                       {"OhwI16i64o2i", tag::OhwI16i64o2i},
+                                                       {"OhwI16i64o4i", tag::OhwI16i64o4i},
+                                                       {"gOhwI16i32o2i", tag::gOhwI16i32o2i},
+                                                       {"gOhwI16i32o4i", tag::gOhwI16i32o4i},
+                                                       {"gOhwI16i48o2i", tag::gOhwI16i48o2i},
+                                                       {"gOhwI16i48o4i", tag::gOhwI16i48o4i},
+                                                       {"gOhwI16i64o2i", tag::gOhwI16i64o2i},
+                                                       {"gOhwI16i64o4i", tag::gOhwI16i64o4i},
+                                                       {"OdhwI16i32o2i", tag::OdhwI16i32o2i},
+                                                       {"OdhwI16i32o4i", tag::OdhwI16i32o4i},
+                                                       {"OdhwI16i48o2i", tag::OdhwI16i48o2i},
+                                                       {"OdhwI16i48o4i", tag::OdhwI16i48o4i},
+                                                       {"OdhwI16i64o2i", tag::OdhwI16i64o2i},
+                                                       {"OdhwI16i64o4i", tag::OdhwI16i64o4i},
+                                                       {"gOdhwI16i32o2i", tag::gOdhwI16i32o2i},
+                                                       {"gOdhwI16i32o4i", tag::gOdhwI16i32o4i},
+                                                       {"gOdhwI16i48o2i", tag::gOdhwI16i48o2i},
+                                                       {"gOdhwI16i48o4i", tag::gOdhwI16i48o4i},
+                                                       {"gOdhwI16i64o2i", tag::gOdhwI16i64o2i},
+                                                       {"gOdhwI16i64o4i", tag::gOdhwI16i64o4i},
+                                                       {"hwioG16g", tag::hwioG16g},
+                                                       {"NCdhw40n32c", tag::NCdhw40n32c},
+                                                       {"NChw40n32c", tag::NChw40n32c},
+                                                       {"NCw40n32c", tag::NCw40n32c},
+                                                       {"OIdhw4o8i8o2i", tag::OIdhw4o8i8o2i},
+                                                       {"OIhw4o8i8o2i", tag::OIhw4o8i8o2i},
+                                                       {"OIw4o8i8o2i", tag::OIw4o8i8o2i},
+                                                       {"gOIdhw4o8i8o2i", tag::gOIdhw4o8i8o2i},
+                                                       {"gOIhw4o8i8o2i", tag::gOIhw4o8i8o2i},
+                                                       {"gOIw4o8i8o2i", tag::gOIw4o8i8o2i},
+                                                       {"IOdhw4i8o8i2o", tag::IOdhw4i8o8i2o},
+                                                       {"IOhw4i8o8i2o", tag::IOhw4i8o8i2o},
+                                                       {"IOw4i8o8i2o", tag::IOw4i8o8i2o},
+                                                       {"gIOdhw4i8o8i2o", tag::gIOdhw4i8o8i2o},
+                                                       {"gIOhw4i8o8i2o", tag::gIOhw4i8o8i2o},
+                                                       {"gIOw4i8o8i2o", tag::gIOw4i8o8i2o},
+                                                       {"NCdhw40n16c", tag::NCdhw40n16c},
+                                                       {"NCw40n16c", tag::NCw40n16c},
+                                                       {"NChw40n16c", tag::NChw40n16c},
+                                                       {"NCw2c32n8c", tag::NCw2c32n8c},
+                                                       {"NChw2c32n8c", tag::NChw2c32n8c},
+                                                       {"NCdhw2c32n8c", tag::NCdhw2c32n8c},
+                                                       {"OIw2i8o16i4o", tag::OIw2i8o16i4o},
+                                                       {"OIhw2i8o16i4o", tag::OIhw2i8o16i4o},
+                                                       {"OIdhw2i8o16i4o", tag::OIdhw2i8o16i4o},
+                                                       {"OIw2o8i16o4i", tag::OIw2o8i16o4i},
+                                                       {"OIw2o8i16o2i", tag::OIw2o8i16o2i},
+                                                       {"IOw2i8o16i4o", tag::IOw2i8o16i4o},
+                                                       {"IOw2i8o16i2o", tag::IOw2i8o16i2o},
+                                                       {"OIhw2o8i16o4i", tag::OIhw2o8i16o4i},
+                                                       {"OIhw2o8i16o2i", tag::OIhw2o8i16o2i},
+                                                       {"IOhw2i8o16i4o", tag::IOhw2i8o16i4o},
+                                                       {"IOhw2i8o16i2o", tag::IOhw2i8o16i2o},
+                                                       {"OIdhw2o8i16o4i", tag::OIdhw2o8i16o4i},
+                                                       {"OIdhw2o8i16o2i", tag::OIdhw2o8i16o2i},
+                                                       {"IOdhw2i8o16i4o", tag::IOdhw2i8o16i4o},
+                                                       {"IOdhw2i8o16i2o", tag::IOdhw2i8o16i2o},
+                                                       {"gOIw2o8i16o2i", tag::gOIw2o8i16o2i},
+                                                       {"gIOw2i8o16i2o", tag::gIOw2i8o16i2o},
+                                                       {"gIOhw2i8o16i2o", tag::gIOhw2i8o16i2o},
+                                                       {"gIOdhw2i8o16i2o", tag::gIOdhw2i8o16i2o},
+                                                       {"gOIhw2o8i16o2i", tag::gOIhw2o8i16o2i},
+                                                       {"gOIdhw2o8i16o2i", tag::gOIdhw2o8i16o2i},
+                                                       {"gOIw2o8i16o4i", tag::gOIw2o8i16o4i},
+                                                       {"gOIhw2o8i16o4i", tag::gOIhw2o8i16o4i}};
+    std::string key = "";
+    for (const auto& c : layout) {
+      if (std::isalpha(c, std::locale("C"))) {
+        char lower_c = std::tolower(c);
+        if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) {
+          key.push_back(c);
+        } else {
+          key.push_back(lower_c);
+        }
+      } else if (std::isdigit(c)) {
+        key.push_back(c);
+      } else {
+        LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl;
+      }
+    }
+    if (str2tag.count(key) == 0) {
+      LOG(WARNING) << "convert unregistered layout '" << key << "' to tag::any";
+      return tag::any;
+    } else {
+      return str2tag.at(key);
+    }
+  }
 
   std::map<std::string, dnnl::algorithm> elt_name2algo{
       {"abs", dnnl::algorithm::eltwise_abs},
@@ -106,62 +602,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
       {"clip", dnnl::algorithm::eltwise_clip},
   };
 
-  std::map<std::string, tag> layout_dict{
-      {"", tag::any},
-      {"NCW", tag::ncw},
-      {"NWC", tag::nwc},
-      {"OIW", tag::oiw},
-      {"GOIW", tag::goiw},
-      {"NCHW", tag::nchw},
-      {"NHWC", tag::nhwc},
-      {"OIHW", tag::oihw},
-      {"GOIHW", tag::goihw},
-      {"NCDHW", tag::ncdhw},
-      {"NDHWC", tag::ndhwc},
-      {"OIDHW", tag::oidhw},
-      {"GOIDHW", tag::goidhw},
-      {"IOHW", tag::iohw},
-      {"GIOHW", tag::giohw},
-      {"IODHW", tag::iodhw},
-      {"GIODHW", tag::giodhw},
-
-      // Blocking layout.
-      {"NCW8c", tag::nCw8c},
-      {"NCW16c", tag::nCw16c},
-      {"OIW16i16o", tag::OIw8i8o},
-      {"OIW16i16o", tag::OIw16i16o},
-      {"OWI8o", tag::Owi8o},
-      {"OWI16o", tag::Owi16o},
-      {"NCHW4c", tag::nChw4c},
-      {"NCHW8c", tag::nChw8c},
-      {"NCHW16c", tag::nChw16c},
-      {"OIHW8i8o", tag::OIhw8i8o},
-      {"IOHW8i8o", tag::any},
-      {"OIHW16i16o", tag::OIhw16i16o},
-      {"IOHW16i16o", tag::IOhw16i16o},
-      {"GOIHW4i4o", tag::gOIhw4i4o},
-      {"GOIHW8i8o", tag::gOIhw8i8o},
-      {"GOIHW16i16o", tag::gOIhw16i16o},
-      {"OHWI8o", tag::Ohwi8o},
-      {"OHWI16o", tag::Ohwi16o},
-      {"OHWI32o", tag::Ohwi32o},
-      {"OHWI48o", tag::Ohwi48o},
-      {"OHWI64o", tag::Ohwi64o},
-      {"GOIHW8g", tag::Goihw8g},
-      {"GOIHW16g", tag::Goihw16g},
-      {"NCDHW8c", tag::nCdhw8c},
-      {"NCDHW16c", tag::nCdhw16c},
-      {"OIDHW16i16o", tag::OIdhw16i16o},
-      {"IODHW16i16o", tag::IOdhw16i16o},
-      {"OIDHW8i8o", tag::OIdhw8i8o},
-      {"IODHW8i8o", tag::any},
-      {"ODHWI8o", tag::Odhwi8o},
-      {"ODHWI16o", tag::Odhwi16o},
-      {"ODHWI32o", tag::Odhwi32o},
-      {"ODHWI48o", tag::Odhwi48o},
-      {"ODHWI64o", tag::Odhwi64o},
-  };
-
   bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
     // Define RegExp.
     std::regex bias_add_pat(".*_bias.*");
@@ -202,12 +642,13 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     }
     // Push the correct shapes of each axis into the output_dims
     for (auto a : axis) {
-      dnnl::memory::dim shape = 1;
       if (layout.find(a) != std::string::npos) {
-        shape *= input_dims[layout.find(a)];
+        dnnl::memory::dim shape = input_dims[layout.find(a)];
         char lower_a = std::tolower(a);
-        if (layout.find(lower_a) != std::string::npos) {
-          shape *= input_dims[layout.find(lower_a)];
+        for (size_t i = 0; i < layout.size(); ++i) {
+          if (lower_a == layout[i]) {
+            shape *= input_dims[i];
+          }
         }
         out_dims.push_back(shape);
       }
@@ -238,6 +679,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     return out_dims;
   }
 
+  // Build up the engine based on the input graph.
   void BuildEngine() {
     engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
     stream_ = dnnl::stream(engine_);
@@ -301,11 +743,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     // has not yet been bound to the other DNNL memory; otherwise it may have memory leak.
     ICHECK_EQ(entry_out_mem_.count(eid), 0);
 
-    // TODO(@comanic): Support other data types (i.e., int8).
-    auto data_node = nodes_[entry.id_];
-    auto dltype = data_node.GetOpDataType()[entry.index_];
-    ICHECK_EQ(dltype.bits, 32);
-
     entry_out_mem_[eid] = {mem, offset};
     return entry_out_mem_[eid].first;
   }
@@ -338,17 +775,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::string data_layout = node.GetAttr<std::vector<std::string>>("data_layout")[0];
     std::string kernel_layout = node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
 
-    // Check layout.
-    if (layout_dict.find(data_layout) == layout_dict.end()) {
-      LOG(FATAL) << "Unsupported data layout for conv: " << data_layout;
-    }
-
-    if (layout_dict.find(kernel_layout) == layout_dict.end()) {
-      layout_dict.insert({kernel_layout, tag::any});
-      LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
-                   << ", transfer to tag::any";
-    }
-
     // Memory shapes.
     dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
     dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout);
@@ -360,6 +786,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     dnnl::memory::dims dst_dims = src_dims;
     dst_dims[1] = channels;
     weights_dims_[0] = channels;
+    weights_dims_[1] = src_dims[1];
     for (size_t i = 2; i < src_dims.size(); i++) {
       dnnl::memory::dim K = weights_dims_[i];
       dnnl::memory::dim S = strides_dims[i - 2];
@@ -380,10 +807,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     }
 
     // Memory descriptions.
-    auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]);
-    auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]);
-    auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
-    auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    auto conv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout));
+    auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout));
+    auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any);
+    auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
 
     // Conv description.
     auto conv_desc =
@@ -413,7 +841,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc());
 
     // Bias memory.
-    auto conv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
+    auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
     if (has_bias) {
       auto bias_entry = node.GetInputs()[2];
       BindDNNLMemory(bias_entry, conv_bias_memory);
@@ -461,17 +889,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::string data_layout = node.GetAttr<std::vector<std::string>>("data_layout")[0];
     std::string kernel_layout = node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
 
-    // Check layout.
-    if (layout_dict.find(data_layout) == layout_dict.end()) {
-      LOG(FATAL) << "Unsupported data layout for deconv: " << data_layout;
-    }
-
-    if (layout_dict.find(kernel_layout) == layout_dict.end()) {
-      layout_dict.insert({kernel_layout, tag::any});
-      LOG(WARNING) << "Unregistered kernel layout for deconv: " << data_layout
-                   << ", transfer to tag::any";
-    }
-
     // Memory shapes.
     dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
     dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout);
@@ -482,6 +899,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
         kernel_layout.replace(kernel_layout.find("OI"), 2, "IO");
       }
     }
+    weights_dims_[0] = channels;
+    weights_dims_[1] = src_dims[1];
     dnnl::memory::dims bias_dims = {channels};
     dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
     dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
@@ -508,10 +927,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     }
 
     // Memory descriptions.
-    auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]);
-    auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]);
-    auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
-    auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout));
+    auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout));
+    auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x);
+    auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
 
     // Transposed covn2d description.
     auto deconv_desc =
@@ -541,7 +961,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc());
 
     // Bias memory.
-    auto deconv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
+    auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
     if (has_bias) {
       auto bias_entry = node.GetInputs()[2];
       BindDNNLMemory(bias_entry, deconv_bias_memory);
@@ -581,10 +1001,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     dnnl::memory::dims out_dims = out_shape;
 
     // Memory descriptions.
-    auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc});
-    auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc});
-    auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x});
-    auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc});
+    auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
+    auto dtype = dtype_dl2dnnl(dl_dtype);
+    auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc});
+    auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc});
+    auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x});
+    auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc});
 
     // Dense description.
     auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
@@ -607,7 +1029,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
       BindDNNLMemory(bias_entry, bias_memory);
     } else {
       float bias[OC] = {0};
-      write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+      write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8));
     }
 
     // Output memory.
@@ -632,7 +1054,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
 
     // Memory description.
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
 
     // BN description.
     auto bn_desc = dnnl::batch_normalization_forward::desc(
@@ -679,11 +1102,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
     std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
 
-    // Check layout.
-    if (layout_dict.find(layout) == layout_dict.end()) {
-      LOG(FATAL) << "Unsupported layout for pooling: " << layout;
-    }
-
     // Attributes related to AvgPool
     if (algo == dnnl::algorithm::pooling_avg) {
       int int_countpad = std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
@@ -701,8 +1119,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
 
     // Memory descriptions.
-    auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]);
-    auto pool_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout));
+    auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
 
     // Pooling description.
     auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo,
@@ -729,7 +1148,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
     auto data_entry = node.GetInputs()[0];
     dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
     float alpha = 0., beta = 0.;
     if (op_name == "clip") {
       alpha = std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
@@ -762,7 +1182,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     if (axis < 0) {
       axis = shape.size() + axis;
     }
-    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
+    auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+    dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
 
     auto softmax_desc =
         dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis);
@@ -790,7 +1211,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     ICHECK_EQ(node.GetInputs().size(), 2U);
     for (auto entry : node.GetInputs()) {
       auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
-      dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+      auto dtype = dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]);
+      dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
 
       data_dims.push_back(data_shape);
       data_mds.push_back(data_md);
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc b/src/runtime/contrib/dnnl/dnnl_utils.cc
new file mode 100644
index 0000000000..7e79f1c939
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_utils.cc
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_utils.cc
+ */
+
+#include "dnnl_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+using dt = dnnl::memory::data_type;
+dt dtype_dl2dnnl(DLDataType dltype) {
+  dt dnnl_type = dt::undef;
+  if (dltype.code == DataType::TypeCode::kFloat) {
+    if (dltype.bits == 16) {
+      dnnl_type = dt::f16;
+    } else if (dltype.bits == 32) {
+      dnnl_type = dt::f32;
+    }
+  } else if (dltype.code == DataType::TypeCode::kBFloat && dltype.bits == 16) {
+    dnnl_type = dt::bf16;
+  } else if (dltype.code == DataType::TypeCode::kInt) {
+    if (dltype.bits == 8) {
+      dnnl_type = dt::s8;
+    } else if (dltype.bits == 32) {
+      dnnl_type = dt::s32;
+    }
+  } else if (dltype.code == DataType::TypeCode::kUInt && dltype.bits == 8) {
+    dnnl_type = dt::u8;
+  }
+  if (dnnl_type == dt::undef) {
+    LOG_ERROR << "unsupported datatype: code=" << dltype.code << ", bits=" << dltype.bits;
+  }
+  return dnnl_type;
+}
+}  // namespace contrib
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h
new file mode 100644
index 0000000000..4fb236f96f
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_utils.h
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_utils.h
+ * \brief utils for DNNL.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
+#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
+
+#include <tvm/runtime/data_type.h>
+
+#include "dnnl.hpp"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+/*!
+ * \brief Convert a DLPack data type to a DNNL data type.
+ * \param dltype The DLPack data type.
+ * \return The corresponding DNNL data type.
+ */
+dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype);
+
+}  // namespace contrib
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index 5e3ba83ce0..f784f7b49a 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -205,10 +205,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule,
                          const Layout& src_layout, const Layout& dst_layout) {
-  if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
-      dst_layout.name().empty()) {
+  if (!src_layout.defined() || src_layout.name().empty()) {
+    LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid.";
     return false;
   }
+  if (!dst_layout.defined() || dst_layout.name().empty()) {
+    LOG(WARNING) << "dst layout '" << dst_layout.name() << "' is invalid.";
+    return false;
+  }
+
   for (size_t i = 0; i < dst_layout.ndim(); ++i) {
     const auto& store_axis = dst_layout[i];
     const IterVar& store_axis_impl = dst_layout->axes[i];
@@ -237,7 +242,8 @@ inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rul
       }
     }
     if (tir::is_zero(index_store)) {
-      // Not convertible
+      LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name()
+                   << "' is not convertible.";
       return false;
     }
 
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 5baf6e06d3..fecd776d70 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -37,6 +37,8 @@ run_module = tvm.testing.parameter(
     ids=["compile", "run"],
 )
 
+bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read()
+
 
 def partition_for_dnnl(mod, params=None, alter_layout=True):
     """Partition the graph greedily offloading supported operators to DNNL.
@@ -109,7 +111,10 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
 
 def vmobj_to_list(o):
     if isinstance(o, tvm.nd.NDArray):
-        return [o.numpy()]
+        o_np = o.numpy()
+        if o_np.dtype == np.uint16:
+            o_np = np.left_shift(o_np.astype("uint32"), 16).view("<f4")
+        return [o_np]
     elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
         return [vmobj_to_list(f) for f in o]
     else:
@@ -121,10 +126,13 @@ def assert_result_dict_holds(result_dict):
         res1 = vmobj_to_list(result_dict[k1])
         res2 = vmobj_to_list(result_dict[k2])
         for r1, r2 in zip(res1, res2):
-            tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+            if "bf16" in k1 or "bf16" in k2:
+                np.testing.assert_array_almost_equal(r1, r2, decimal=1)
+            else:
+                tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
 
 
-def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
+def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True):
     def check_dnnl_used(mod, subgraph_num=None):
         num_dnnl_subgraphs = sum(
             [1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]
@@ -137,13 +145,30 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
     dev = tvm.cpu()
     result_dict = dict()
     for mode in ["graph", "vm"]:
-        for use_dnnl, alter_layout in [(False, False), (True, False), (True, True)]:
-            result_key = mode + ("_dnnl" if use_dnnl else "") + ("_layout" if alter_layout else "")
+        configs = [
+            (False, False, False),
+            (True, False, False),
+            (True, True, False),
+        ]
+        if test_bf16 and bf16_supported:
+            configs += [(True, False, True), (True, True, True)]
+        for use_dnnl, alter_layout, use_bf16 in configs:
+            result_key = (
+                mode
+                + ("_dnnl" if use_dnnl else "")
+                + ("_layout" if alter_layout else "")
+                + ("_bf16" if use_bf16 else "_fp32")
+            )
+            processed_mod = mod
+            if use_bf16:
+                processed_mod = relay.transform.ToMixedPrecision("bfloat16")(processed_mod)
+                if tvm.ir.structural_equal(processed_mod, mod):
+                    print("can not convert to bfloat16, skipping...")
+                    continue
             if use_dnnl:
-                processed_mod = partition_for_dnnl(mod, params, alter_layout)
-                check_dnnl_used(processed_mod, subgraph_num)
-            else:
-                processed_mod = mod
+                processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
+                check_dnnl_used(processed_mod)
+
             with tvm.transform.PassContext(opt_level=3):
                 func = relay.create_executor(
                     mode, mod=processed_mod, device=dev, target=target
@@ -158,7 +183,9 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
         assert_result_dict_holds(result_dict)
 
 
-def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dtype="float32"):
+def run_and_verify_func(
+    config, run_module, subgraph_num=None, target="llvm", dtype="float32", test_bf16=True
+):
     """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
     Parameters
     ----------
@@ -176,7 +203,13 @@ def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dt
         if k not in is_param
     }
     run_and_verify(
-        f, input_dict, params, subgraph_num=subgraph_num, target=target, run_module=run_module
+        f,
+        input_dict,
+        params,
+        subgraph_num=subgraph_num,
+        target=target,
+        run_module=run_module,
+        test_bf16=test_bf16,
     )
 
 
@@ -586,7 +619,6 @@ def test_elementwise(run_module, dtype="float32"):
         relay.exp,
         relay.log,
         relay.sqrt,
-        relay.round,
         relay.nn.relu,
         relay.tanh,
         relay.sigmoid,
@@ -956,14 +988,14 @@ def test_prune_dnnl_subgraph(run_module):
     """In this test, OP "add" should be offloaded from dnnl codegen."""
 
     def get_graph():
-        x1 = relay.var("x1", shape=(1, 64, 56, 56))
-        x2 = relay.var("x2", shape=(1, 64, 56, 56))
-        bias = relay.var("bias", shape=(64,))
-        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        x1 = relay.var("x1", shape=(1, 32, 56, 56))
+        x2 = relay.var("x2", shape=(1, 32, 56, 56))
+        bias = relay.var("bias", shape=(32,))
+        weight = relay.var("weight", shape=(32, 32, 3, 3))
         y = relay.nn.conv2d(
             x1,
             weight,
-            channels=64,
+            channels=32,
             kernel_size=(3, 3),
             padding=(1, 1),
         )
@@ -972,16 +1004,16 @@ def test_prune_dnnl_subgraph(run_module):
         y = relay.nn.global_max_pool2d(y)
         y = relay.add(y, x2)
         dic = {
-            "x1": (1, 64, 56, 56),
-            "x2": (1, 64, 56, 56),
-            "weight": (64, 64, 3, 3),
-            "bias": (64,),
+            "x1": (1, 32, 56, 56),
+            "x2": (1, 32, 56, 56),
+            "weight": (32, 32, 3, 3),
+            "bias": (32,),
         }
         param_lst = ["weight", "bias"]
         out = tvm.IRModule.from_expr(y)
         return out, dic, param_lst
 
-    run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module)
+    run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False)
 
 
 if __name__ == "__main__":