You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/26 07:35:27 UTC
[tvm] branch main updated: [BYOC] Enable bfloat16 in DNNL BYOC (#11111)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8135860527 [BYOC] Enable bfloat16 in DNNL BYOC (#11111)
8135860527 is described below
commit 8135860527fa28853660f8f4747795b594b3f53f
Author: Youlei Yang <yo...@intel.com>
AuthorDate: Thu May 26 15:35:23 2022 +0800
[BYOC] Enable bfloat16 in DNNL BYOC (#11111)
* refine the code style (#10112)
* support more data types in oneDNN BYOC
* consider dtype when query layout
* support more translation of blocked layout
* refine log for invalid layout transform
* reset N and C for the weights
* support multi-blocking in TransDims2Plain()
* add tests for bf16 oneDNN BYOC
* unregister 'round' OP in oneDNN BYOC
* restore the criteria for fp32 tests
* disable test_prune_dnnl_subgraph for bf16
* fix typo in dnnl.py
* delete tag::format_tag_last
* delete 'is_weight' in layout2tag()
* reuse dtype_dl2dnnl()
* fix lint errors
* change to WARNING for invalid laytout transform
* skip bf16 tests if AVX512 is unavailable
---
cmake/modules/contrib/DNNL.cmake | 4 +-
include/tvm/tir/op.h | 24 +-
python/tvm/relay/op/contrib/dnnl.py | 25 +-
src/relay/backend/contrib/dnnl/query_layout.cc | 33 +-
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 656 ++++++++++++++++++++-----
src/runtime/contrib/dnnl/dnnl_utils.cc | 56 +++
src/runtime/contrib/dnnl/dnnl_utils.h | 46 ++
src/tir/ir/data_layout.cc | 12 +-
tests/python/contrib/test_dnnl.py | 76 ++-
9 files changed, 758 insertions(+), 174 deletions(-)
diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake
index 9e36f39891..6642719cb4 100644
--- a/cmake/modules/contrib/DNNL.cmake
+++ b/cmake/modules/contrib/DNNL.cmake
@@ -19,11 +19,11 @@ if((USE_DNNL_CODEGEN STREQUAL "ON") OR (USE_DNNL_CODEGEN STREQUAL "JSON"))
add_definitions(-DUSE_JSON_RUNTIME=1)
tvm_file_glob(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/*.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})
- list(APPEND COMPILER_SRCS ${JSON_RELAY_CONTRIB_SRC})
find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
- tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc)
+ tvm_file_glob(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+ src/runtime/contrib/dnnl/dnnl_utils.cc)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL JSON runtime: " ${EXTERN_LIBRARY_DNNL})
elseif(USE_DNNL_CODEGEN STREQUAL "C_SRC")
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 5b63016d2f..905c67f1c5 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
Span span = Span());
// Intrinsic operators
-#define TVM_DECLARE_INTRIN_UNARY(OpName) \
- inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
- static const Op& op = Op::Get("tir." #OpName); \
- if (x.dtype().is_bfloat16()) { \
- DataType srcType = x.dtype(); \
- DataType dstType(kDLFloat, 32, srcType.lanes()); \
- PrimExpr castX = tir::Cast(dstType, {x}, span); \
- PrimExpr result = tir::Call(dstType, op, {castX}, span); \
- return tir::Cast(srcType, {result}, span); \
- } else { \
- return tir::Call(x.dtype(), op, {x}, span); \
- } \
+#define TVM_DECLARE_INTRIN_UNARY(OpName) \
+ inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
+ static const Op& op = Op::Get("tir." #OpName); \
+ if (x.dtype().is_bfloat16()) { \
+ DataType bf16_dtype = x.dtype(); \
+ DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
+ PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
+ PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
+ return tir::Cast(bf16_dtype, {result_fp32}, span); \
+ } else { \
+ return tir::Call(x.dtype(), op, {x}, span); \
+ } \
}
TVM_DECLARE_INTRIN_UNARY(exp);
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 72e004b868..2e975cf49c 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -85,7 +85,6 @@ _register_external_op_helper("clip")
_register_external_op_helper("exp")
_register_external_op_helper("log")
_register_external_op_helper("sqrt")
-_register_external_op_helper("round")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.leaky_relu")
_register_external_op_helper("tanh")
@@ -212,7 +211,7 @@ def pattern_table():
def get_optimal_layout_for_conv(
- data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
+ data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups, dtype
):
"""Get the optimal layout of dnnl, given shape of conv2d.
@@ -236,6 +235,7 @@ def get_optimal_layout_for_conv(
strides,
dilates,
groups,
+ dtype,
)
@@ -249,6 +249,7 @@ def get_optimal_layout_for_conv_transpose(
strides,
dilates,
groups,
+ dtype,
):
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
@@ -274,6 +275,7 @@ def get_optimal_layout_for_conv_transpose(
strides,
dilates,
groups,
+ dtype,
)
@@ -292,6 +294,21 @@ def get_shape(tensor):
raise TypeError("Unsupport data type: %s" % type(tensor))
+def get_dtype(tensor):
+ """Get tensor's dtype."""
+ if isinstance(tensor, relay.expr.Var):
+ return tensor.type_annotation.dtype
+ if isinstance(tensor, relay.expr.Constant):
+ return tensor.data.dtype
+ if isinstance(tensor, tvm.ir.tensor_type.TensorType):
+ return tensor.dtype
+ if isinstance(tensor, tvm.ir.container.Array):
+ return tensor[-1].dtype
+ if isinstance(tensor, relay.expr.Call):
+ return tensor.checked_type.dtype
+ raise TypeError("Unsupport data type: %s" % type(tensor))
+
+
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
"""Transfer layout, denoted with `a, b, c, d, e`,
into valid layout (NCHW / OIHW) of TVM."""
@@ -353,6 +370,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
+ dtype = get_dtype(weight)
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]
@@ -365,6 +383,7 @@ def alter_conv(attrs, inputs, tinfos, out_type):
strides,
dilates,
groups,
+ dtype,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
@@ -389,6 +408,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
groups = str(attrs.groups)
+ dtype = get_dtype(weight)
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]
@@ -402,6 +422,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
strides,
dilates,
groups,
+ dtype,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc
index 7fb1d824c7..3762c1906f 100755
--- a/src/relay/backend/contrib/dnnl/query_layout.cc
+++ b/src/relay/backend/contrib/dnnl/query_layout.cc
@@ -34,16 +34,17 @@
#include <regex>
#include <sstream>
+#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
#include "../../utils.h"
#include "dnnl.hpp"
-
-using dim_t = dnnl_dim_t;
-using dims_t = dnnl_dims_t;
-
namespace tvm {
namespace relay {
namespace contrib {
+using dim_t = dnnl_dim_t;
+using dims_t = dnnl_dims_t;
+using tvm::runtime::contrib::dtype_dl2dnnl;
+
template <typename T, typename U>
inline void array_set(T* arr, const U& val, size_t size) {
for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
@@ -192,7 +193,7 @@ void check_layout(bool var, bool ref) {
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
- std::string dilates, std::string G) {
+ std::string dilates, std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
@@ -200,7 +201,6 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
dnnl::stream s(eng);
using tag = dnnl::memory::format_tag;
- using dt = dnnl::memory::data_type;
dnnl::memory::dim groups = std::stoi(G);
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
@@ -249,9 +249,10 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
dnnl::memory::dims conv_padding_l = padding_dims_l;
dnnl::memory::dims conv_padding_r = padding_dims_r;
- auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dt::f32, tag::any);
- auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dt::f32, tag::any);
- auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dt::f32, tag::any);
+ auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
+ auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dnnl_dtype, tag::any);
+ auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dnnl_dtype, tag::any);
+ auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dnnl_dtype, tag::any);
auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
@@ -276,7 +277,7 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
- std::string G) {
+ std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
@@ -284,7 +285,6 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
dnnl::stream s(eng);
using tag = dnnl::memory::format_tag;
- using dt = dnnl::memory::data_type;
dnnl::memory::dim groups = std::stoi(G);
dnnl::memory::dims weight_dims_ = str2dims(weight_shape);
@@ -338,9 +338,10 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
dnnl::memory::dims deconv_padding_l = padding_dims_l;
dnnl::memory::dims deconv_padding_r = padding_dims_r;
- auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dt::f32, tag::any);
- auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dt::f32, tag::any);
- auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dt::f32, tag::any);
+ auto dnnl_dtype = dtype_dl2dnnl(tvm::runtime::String2DLDataType(dtype));
+ auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dnnl_dtype, tag::any);
+ auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dnnl_dtype, tag::any);
+ auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dnnl_dtype, tag::any);
auto deconv_desc = dnnl::deconvolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, deconv_src_md,
@@ -364,13 +365,13 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5],
- args[6], args[7]);
+ args[6], args[7], args[8]);
});
TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4],
- args[5], args[6], args[7], args[8]);
+ args[5], args[6], args[7], args[8], args[9]);
});
} // namespace contrib
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index dc2afecbaf..f6a1c3b790 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -33,6 +33,7 @@
#include "../json/json_node.h"
#include "../json/json_runtime.h"
#include "dnnl.hpp"
+#include "dnnl_utils.h"
namespace tvm {
namespace runtime {
@@ -66,8 +67,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Fill in the input buffers.
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto eid = EntryID(input_nodes_[i], 0);
- // TODO(@comaniac): Support other data lengths.
- size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+ size_t offset_in_bytes =
+ entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8);
size_t buffer_size = GetDataSize(*data_entry_[eid]);
write_to_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
offset_in_bytes);
@@ -82,7 +83,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Read output buffers.
for (size_t i = 0; i < outputs_.size(); ++i) {
auto eid = EntryID(outputs_[i]);
- size_t offset_in_bytes = entry_out_mem_[eid].second * 4;
+ size_t offset_in_bytes =
+ entry_out_mem_[eid].second * ((data_entry_[eid]->dtype.bits + 7) / 8);
size_t buffer_size = GetDataSize(*data_entry_[eid]);
read_from_dnnl_memory(data_entry_[eid]->data, entry_out_mem_[eid].first, buffer_size,
offset_in_bytes);
@@ -90,7 +92,501 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}
private:
- // Build up the engine based on the input graph.
+ tag layout2tag(std::string layout) {
+ static const std::map<std::string, tag> str2tag = {{"nc", tag::nc},
+ {"cn", tag::cn},
+ {"tn", tag::tn},
+ {"nt", tag::nt},
+ {"ncw", tag::ncw},
+ {"nwc", tag::nwc},
+ {"nchw", tag::nchw},
+ {"nhwc", tag::nhwc},
+ {"chwn", tag::chwn},
+ {"ncdhw", tag::ncdhw},
+ {"ndhwc", tag::ndhwc},
+ {"oi", tag::oi},
+ {"io", tag::io},
+ {"oiw", tag::oiw},
+ {"owi", tag::owi},
+ {"wio", tag::wio},
+ {"iwo", tag::iwo},
+ {"oihw", tag::oihw},
+ {"hwio", tag::hwio},
+ {"ohwi", tag::ohwi},
+ {"ihwo", tag::ihwo},
+ {"iohw", tag::iohw},
+ {"oidhw", tag::oidhw},
+ {"dhwio", tag::dhwio},
+ {"odhwi", tag::odhwi},
+ {"iodhw", tag::iodhw},
+ {"idhwo", tag::idhwo},
+ {"goiw", tag::goiw},
+ {"gowi", tag::gowi},
+ {"wigo", tag::wigo},
+ {"gohwi", tag::gohwi},
+ {"goihw", tag::goihw},
+ {"hwigo", tag::hwigo},
+ {"giohw", tag::giohw},
+ {"goidhw", tag::goidhw},
+ {"giodhw", tag::giodhw},
+ {"godhwi", tag::godhwi},
+ {"dhwigo", tag::dhwigo},
+ {"tnc", tag::tnc},
+ {"ntc", tag::ntc},
+ {"ldnc", tag::ldnc},
+ {"ldigo", tag::ldigo},
+ {"ldgoi", tag::ldgoi},
+ {"ldio", tag::ldio},
+ {"ldoi", tag::ldoi},
+ {"ldgo", tag::ldgo},
+ {"nCdhw16c", tag::nCdhw16c},
+ {"nCdhw4c", tag::nCdhw4c},
+ {"nCdhw8c", tag::nCdhw8c},
+ {"nChw16c", tag::nChw16c},
+ {"nChw4c", tag::nChw4c},
+ {"nChw8c", tag::nChw8c},
+ {"nCw16c", tag::nCw16c},
+ {"nCw4c", tag::nCw4c},
+ {"nCw8c", tag::nCw8c},
+ {"NCw16n16c", tag::NCw16n16c},
+ {"NChw16n16c", tag::NChw16n16c},
+ {"NCdhw16n16c", tag::NCdhw16n16c},
+ {"NCdhw32n32c", tag::NCdhw32n32c},
+ {"NChw32n32c", tag::NChw32n32c},
+ {"IOhw16i16o", tag::IOhw16i16o},
+ {"OI16i16o", tag::OI16i16o},
+ {"OI16i32o", tag::OI16i32o},
+ {"OI16i64o", tag::OI16i64o},
+ {"OI8i16o2i", tag::OI8i16o2i},
+ {"OI8i32o2i", tag::OI8i32o2i},
+ {"OI8i64o2i", tag::OI8i64o2i},
+ {"OI4i16o4i", tag::OI4i16o4i},
+ {"OI4i32o4i", tag::OI4i32o4i},
+ {"OI4i64o4i", tag::OI4i64o4i},
+ {"Ohwi32o", tag::Ohwi32o},
+ {"IOdhw16i16o", tag::IOdhw16i16o},
+ {"gIOhw16i16o", tag::gIOhw16i16o},
+ {"gOhwi32o", tag::gOhwi32o},
+ {"Goidhw16g", tag::Goidhw16g},
+ {"IOw16o16i", tag::IOw16o16i},
+ {"OIw16i16o", tag::OIw16i16o},
+ {"OIw16i32o", tag::OIw16i32o},
+ {"OIw16i64o", tag::OIw16i64o},
+ {"IOw16i16o", tag::IOw16i16o},
+ {"gIOw16i16o", tag::gIOw16i16o},
+ {"OIw16o16i", tag::OIw16o16i},
+ {"Oiw16o", tag::Oiw16o},
+ {"OIw4i16o4i", tag::OIw4i16o4i},
+ {"OIw4i32o4i", tag::OIw4i32o4i},
+ {"OIw4i64o4i", tag::OIw4i64o4i},
+ {"OIw2i8o4i", tag::OIw2i8o4i},
+ {"OIw4i4o", tag::OIw4i4o},
+ {"OIw4o4i", tag::OIw4o4i},
+ {"Oiw4o", tag::Oiw4o},
+ {"OIw8i16o2i", tag::OIw8i16o2i},
+ {"OIw8i32o2i", tag::OIw8i32o2i},
+ {"OIw8i64o2i", tag::OIw8i64o2i},
+ {"OIw8i8o", tag::OIw8i8o},
+ {"OIw8o16i2o", tag::OIw8o16i2o},
+ {"OIw8o8i", tag::OIw8o8i},
+ {"OIw8o4i", tag::OIw8o4i},
+ {"OIw16i16o4i", tag::OIw16i16o4i},
+ {"OIw16i32o4i", tag::OIw16i32o4i},
+ {"OIw16i48o4i", tag::OIw16i48o4i},
+ {"OIw16i64o4i", tag::OIw16i64o4i},
+ {"OIw16i16o2i", tag::OIw16i16o2i},
+ {"OIw16i32o2i", tag::OIw16i32o2i},
+ {"OIw16i48o2i", tag::OIw16i48o2i},
+ {"OIw16i64o2i", tag::OIw16i64o2i},
+ {"OIw16o16i2o", tag::OIw16o16i2o},
+ {"Owi16o", tag::Owi16o},
+ {"OwI16o2i", tag::OwI16o2i},
+ {"Owi4o", tag::Owi4o},
+ {"Owi8o", tag::Owi8o},
+ {"IOhw16o16i", tag::IOhw16o16i},
+ {"Ohwi16o", tag::Ohwi16o},
+ {"OhwI16o2i", tag::OhwI16o2i},
+ {"Ohwi4o", tag::Ohwi4o},
+ {"Ohwi8o", tag::Ohwi8o},
+ {"OIhw16i16o", tag::OIhw16i16o},
+ {"OIhw16i32o", tag::OIhw16i32o},
+ {"OIhw16i64o", tag::OIhw16i64o},
+ {"OIhw16o16i", tag::OIhw16o16i},
+ {"Oihw16o", tag::Oihw16o},
+ {"OIhw4i16o4i", tag::OIhw4i16o4i},
+ {"OIhw4i32o4i", tag::OIhw4i32o4i},
+ {"OIhw4i64o4i", tag::OIhw4i64o4i},
+ {"OIhw4i4o", tag::OIhw4i4o},
+ {"OIhw4o4i", tag::OIhw4o4i},
+ {"Oihw4o", tag::Oihw4o},
+ {"OIhw8i16o2i", tag::OIhw8i16o2i},
+ {"OIhw8i32o2i", tag::OIhw8i32o2i},
+ {"OIhw8i64o2i", tag::OIhw8i64o2i},
+ {"OIhw8i8o", tag::OIhw8i8o},
+ {"OIhw8o16i2o", tag::OIhw8o16i2o},
+ {"OIhw8o8i", tag::OIhw8o8i},
+ {"OIhw8o4i", tag::OIhw8o4i},
+ {"OIhw2i8o4i", tag::OIhw2i8o4i},
+ {"IOdhw16o16i", tag::IOdhw16o16i},
+ {"Odhwi16o", tag::Odhwi16o},
+ {"OdhwI16o2i", tag::OdhwI16o2i},
+ {"Odhwi4o", tag::Odhwi4o},
+ {"Odhwi8o", tag::Odhwi8o},
+ {"OIdhw16i16o", tag::OIdhw16i16o},
+ {"OIdhw16i32o", tag::OIdhw16i32o},
+ {"OIdhw16i64o", tag::OIdhw16i64o},
+ {"OIdhw16o16i", tag::OIdhw16o16i},
+ {"Oidhw16o", tag::Oidhw16o},
+ {"OIdhw4i4o", tag::OIdhw4i4o},
+ {"OIdhw4o4i", tag::OIdhw4o4i},
+ {"Oidhw4o", tag::Oidhw4o},
+ {"OIdhw8i16o2i", tag::OIdhw8i16o2i},
+ {"OIdhw8i32o2i", tag::OIdhw8i32o2i},
+ {"OIdhw8i64o2i", tag::OIdhw8i64o2i},
+ {"OIdhw4i16o4i", tag::OIdhw4i16o4i},
+ {"OIdhw16i16o4i", tag::OIdhw16i16o4i},
+ {"OIdhw16i32o4i", tag::OIdhw16i32o4i},
+ {"OIdhw16i48o4i", tag::OIdhw16i48o4i},
+ {"OIdhw16i64o4i", tag::OIdhw16i64o4i},
+ {"OIdhw16i16o2i", tag::OIdhw16i16o2i},
+ {"OIdhw16i32o2i", tag::OIdhw16i32o2i},
+ {"OIdhw16i48o2i", tag::OIdhw16i48o2i},
+ {"OIdhw16i64o2i", tag::OIdhw16i64o2i},
+ {"OIdhw4i32o4i", tag::OIdhw4i32o4i},
+ {"OIdhw4i64o4i", tag::OIdhw4i64o4i},
+ {"OIdhw2i8o4i", tag::OIdhw2i8o4i},
+ {"OIdhw8i8o", tag::OIdhw8i8o},
+ {"OIdhw8o8i", tag::OIdhw8o8i},
+ {"OIdhw8o4i", tag::OIdhw8o4i},
+ {"gIOw16o16i", tag::gIOw16o16i},
+ {"gOIw16i16o", tag::gOIw16i16o},
+ {"gOIw16o16i", tag::gOIw16o16i},
+ {"gOiw16o", tag::gOiw16o},
+ {"gOIw4i16o4i", tag::gOIw4i16o4i},
+ {"gOIw2i8o4i", tag::gOIw2i8o4i},
+ {"gOIw4i4o", tag::gOIw4i4o},
+ {"gOIw4o4i", tag::gOIw4o4i},
+ {"gOiw4o", tag::gOiw4o},
+ {"gOIw8i16o2i", tag::gOIw8i16o2i},
+ {"gOIw8i8o", tag::gOIw8i8o},
+ {"gOIw8o16i2o", tag::gOIw8o16i2o},
+ {"gOIw8o8i", tag::gOIw8o8i},
+ {"gOIw8o4i", tag::gOIw8o4i},
+ {"gOIw16i16o4i", tag::gOIw16i16o4i},
+ {"gOIw16i16o2i", tag::gOIw16i16o2i},
+ {"gOIw16o16i2o", tag::gOIw16o16i2o},
+ {"gOwi16o", tag::gOwi16o},
+ {"gOwI16o2i", tag::gOwI16o2i},
+ {"gOwi4o", tag::gOwi4o},
+ {"gOwi8o", tag::gOwi8o},
+ {"Goiw8g", tag::Goiw8g},
+ {"Goiw16g", tag::Goiw16g},
+ {"gIOhw16o16i", tag::gIOhw16o16i},
+ {"gOhwi16o", tag::gOhwi16o},
+ {"gOhwI16o2i", tag::gOhwI16o2i},
+ {"gOhwi4o", tag::gOhwi4o},
+ {"gOhwi8o", tag::gOhwi8o},
+ {"Goihw16g", tag::Goihw16g},
+ {"gOIhw16i16o", tag::gOIhw16i16o},
+ {"gOIhw16o16i", tag::gOIhw16o16i},
+ {"gOihw16o", tag::gOihw16o},
+ {"gOIhw4i16o4i", tag::gOIhw4i16o4i},
+ {"gOIhw2i8o4i", tag::gOIhw2i8o4i},
+ {"gOIhw4i4o", tag::gOIhw4i4o},
+ {"gOIhw4o4i", tag::gOIhw4o4i},
+ {"gOihw4o", tag::gOihw4o},
+ {"Goihw8g", tag::Goihw8g},
+ {"gOIhw8i16o2i", tag::gOIhw8i16o2i},
+ {"gOIhw8i8o", tag::gOIhw8i8o},
+ {"gOIhw8o16i2o", tag::gOIhw8o16i2o},
+ {"OIw4o8i8o4i", tag::OIw4o8i8o4i},
+ {"OIdhw4o8i8o4i", tag::OIdhw4o8i8o4i},
+ {"OIhw4o8i8o4i", tag::OIhw4o8i8o4i},
+ {"OIhw2o8i8o2i", tag::OIhw2o8i8o2i},
+ {"gOIw4o8i8o4i", tag::gOIw4o8i8o4i},
+ {"gOIdhw4o8i8o4i", tag::gOIdhw4o8i8o4i},
+ {"gOIhw4o8i8o4i", tag::gOIhw4o8i8o4i},
+ {"gOIhw2o8i8o2i", tag::gOIhw2o8i8o2i},
+ {"OIhw16i16o4i", tag::OIhw16i16o4i},
+ {"OIhw16i32o4i", tag::OIhw16i32o4i},
+ {"OIhw16i48o4i", tag::OIhw16i48o4i},
+ {"OIhw16i64o4i", tag::OIhw16i64o4i},
+ {"OIhw16i16o2i", tag::OIhw16i16o2i},
+ {"OIhw16i32o2i", tag::OIhw16i32o2i},
+ {"OIhw16i48o2i", tag::OIhw16i48o2i},
+ {"OIhw16i64o2i", tag::OIhw16i64o2i},
+ {"OIhw16o16i2o", tag::OIhw16o16i2o},
+ {"gOIhw16i16o4i", tag::gOIhw16i16o4i},
+ {"gOIhw16i16o2i", tag::gOIhw16i16o2i},
+ {"gOIhw16o16i2o", tag::gOIhw16o16i2o},
+ {"gOIhw8o8i", tag::gOIhw8o8i},
+ {"gOIhw8o4i", tag::gOIhw8o4i},
+ {"gIOdhw16i16o", tag::gIOdhw16i16o},
+ {"gIOdhw16o16i", tag::gIOdhw16o16i},
+ {"gOdhwi16o", tag::gOdhwi16o},
+ {"gOdhwI16o2i", tag::gOdhwI16o2i},
+ {"gOdhwi4o", tag::gOdhwi4o},
+ {"gOdhwi8o", tag::gOdhwi8o},
+ {"gOIdhw16i16o", tag::gOIdhw16i16o},
+ {"gOIdhw16o16i", tag::gOIdhw16o16i},
+ {"gOidhw16o", tag::gOidhw16o},
+ {"gOIdhw4i4o", tag::gOIdhw4i4o},
+ {"gOIdhw4o4i", tag::gOIdhw4o4i},
+ {"gOidhw4o", tag::gOidhw4o},
+ {"gOIdhw8i16o2i", tag::gOIdhw8i16o2i},
+ {"gOIdhw4i16o4i", tag::gOIdhw4i16o4i},
+ {"gOIdhw16i16o4i", tag::gOIdhw16i16o4i},
+ {"gOIdhw16i16o2i", tag::gOIdhw16i16o2i},
+ {"gOIdhw2i8o4i", tag::gOIdhw2i8o4i},
+ {"gOIdhw8i8o", tag::gOIdhw8i8o},
+ {"gOIdhw8o8i", tag::gOIdhw8o8i},
+ {"gOIdhw8o4i", tag::gOIdhw8o4i},
+ {"gOIw2i4o2i", tag::gOIw2i4o2i},
+ {"gOIhw2i4o2i", tag::gOIhw2i4o2i},
+ {"gOIdhw2i4o2i", tag::gOIdhw2i4o2i},
+ {"gOIw2o4i2o", tag::gOIw2o4i2o},
+ {"gOIhw2o4i2o", tag::gOIhw2o4i2o},
+ {"gOIdhw2o4i2o", tag::gOIdhw2o4i2o},
+ {"gOIw4i8o2i", tag::gOIw4i8o2i},
+ {"gOIhw4i8o2i", tag::gOIhw4i8o2i},
+ {"gOIdhw4i8o2i", tag::gOIdhw4i8o2i},
+ {"gOIw4o8i2o", tag::gOIw4o8i2o},
+ {"gOIhw4o8i2o", tag::gOIhw4o8i2o},
+ {"gOIdhw4o8i2o", tag::gOIdhw4o8i2o},
+ {"ldOi32o", tag::ldOi32o},
+ {"ldOI32o4i", tag::ldOI32o4i},
+ {"ldgOi32o", tag::ldgOi32o},
+ {"ldgOI32o2i", tag::ldgOI32o2i},
+ {"ldgOI32o4i", tag::ldgOI32o4i},
+ {"OwI16o4i", tag::OwI16o4i},
+ {"OhwI16o4i", tag::OhwI16o4i},
+ {"gOwI16o4i", tag::gOwI16o4i},
+ {"gOhwI16o4i", tag::gOhwI16o4i},
+ {"OdhwI16o4i", tag::OdhwI16o4i},
+ {"gOdhwI16o4i", tag::gOdhwI16o4i},
+ {"Owi32o", tag::Owi32o},
+ {"OwI32o2i", tag::OwI32o2i},
+ {"OwI32o4i", tag::OwI32o4i},
+ {"Owi48o", tag::Owi48o},
+ {"OwI48o2i", tag::OwI48o2i},
+ {"OwI48o4i", tag::OwI48o4i},
+ {"Owi64o", tag::Owi64o},
+ {"OwI64o2i", tag::OwI64o2i},
+ {"OwI64o4i", tag::OwI64o4i},
+ {"wIo2i", tag::wIo2i},
+ {"wIo4i", tag::wIo4i},
+ {"gOwi32o", tag::gOwi32o},
+ {"gOwI32o2i", tag::gOwI32o2i},
+ {"gOwI32o4i", tag::gOwI32o4i},
+ {"gOwi48o", tag::gOwi48o},
+ {"gOwI48o2i", tag::gOwI48o2i},
+ {"gOwI48o4i", tag::gOwI48o4i},
+ {"gOwi64o", tag::gOwi64o},
+ {"gOwI64o2i", tag::gOwI64o2i},
+ {"gOwI64o4i", tag::gOwI64o4i},
+ {"gwio", tag::gwio},
+ {"gwIo2i", tag::gwIo2i},
+ {"gwIo4i", tag::gwIo4i},
+ {"OhwI32o", tag::OhwI32o},
+ {"OhwI32o2i", tag::OhwI32o2i},
+ {"OhwI32o4i", tag::OhwI32o4i},
+ {"Ohwi48o", tag::Ohwi48o},
+ {"OhwI48o2i", tag::OhwI48o2i},
+ {"OhwI48o4i", tag::OhwI48o4i},
+ {"Ohwi64o", tag::Ohwi64o},
+ {"OhwI64o2i", tag::OhwI64o2i},
+ {"OhwI64o4i", tag::OhwI64o4i},
+ {"hwIo2i", tag::hwIo2i},
+ {"hwIo4i", tag::hwIo4i},
+ {"gOhwI32o", tag::gOhwI32o},
+ {"gOhwI32o2i", tag::gOhwI32o2i},
+ {"gOhwI32o4i", tag::gOhwI32o4i},
+ {"gOhwi48o", tag::gOhwi48o},
+ {"gOhwI48o2i", tag::gOhwI48o2i},
+ {"gOhwI48o4i", tag::gOhwI48o4i},
+ {"gOhwi64o", tag::gOhwi64o},
+ {"gOhwI64o2i", tag::gOhwI64o2i},
+ {"gOhwI64o4i", tag::gOhwI64o4i},
+ {"ghwio", tag::ghwio},
+ {"ghwIo2i", tag::ghwIo2i},
+ {"ghwIo4i", tag::ghwIo4i},
+ {"Odhwi32o", tag::Odhwi32o},
+ {"OdhwI32o2i", tag::OdhwI32o2i},
+ {"OdhwI32o4i", tag::OdhwI32o4i},
+ {"Odhwi48o", tag::Odhwi48o},
+ {"OdhwI48o2i", tag::OdhwI48o2i},
+ {"OdhwI48o4i", tag::OdhwI48o4i},
+ {"Odhwi64o", tag::Odhwi64o},
+ {"OdhwI64o2i", tag::OdhwI64o2i},
+ {"OdhwI64o4i", tag::OdhwI64o4i},
+ {"dhwIo2i", tag::dhwIo2i},
+ {"dhwIo4i", tag::dhwIo4i},
+ {"gOdhwi32o", tag::gOdhwi32o},
+ {"gOdhwI32o2i", tag::gOdhwI32o2i},
+ {"gOdhwI32o4i", tag::gOdhwI32o4i},
+ {"gOdhwi48o", tag::gOdhwi48o},
+ {"gOdhwI48o2i", tag::gOdhwI48o2i},
+ {"gOdhwI48o4i", tag::gOdhwI48o4i},
+ {"gOdhwi64o", tag::gOdhwi64o},
+ {"gOdhwI64o2i", tag::gOdhwI64o2i},
+ {"gOdhwI64o4i", tag::gOdhwI64o4i},
+ {"gdhwio", tag::gdhwio},
+ {"gdhwIo2i", tag::gdhwIo2i},
+ {"gdhwIo4i", tag::gdhwIo4i},
+ {"ldIo32i", tag::ldIo32i},
+ {"ldgIo32i", tag::ldgIo32i},
+ {"ldgIO32i2o", tag::ldgIO32i2o},
+ {"nCdhw32c", tag::nCdhw32c},
+ {"nChw32c", tag::nChw32c},
+ {"nCw32c", tag::nCw32c},
+ {"NCw32n16c", tag::NCw32n16c},
+ {"NChw32n16c", tag::NChw32n16c},
+ {"NCdhw32n16c", tag::NCdhw32n16c},
+ {"NCw32n32c", tag::NCw32n32c},
+ {"OI16i16o4i", tag::OI16i16o4i},
+ {"IOw8o16i2o", tag::IOw8o16i2o},
+ {"IOhw8o16i2o", tag::IOhw8o16i2o},
+ {"Owhi16o", tag::Owhi16o},
+ {"OIdhw8o16i2o", tag::OIdhw8o16i2o},
+ {"IOdhw8o16i2o", tag::IOdhw8o16i2o},
+ {"Goiw4g", tag::Goiw4g},
+ {"gIOw8o16i2o", tag::gIOw8o16i2o},
+ {"Goiw32g", tag::Goiw32g},
+ {"Goihw4g", tag::Goihw4g},
+ {"gIOhw8o16i2o", tag::gIOhw8o16i2o},
+ {"Goihw32g", tag::Goihw32g},
+ {"gOwhi16o", tag::gOwhi16o},
+ {"IOw4i8o8i4o", tag::IOw4i8o8i4o},
+ {"IOhw4i8o8i4o", tag::IOhw4i8o8i4o},
+ {"IOdhw4i8o8i4o", tag::IOdhw4i8o8i4o},
+ {"gIOw4i8o8i4o", tag::gIOw4i8o8i4o},
+ {"gIOhw4i8o8i4o", tag::gIOhw4i8o8i4o},
+ {"gIOdhw4i8o8i4o", tag::gIOdhw4i8o8i4o},
+ {"gOIdhw8o16i2o", tag::gOIdhw8o16i2o},
+ {"gIOdhw8o16i2o", tag::gIOdhw8o16i2o},
+ {"Goidhw32g", tag::Goidhw32g},
+ {"OI16i32o4i", tag::OI16i32o4i},
+ {"OI16i48o4i", tag::OI16i48o4i},
+ {"OI16i64o4i", tag::OI16i64o4i},
+ {"OI16i16o2i", tag::OI16i16o2i},
+ {"OI16i32o2i", tag::OI16i32o2i},
+ {"OI16i48o2i", tag::OI16i48o2i},
+ {"OI16i64o2i", tag::OI16i64o2i},
+ {"OwI16i16o2i", tag::OwI16i16o2i},
+ {"gOwI16i16o2i", tag::gOwI16i16o2i},
+ {"OhwI16i16o2i", tag::OhwI16i16o2i},
+ {"gOhwI16i16o2i", tag::gOhwI16i16o2i},
+ {"OdhwI16i16o2i", tag::OdhwI16i16o2i},
+ {"gOdhwI16i16o2i", tag::gOdhwI16i16o2i},
+ {"OwI16i16o4i", tag::OwI16i16o4i},
+ {"gOwI16i16o4i", tag::gOwI16i16o4i},
+ {"OhwI16i16o4i", tag::OhwI16i16o4i},
+ {"gOhwI16i16o4i", tag::gOhwI16i16o4i},
+ {"OdhwI16i16o4i", tag::OdhwI16i16o4i},
+ {"gOdhwI16i16o4i", tag::gOdhwI16i16o4i},
+ {"OwI16i32o2i", tag::OwI16i32o2i},
+ {"OwI16i32o4i", tag::OwI16i32o4i},
+ {"OwI16i48o2i", tag::OwI16i48o2i},
+ {"OwI16i48o4i", tag::OwI16i48o4i},
+ {"OwI16i64o2i", tag::OwI16i64o2i},
+ {"OwI16i64o4i", tag::OwI16i64o4i},
+ {"gOwI16i32o2i", tag::gOwI16i32o2i},
+ {"gOwI16i32o4i", tag::gOwI16i32o4i},
+ {"gOwI16i48o2i", tag::gOwI16i48o2i},
+ {"gOwI16i48o4i", tag::gOwI16i48o4i},
+ {"gOwI16i64o2i", tag::gOwI16i64o2i},
+ {"gOwI16i64o4i", tag::gOwI16i64o4i},
+ {"OhwI16i32o2i", tag::OhwI16i32o2i},
+ {"OhwI16i32o4i", tag::OhwI16i32o4i},
+ {"OhwI16i48o2i", tag::OhwI16i48o2i},
+ {"OhwI16i48o4i", tag::OhwI16i48o4i},
+ {"OhwI16i64o2i", tag::OhwI16i64o2i},
+ {"OhwI16i64o4i", tag::OhwI16i64o4i},
+ {"gOhwI16i32o2i", tag::gOhwI16i32o2i},
+ {"gOhwI16i32o4i", tag::gOhwI16i32o4i},
+ {"gOhwI16i48o2i", tag::gOhwI16i48o2i},
+ {"gOhwI16i48o4i", tag::gOhwI16i48o4i},
+ {"gOhwI16i64o2i", tag::gOhwI16i64o2i},
+ {"gOhwI16i64o4i", tag::gOhwI16i64o4i},
+ {"OdhwI16i32o2i", tag::OdhwI16i32o2i},
+ {"OdhwI16i32o4i", tag::OdhwI16i32o4i},
+ {"OdhwI16i48o2i", tag::OdhwI16i48o2i},
+ {"OdhwI16i48o4i", tag::OdhwI16i48o4i},
+ {"OdhwI16i64o2i", tag::OdhwI16i64o2i},
+ {"OdhwI16i64o4i", tag::OdhwI16i64o4i},
+ {"gOdhwI16i32o2i", tag::gOdhwI16i32o2i},
+ {"gOdhwI16i32o4i", tag::gOdhwI16i32o4i},
+ {"gOdhwI16i48o2i", tag::gOdhwI16i48o2i},
+ {"gOdhwI16i48o4i", tag::gOdhwI16i48o4i},
+ {"gOdhwI16i64o2i", tag::gOdhwI16i64o2i},
+ {"gOdhwI16i64o4i", tag::gOdhwI16i64o4i},
+ {"hwioG16g", tag::hwioG16g},
+ {"NCdhw40n32c", tag::NCdhw40n32c},
+ {"NChw40n32c", tag::NChw40n32c},
+ {"NCw40n32c", tag::NCw40n32c},
+ {"OIdhw4o8i8o2i", tag::OIdhw4o8i8o2i},
+ {"OIhw4o8i8o2i", tag::OIhw4o8i8o2i},
+ {"OIw4o8i8o2i", tag::OIw4o8i8o2i},
+ {"gOIdhw4o8i8o2i", tag::gOIdhw4o8i8o2i},
+ {"gOIhw4o8i8o2i", tag::gOIhw4o8i8o2i},
+ {"gOIw4o8i8o2i", tag::gOIw4o8i8o2i},
+ {"IOdhw4i8o8i2o", tag::IOdhw4i8o8i2o},
+ {"IOhw4i8o8i2o", tag::IOhw4i8o8i2o},
+ {"IOw4i8o8i2o", tag::IOw4i8o8i2o},
+ {"gIOdhw4i8o8i2o", tag::gIOdhw4i8o8i2o},
+ {"gIOhw4i8o8i2o", tag::gIOhw4i8o8i2o},
+ {"gIOw4i8o8i2o", tag::gIOw4i8o8i2o},
+ {"NCdhw40n16c", tag::NCdhw40n16c},
+ {"NCw40n16c", tag::NCw40n16c},
+ {"NChw40n16c", tag::NChw40n16c},
+ {"NCw2c32n8c", tag::NCw2c32n8c},
+ {"NChw2c32n8c", tag::NChw2c32n8c},
+ {"NCdhw2c32n8c", tag::NCdhw2c32n8c},
+ {"OIw2i8o16i4o", tag::OIw2i8o16i4o},
+ {"OIhw2i8o16i4o", tag::OIhw2i8o16i4o},
+ {"OIdhw2i8o16i4o", tag::OIdhw2i8o16i4o},
+ {"OIw2o8i16o4i", tag::OIw2o8i16o4i},
+ {"OIw2o8i16o2i", tag::OIw2o8i16o2i},
+ {"IOw2i8o16i4o", tag::IOw2i8o16i4o},
+ {"IOw2i8o16i2o", tag::IOw2i8o16i2o},
+ {"OIhw2o8i16o4i", tag::OIhw2o8i16o4i},
+ {"OIhw2o8i16o2i", tag::OIhw2o8i16o2i},
+ {"IOhw2i8o16i4o", tag::IOhw2i8o16i4o},
+ {"IOhw2i8o16i2o", tag::IOhw2i8o16i2o},
+ {"OIdhw2o8i16o4i", tag::OIdhw2o8i16o4i},
+ {"OIdhw2o8i16o2i", tag::OIdhw2o8i16o2i},
+ {"IOdhw2i8o16i4o", tag::IOdhw2i8o16i4o},
+ {"IOdhw2i8o16i2o", tag::IOdhw2i8o16i2o},
+ {"gOIw2o8i16o2i", tag::gOIw2o8i16o2i},
+ {"gIOw2i8o16i2o", tag::gIOw2i8o16i2o},
+ {"gIOhw2i8o16i2o", tag::gIOhw2i8o16i2o},
+ {"gIOdhw2i8o16i2o", tag::gIOdhw2i8o16i2o},
+ {"gOIhw2o8i16o2i", tag::gOIhw2o8i16o2i},
+ {"gOIdhw2o8i16o2i", tag::gOIdhw2o8i16o2i},
+ {"gOIw2o8i16o4i", tag::gOIw2o8i16o4i},
+ {"gOIhw2o8i16o4i", tag::gOIhw2o8i16o4i}};
+ std::string key = "";
+ for (const auto& c : layout) {
+ if (std::isalpha(c, std::locale("C"))) {
+ char lower_c = std::tolower(c);
+ if (std::isupper(c) && (layout.find(lower_c) != std::string::npos)) {
+ key.push_back(c);
+ } else {
+ key.push_back(lower_c);
+ }
+ } else if (std::isdigit(c)) {
+ key.push_back(c);
+ } else {
+ LOG(FATAL) << "invalid char '" << c << "' in " << layout << std::endl;
+ }
+ }
+ if (str2tag.count(key) == 0) {
+ LOG(WARNING) << "convert unregistered layout '" << key << "' to tag::any";
+ return tag::any;
+ } else {
+ return str2tag.at(key);
+ }
+ }
std::map<std::string, dnnl::algorithm> elt_name2algo{
{"abs", dnnl::algorithm::eltwise_abs},
@@ -106,62 +602,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{"clip", dnnl::algorithm::eltwise_clip},
};
- std::map<std::string, tag> layout_dict{
- {"", tag::any},
- {"NCW", tag::ncw},
- {"NWC", tag::nwc},
- {"OIW", tag::oiw},
- {"GOIW", tag::goiw},
- {"NCHW", tag::nchw},
- {"NHWC", tag::nhwc},
- {"OIHW", tag::oihw},
- {"GOIHW", tag::goihw},
- {"NCDHW", tag::ncdhw},
- {"NDHWC", tag::ndhwc},
- {"OIDHW", tag::oidhw},
- {"GOIDHW", tag::goidhw},
- {"IOHW", tag::iohw},
- {"GIOHW", tag::giohw},
- {"IODHW", tag::iodhw},
- {"GIODHW", tag::giodhw},
-
- // Blocking layout.
- {"NCW8c", tag::nCw8c},
- {"NCW16c", tag::nCw16c},
- {"OIW16i16o", tag::OIw8i8o},
- {"OIW16i16o", tag::OIw16i16o},
- {"OWI8o", tag::Owi8o},
- {"OWI16o", tag::Owi16o},
- {"NCHW4c", tag::nChw4c},
- {"NCHW8c", tag::nChw8c},
- {"NCHW16c", tag::nChw16c},
- {"OIHW8i8o", tag::OIhw8i8o},
- {"IOHW8i8o", tag::any},
- {"OIHW16i16o", tag::OIhw16i16o},
- {"IOHW16i16o", tag::IOhw16i16o},
- {"GOIHW4i4o", tag::gOIhw4i4o},
- {"GOIHW8i8o", tag::gOIhw8i8o},
- {"GOIHW16i16o", tag::gOIhw16i16o},
- {"OHWI8o", tag::Ohwi8o},
- {"OHWI16o", tag::Ohwi16o},
- {"OHWI32o", tag::Ohwi32o},
- {"OHWI48o", tag::Ohwi48o},
- {"OHWI64o", tag::Ohwi64o},
- {"GOIHW8g", tag::Goihw8g},
- {"GOIHW16g", tag::Goihw16g},
- {"NCDHW8c", tag::nCdhw8c},
- {"NCDHW16c", tag::nCdhw16c},
- {"OIDHW16i16o", tag::OIdhw16i16o},
- {"IODHW16i16o", tag::IOdhw16i16o},
- {"OIDHW8i8o", tag::OIdhw8i8o},
- {"IODHW8i8o", tag::any},
- {"ODHWI8o", tag::Odhwi8o},
- {"ODHWI16o", tag::Odhwi16o},
- {"ODHWI32o", tag::Odhwi32o},
- {"ODHWI48o", tag::Odhwi48o},
- {"ODHWI64o", tag::Odhwi64o},
- };
-
bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) {
// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
@@ -202,12 +642,13 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}
// Push the correct shapes of each axis into the output_dims
for (auto a : axis) {
- dnnl::memory::dim shape = 1;
if (layout.find(a) != std::string::npos) {
- shape *= input_dims[layout.find(a)];
+ dnnl::memory::dim shape = input_dims[layout.find(a)];
char lower_a = std::tolower(a);
- if (layout.find(lower_a) != std::string::npos) {
- shape *= input_dims[layout.find(lower_a)];
+ for (size_t i = 0; i < layout.size(); ++i) {
+ if (lower_a == layout[i]) {
+ shape *= input_dims[i];
+ }
}
out_dims.push_back(shape);
}
@@ -238,6 +679,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return out_dims;
}
+ // Build up the engine based on the input graph.
void BuildEngine() {
engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0);
stream_ = dnnl::stream(engine_);
@@ -301,11 +743,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// has not yet been bound to the other DNNL memory; otherwise it may have memory leak.
ICHECK_EQ(entry_out_mem_.count(eid), 0);
- // TODO(@comanic): Support other data types (i.e., int8).
- auto data_node = nodes_[entry.id_];
- auto dltype = data_node.GetOpDataType()[entry.index_];
- ICHECK_EQ(dltype.bits, 32);
-
entry_out_mem_[eid] = {mem, offset};
return entry_out_mem_[eid].first;
}
@@ -338,17 +775,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::string data_layout = node.GetAttr<std::vector<std::string>>("data_layout")[0];
std::string kernel_layout = node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
- // Check layout.
- if (layout_dict.find(data_layout) == layout_dict.end()) {
- LOG(FATAL) << "Unsupported data layout for conv: " << data_layout;
- }
-
- if (layout_dict.find(kernel_layout) == layout_dict.end()) {
- layout_dict.insert({kernel_layout, tag::any});
- LOG(WARNING) << "Unregistered kernel layout for conv: " << kernel_layout
- << ", transfer to tag::any";
- }
-
// Memory shapes.
dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout);
@@ -360,6 +786,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims dst_dims = src_dims;
dst_dims[1] = channels;
weights_dims_[0] = channels;
+ weights_dims_[1] = src_dims[1];
for (size_t i = 2; i < src_dims.size(); i++) {
dnnl::memory::dim K = weights_dims_[i];
dnnl::memory::dim S = strides_dims[i - 2];
@@ -380,10 +807,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}
// Memory descriptions.
- auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]);
- auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]);
- auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
- auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ auto conv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout));
+ auto conv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout));
+ auto conv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::any);
+ auto conv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
// Conv description.
auto conv_desc =
@@ -413,7 +841,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc());
// Bias memory.
- auto conv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
+ auto conv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
if (has_bias) {
auto bias_entry = node.GetInputs()[2];
BindDNNLMemory(bias_entry, conv_bias_memory);
@@ -461,17 +889,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::string data_layout = node.GetAttr<std::vector<std::string>>("data_layout")[0];
std::string kernel_layout = node.GetAttr<std::vector<std::string>>("kernel_layout")[0];
- // Check layout.
- if (layout_dict.find(data_layout) == layout_dict.end()) {
- LOG(FATAL) << "Unsupported data layout for deconv: " << data_layout;
- }
-
- if (layout_dict.find(kernel_layout) == layout_dict.end()) {
- layout_dict.insert({kernel_layout, tag::any});
- LOG(WARNING) << "Unregistered kernel layout for deconv: " << data_layout
- << ", transfer to tag::any";
- }
-
// Memory shapes.
dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout);
dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout);
@@ -482,6 +899,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
kernel_layout.replace(kernel_layout.find("OI"), 2, "IO");
}
}
+ weights_dims_[0] = channels;
+ weights_dims_[1] = src_dims[1];
dnnl::memory::dims bias_dims = {channels};
dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides);
dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true);
@@ -508,10 +927,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
}
// Memory descriptions.
- auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]);
- auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]);
- auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any);
- auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ auto deconv_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(data_layout));
+ auto deconv_weights_md = dnnl::memory::desc(weights_dims, dtype, layout2tag(kernel_layout));
+ auto deconv_bias_md = dnnl::memory::desc(bias_dims, dtype, tag::x);
+ auto deconv_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
// Transposed covn2d description.
auto deconv_desc =
@@ -541,7 +961,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc());
// Bias memory.
- auto deconv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_);
+ auto deconv_bias_memory = dnnl::memory({bias_dims, dtype, tag::x}, engine_);
if (has_bias) {
auto bias_entry = node.GetInputs()[2];
BindDNNLMemory(bias_entry, deconv_bias_memory);
@@ -581,10 +1001,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims out_dims = out_shape;
// Memory descriptions.
- auto data_md = dnnl::memory::desc({data_dims, dt::f32, tag::nc});
- auto weight_md = dnnl::memory::desc({weight_dims, dt::f32, tag::nc});
- auto bias_md = dnnl::memory::desc({bias_dims, dt::f32, tag::x});
- auto dst_md = dnnl::memory::desc({out_dims, dt::f32, tag::nc});
+ auto dl_dtype = nodes_[data_entry.id_].GetOpDataType()[data_entry.index_];
+ auto dtype = dtype_dl2dnnl(dl_dtype);
+ auto data_md = dnnl::memory::desc({data_dims, dtype, tag::nc});
+ auto weight_md = dnnl::memory::desc({weight_dims, dtype, tag::nc});
+ auto bias_md = dnnl::memory::desc({bias_dims, dtype, tag::x});
+ auto dst_md = dnnl::memory::desc({out_dims, dtype, tag::nc});
// Dense description.
auto dense_desc = dnnl::inner_product_forward::desc(dnnl::prop_kind::forward_inference, data_md,
@@ -607,7 +1029,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
BindDNNLMemory(bias_entry, bias_memory);
} else {
float bias[OC] = {0};
- write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
+ write_to_dnnl_memory(bias, bias_memory, OC * ((dl_dtype.bits + 7) / 8));
}
// Output memory.
@@ -632,7 +1054,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
float epsilon = std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
// Memory description.
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
// BN description.
auto bn_desc = dnnl::batch_normalization_forward::desc(
@@ -679,11 +1102,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
std::string layout = node.GetAttr<std::vector<std::string>>("layout")[0];
- // Check layout.
- if (layout_dict.find(layout) == layout_dict.end()) {
- LOG(FATAL) << "Unsupported layout for pooling: " << layout;
- }
-
// Attributes related to AvgPool
if (algo == dnnl::algorithm::pooling_avg) {
int int_countpad = std::stoi(node.GetAttr<std::vector<std::string>>("count_include_pad")[0]);
@@ -701,8 +1119,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r);
// Memory descriptions.
- auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]);
- auto pool_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ auto pool_src_md = dnnl::memory::desc(src_dims, dtype, layout2tag(layout));
+ auto pool_dst_md = dnnl::memory::desc(dst_dims, dtype, tag::any);
// Pooling description.
auto pool_desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_inference, algo,
@@ -729,7 +1148,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto data_entry = node.GetInputs()[0];
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
float alpha = 0., beta = 0.;
if (op_name == "clip") {
alpha = std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
@@ -762,7 +1182,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (axis < 0) {
axis = shape.size() + axis;
}
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);
+ auto dtype = dtype_dl2dnnl(nodes_[data_entry.id_].GetOpDataType()[data_entry.index_]);
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dtype);
auto softmax_desc =
dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, data_md, axis);
@@ -790,7 +1211,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
ICHECK_EQ(node.GetInputs().size(), 2U);
for (auto entry : node.GetInputs()) {
auto data_shape = nodes_[entry.id_].GetOpShape()[entry.index_];
- dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dt::f32);
+ auto dtype = dtype_dl2dnnl(nodes_[entry.id_].GetOpDataType()[entry.index_]);
+ dnnl::memory::desc data_md = GenDNNLMemDescByShape(data_shape, dtype);
data_dims.push_back(data_shape);
data_mds.push_back(data_md);
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.cc b/src/runtime/contrib/dnnl/dnnl_utils.cc
new file mode 100644
index 0000000000..7e79f1c939
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_utils.cc
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_utils.cc
+ */
+
+#include "dnnl_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+using dt = dnnl::memory::data_type;
+dt dtype_dl2dnnl(DLDataType dltype) {
+ dt dnnl_type = dt::undef;
+ if (dltype.code == DataType::TypeCode::kFloat) {
+ if (dltype.bits == 16) {
+ dnnl_type = dt::f16;
+ } else if (dltype.bits == 32) {
+ dnnl_type = dt::f32;
+ }
+ } else if (dltype.code == DataType::TypeCode::kBFloat && dltype.bits == 16) {
+ dnnl_type = dt::bf16;
+ } else if (dltype.code == DataType::TypeCode::kInt) {
+ if (dltype.bits == 8) {
+ dnnl_type = dt::s8;
+ } else if (dltype.bits == 32) {
+ dnnl_type = dt::s32;
+ }
+ } else if (dltype.code == DataType::TypeCode::kUInt && dltype.bits == 8) {
+ dnnl_type = dt::u8;
+ }
+ if (dnnl_type == dt::undef) {
+ LOG_ERROR << "unsupported datatype: code=" << dltype.code << ", bits=" << dltype.bits;
+ }
+ return dnnl_type;
+}
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/contrib/dnnl/dnnl_utils.h b/src/runtime/contrib/dnnl/dnnl_utils.h
new file mode 100644
index 0000000000..4fb236f96f
--- /dev/null
+++ b/src/runtime/contrib/dnnl/dnnl_utils.h
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/dnnl/dnnl_utils.h
+ * \brief utils for DNNL.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
+#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
+
+#include <tvm/runtime/data_type.h>
+
+#include "dnnl.hpp"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+/*!
+ * \brief Convert a DLPack data type to a DNNL data type.
+ * \param dltype The DLPack data type.
+ * \return The corresponding DNNL data type.
+ */
+dnnl::memory::data_type dtype_dl2dnnl(DLDataType dltype);
+
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_UTILS_H_
diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc
index 5e3ba83ce0..f784f7b49a 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/tir/ir/data_layout.cc
@@ -205,10 +205,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rule,
const Layout& src_layout, const Layout& dst_layout) {
- if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() ||
- dst_layout.name().empty()) {
+ if (!src_layout.defined() || src_layout.name().empty()) {
+ LOG(WARNING) << "src layout '" << src_layout.name() << "' is invalid.";
return false;
}
+ if (!dst_layout.defined() || dst_layout.name().empty()) {
+ LOG(WARNING) << "dst layout '" << dst_layout.name() << "' is invalid.";
+ return false;
+ }
+
for (size_t i = 0; i < dst_layout.ndim(); ++i) {
const auto& store_axis = dst_layout[i];
const IterVar& store_axis_impl = dst_layout->axes[i];
@@ -237,7 +242,8 @@ inline bool GetStoreRule(Array<PrimExpr>* index_rule, Array<PrimExpr>* shape_rul
}
}
if (tir::is_zero(index_store)) {
- // Not convertible
+ LOG(WARNING) << "layout '" << src_layout.name() << "'-->'" << dst_layout.name()
+ << "' is not convertible.";
return false;
}
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 5baf6e06d3..fecd776d70 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -37,6 +37,8 @@ run_module = tvm.testing.parameter(
ids=["compile", "run"],
)
+bf16_supported = "avx512" in open("/proc/cpuinfo", "r").read()
+
def partition_for_dnnl(mod, params=None, alter_layout=True):
"""Partition the graph greedily offloading supported operators to DNNL.
@@ -109,7 +111,10 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
- return [o.numpy()]
+ o_np = o.numpy()
+ if o_np.dtype == np.uint16:
+ o_np = np.left_shift(o_np.astype("uint32"), 16).view("<f4")
+ return [o_np]
elif isinstance(o, tvm.runtime.container.ADT) or isinstance(o, list):
return [vmobj_to_list(f) for f in o]
else:
@@ -121,10 +126,13 @@ def assert_result_dict_holds(result_dict):
res1 = vmobj_to_list(result_dict[k1])
res2 = vmobj_to_list(result_dict[k2])
for r1, r2 in zip(res1, res2):
- tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+ if "bf16" in k1 or "bf16" in k2:
+ np.testing.assert_array_almost_equal(r1, r2, decimal=1)
+ else:
+ tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
-def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
+def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, test_bf16=True):
def check_dnnl_used(mod, subgraph_num=None):
num_dnnl_subgraphs = sum(
[1 if "dnnl" in gv.name_hint else 0 for gv in mod.get_global_vars()]
@@ -137,13 +145,30 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
dev = tvm.cpu()
result_dict = dict()
for mode in ["graph", "vm"]:
- for use_dnnl, alter_layout in [(False, False), (True, False), (True, True)]:
- result_key = mode + ("_dnnl" if use_dnnl else "") + ("_layout" if alter_layout else "")
+ configs = [
+ (False, False, False),
+ (True, False, False),
+ (True, True, False),
+ ]
+ if test_bf16 and bf16_supported:
+ configs += [(True, False, True), (True, True, True)]
+ for use_dnnl, alter_layout, use_bf16 in configs:
+ result_key = (
+ mode
+ + ("_dnnl" if use_dnnl else "")
+ + ("_layout" if alter_layout else "")
+ + ("_bf16" if use_bf16 else "_fp32")
+ )
+ processed_mod = mod
+ if use_bf16:
+ processed_mod = relay.transform.ToMixedPrecision("bfloat16")(processed_mod)
+ if tvm.ir.structural_equal(processed_mod, mod):
+ print("can not convert to bfloat16, skipping...")
+ continue
if use_dnnl:
- processed_mod = partition_for_dnnl(mod, params, alter_layout)
- check_dnnl_used(processed_mod, subgraph_num)
- else:
- processed_mod = mod
+ processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
+ check_dnnl_used(processed_mod)
+
with tvm.transform.PassContext(opt_level=3):
func = relay.create_executor(
mode, mod=processed_mod, device=dev, target=target
@@ -158,7 +183,9 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None):
assert_result_dict_holds(result_dict)
-def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dtype="float32"):
+def run_and_verify_func(
+ config, run_module, subgraph_num=None, target="llvm", dtype="float32", test_bf16=True
+):
"""Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
Parameters
----------
@@ -176,7 +203,13 @@ def run_and_verify_func(config, run_module, subgraph_num=None, target="llvm", dt
if k not in is_param
}
run_and_verify(
- f, input_dict, params, subgraph_num=subgraph_num, target=target, run_module=run_module
+ f,
+ input_dict,
+ params,
+ subgraph_num=subgraph_num,
+ target=target,
+ run_module=run_module,
+ test_bf16=test_bf16,
)
@@ -586,7 +619,6 @@ def test_elementwise(run_module, dtype="float32"):
relay.exp,
relay.log,
relay.sqrt,
- relay.round,
relay.nn.relu,
relay.tanh,
relay.sigmoid,
@@ -956,14 +988,14 @@ def test_prune_dnnl_subgraph(run_module):
"""In this test, OP "add" should be offloaded from dnnl codegen."""
def get_graph():
- x1 = relay.var("x1", shape=(1, 64, 56, 56))
- x2 = relay.var("x2", shape=(1, 64, 56, 56))
- bias = relay.var("bias", shape=(64,))
- weight = relay.var("weight", shape=(64, 64, 3, 3))
+ x1 = relay.var("x1", shape=(1, 32, 56, 56))
+ x2 = relay.var("x2", shape=(1, 32, 56, 56))
+ bias = relay.var("bias", shape=(32,))
+ weight = relay.var("weight", shape=(32, 32, 3, 3))
y = relay.nn.conv2d(
x1,
weight,
- channels=64,
+ channels=32,
kernel_size=(3, 3),
padding=(1, 1),
)
@@ -972,16 +1004,16 @@ def test_prune_dnnl_subgraph(run_module):
y = relay.nn.global_max_pool2d(y)
y = relay.add(y, x2)
dic = {
- "x1": (1, 64, 56, 56),
- "x2": (1, 64, 56, 56),
- "weight": (64, 64, 3, 3),
- "bias": (64,),
+ "x1": (1, 32, 56, 56),
+ "x2": (1, 32, 56, 56),
+ "weight": (32, 32, 3, 3),
+ "bias": (32,),
}
param_lst = ["weight", "bias"]
out = tvm.IRModule.from_expr(y)
return out, dic, param_lst
- run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module)
+ run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, test_bf16=False)
if __name__ == "__main__":