You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/11 10:10:10 UTC
[tvm] branch main updated: enable bmm (#12018)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2d5072858c enable bmm (#12018)
2d5072858c is described below
commit 2d5072858c9749217256913cad3c14fe52be0367
Author: billishyahao <ya...@intel.com>
AuthorDate: Mon Jul 11 18:09:58 2022 +0800
enable bmm (#12018)
---
python/tvm/relay/op/contrib/dnnl.py | 4 ++-
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 50 ++++++++++++++++++++++++++-
tests/python/contrib/test_dnnl.py | 29 ++++++++++++++++
3 files changed, 81 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 9b6b45240a..05416bb9a3 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -105,6 +105,7 @@ _register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
_register_external_op_helper("nn.layer_norm")
+_register_external_op_helper("nn.batch_matmul")
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
@@ -563,6 +564,7 @@ class IsComputeIntensiveGraph(ExprVisitor):
"nn.conv3d_transpose",
"nn.dense",
"nn.layer_norm",
+ "nn.batch_matmul",
]
)
if isinstance(call.op, tvm.tir.op.Op):
@@ -679,7 +681,7 @@ class LayerNormRewrite(DFPatternCallback):
const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
p1 = is_op("power")(cdiff, const_two)
mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
- eps = is_expr(relay.const(1e-5))
+ eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 6c0fd64066..c6e50eafea 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -269,6 +269,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_mul);
} else if ("nn.layer_norm" == op_name) {
LayerNorm(nid);
+ } else if ("nn.batch_matmul" == op_name) {
+ BatchMatMul(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
@@ -483,6 +485,52 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{sum_in_tr, DNNL_ARG_DST});
}
+ void BatchMatMul(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ // Setup attributes.
+ auto src_tr = GetInput(nid, 0);
+ auto wgh_tr = GetInput(nid, 1);
+ auto dst_tr = GetOutput(nid, 0);
+ auto bias_tr = TensorRequisite{};
+
+ auto attr = ParseAttrs(nid, &bias_tr);
+ attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+
+ bool transpose_a = GetNodeAttr<bool>(node, "transpose_a");
+ bool transpose_b = GetNodeAttr<bool>(node, "transpose_b");
+
+ if (transpose_a) {
+ src_tr = src_tr.Permute({0, 2, 1});
+ }
+ if (transpose_b) {
+ wgh_tr = wgh_tr.Permute({0, 2, 1});
+ }
+
+ // Assumption that bias is correct and can be squeezed to 1D
+ bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
+
+ // Matmul description.
+ auto bmm_desc = dnnl::matmul::desc(src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(),
+ bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc());
+
+ // Enable elementwise post-ops.
+ auto bmm_prim_desc = dnnl::matmul::primitive_desc(bmm_desc, attr, engine_);
+
+ src_tr = src_tr.RequestLayout(bmm_prim_desc.src_desc());
+ wgh_tr = wgh_tr.RequestLayout(bmm_prim_desc.weights_desc());
+ dst_tr = dst_tr.RequestLayout(bmm_prim_desc.dst_desc());
+ bias_tr = bias_tr.RequestLayout(bmm_prim_desc.bias_desc());
+
+ auto scratchpad_tr = TensorRequisite::AsIs(bmm_prim_desc.scratchpad_desc());
+
+ Submit(dnnl::matmul(bmm_prim_desc), {{DNNL_ARG_SRC, src_tr},
+ {DNNL_ARG_WEIGHTS, wgh_tr},
+ {DNNL_ARG_BIAS, bias_tr},
+ {DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+ {DNNL_ARG_DST, dst_tr}});
+ }
+
void BatchNorm(const size_t& nid) {
auto node = nodes_[nid];
@@ -755,7 +803,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
TensorRequisite GetOutput(const size_t& nid, const int idx) {
if (idx == -1) return {}; // -1 reserved value for empty input.
-
const JSONGraphNode& node = nodes_[nid];
ICHECK_LT(idx, node.GetNumOutput());
@@ -764,6 +811,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto eid = node_row_ptr_[nid] + static_cast<uint32_t>(idx);
ICHECK(data_entry_[eid] == nullptr);
+
auto desc = MakePlainDesc(shape, dtype);
return TensorRequisite::AsIs(desc, eid).Backward();
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 6c7034741a..dfe1b7265d 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -556,6 +556,35 @@ def get_dense(
return out, dic, param_lst
+def get_bmm(
+ x_shape=(1, 16, 8), k_shape=(1, 4, 8), dtype="float32", transpose_a=False, transpose_b=True
+):
+ x = relay.var("x", shape=(x_shape), dtype=dtype)
+ kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+ out = relay.nn.batch_matmul(
+ x, kernel, out_dtype=dtype, transpose_a=transpose_a, transpose_b=transpose_b
+ )
+ dic = {"x": x_shape, "kernel": k_shape}
+ param_lst = ["kernel"]
+ return out, dic, param_lst
+
+
+def test_bmm(run_module, dtype="float32"):
+ x_shape = (1, 2, 4)
+ k_shape = (1, 3, 4)
+
+ dense, dic, param_lst = get_bmm(x_shape, k_shape, dtype=dtype)
+ dense = tvm.IRModule.from_expr(dense)
+ config = dense, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+ k_shape_t = (1, 4, 3)
+ dense, dic, param_lst = get_bmm(x_shape, k_shape_t, dtype=dtype, transpose_b=False)
+ dense = tvm.IRModule.from_expr(dense)
+ config = dense, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
def get_dense_bias(
x_shape=(1, 16),
k_shape=(32, 16),