You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/11 10:10:10 UTC

[tvm] branch main updated: enable bmm (#12018)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d5072858c enable bmm (#12018)
2d5072858c is described below

commit 2d5072858c9749217256913cad3c14fe52be0367
Author: billishyahao <ya...@intel.com>
AuthorDate: Mon Jul 11 18:09:58 2022 +0800

    enable bmm (#12018)
---
 python/tvm/relay/op/contrib/dnnl.py           |  4 ++-
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 50 ++++++++++++++++++++++++++-
 tests/python/contrib/test_dnnl.py             | 29 ++++++++++++++++
 3 files changed, 81 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index 9b6b45240a..05416bb9a3 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -105,6 +105,7 @@ _register_external_op_helper("nn.softmax")
 _register_external_op_helper("add")
 _register_external_op_helper("multiply")
 _register_external_op_helper("nn.layer_norm")
+_register_external_op_helper("nn.batch_matmul")
 
 
 def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
@@ -563,6 +564,7 @@ class IsComputeIntensiveGraph(ExprVisitor):
                 "nn.conv3d_transpose",
                 "nn.dense",
                 "nn.layer_norm",
+                "nn.batch_matmul",
             ]
         )
         if isinstance(call.op, tvm.tir.op.Op):
@@ -679,7 +681,7 @@ class LayerNormRewrite(DFPatternCallback):
         const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
         p1 = is_op("power")(cdiff, const_two)
         mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
-        eps = is_expr(relay.const(1e-5))
+        eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
         added_eps = is_op("add")(mp1, eps)
         deno = is_op("sqrt")(added_eps)
         div_out = is_op("divide")(diff, deno)
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index 6c0fd64066..c6e50eafea 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -269,6 +269,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
           Binary(nid, dnnl::algorithm::binary_mul);
         } else if ("nn.layer_norm" == op_name) {
           LayerNorm(nid);
+        } else if ("nn.batch_matmul" == op_name) {
+          BatchMatMul(nid);
         } else {
           LOG(FATAL) << "Unsupported op: " << op_name;
         }
@@ -483,6 +485,52 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
            {sum_in_tr, DNNL_ARG_DST});
   }
 
+  void BatchMatMul(const size_t& nid) {
+    auto node = nodes_[nid];
+
+    // Setup attributes.
+    auto src_tr = GetInput(nid, 0);
+    auto wgh_tr = GetInput(nid, 1);
+    auto dst_tr = GetOutput(nid, 0);
+    auto bias_tr = TensorRequisite{};
+
+    auto attr = ParseAttrs(nid, &bias_tr);
+    attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+
+    bool transpose_a = GetNodeAttr<bool>(node, "transpose_a");
+    bool transpose_b = GetNodeAttr<bool>(node, "transpose_b");
+
+    if (transpose_a) {
+      src_tr = src_tr.Permute({0, 2, 1});
+    }
+    if (transpose_b) {
+      wgh_tr = wgh_tr.Permute({0, 2, 1});
+    }
+
+    // Assumption that bias is correct and can be squeezed to 1D
+    bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});
+
+    // Matmul description.
+    auto bmm_desc = dnnl::matmul::desc(src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(),
+                                       bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc());
+
+    // Enable elementwise post-ops.
+    auto bmm_prim_desc = dnnl::matmul::primitive_desc(bmm_desc, attr, engine_);
+
+    src_tr = src_tr.RequestLayout(bmm_prim_desc.src_desc());
+    wgh_tr = wgh_tr.RequestLayout(bmm_prim_desc.weights_desc());
+    dst_tr = dst_tr.RequestLayout(bmm_prim_desc.dst_desc());
+    bias_tr = bias_tr.RequestLayout(bmm_prim_desc.bias_desc());
+
+    auto scratchpad_tr = TensorRequisite::AsIs(bmm_prim_desc.scratchpad_desc());
+
+    Submit(dnnl::matmul(bmm_prim_desc), {{DNNL_ARG_SRC, src_tr},
+                                         {DNNL_ARG_WEIGHTS, wgh_tr},
+                                         {DNNL_ARG_BIAS, bias_tr},
+                                         {DNNL_ARG_SCRATCHPAD, scratchpad_tr},
+                                         {DNNL_ARG_DST, dst_tr}});
+  }
+
   void BatchNorm(const size_t& nid) {
     auto node = nodes_[nid];
 
@@ -755,7 +803,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
 
   TensorRequisite GetOutput(const size_t& nid, const int idx) {
     if (idx == -1) return {};  // -1 reserved value for empty input.
-
     const JSONGraphNode& node = nodes_[nid];
 
     ICHECK_LT(idx, node.GetNumOutput());
@@ -764,6 +811,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
     auto eid = node_row_ptr_[nid] + static_cast<uint32_t>(idx);
 
     ICHECK(data_entry_[eid] == nullptr);
+
     auto desc = MakePlainDesc(shape, dtype);
 
     return TensorRequisite::AsIs(desc, eid).Backward();
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 6c7034741a..dfe1b7265d 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -556,6 +556,35 @@ def get_dense(
     return out, dic, param_lst
 
 
+def get_bmm(
+    x_shape=(1, 16, 8), k_shape=(1, 4, 8), dtype="float32", transpose_a=False, transpose_b=True
+):
+    x = relay.var("x", shape=(x_shape), dtype=dtype)
+    kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+    out = relay.nn.batch_matmul(
+        x, kernel, out_dtype=dtype, transpose_a=transpose_a, transpose_b=transpose_b
+    )
+    dic = {"x": x_shape, "kernel": k_shape}
+    param_lst = ["kernel"]
+    return out, dic, param_lst
+
+
+def test_bmm(run_module, dtype="float32"):
+    x_shape = (1, 2, 4)
+    k_shape = (1, 3, 4)
+
+    dense, dic, param_lst = get_bmm(x_shape, k_shape, dtype=dtype)
+    dense = tvm.IRModule.from_expr(dense)
+    config = dense, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+    k_shape_t = (1, 4, 3)
+    dense, dic, param_lst = get_bmm(x_shape, k_shape_t, dtype=dtype, transpose_b=False)
+    dense = tvm.IRModule.from_expr(dense)
+    config = dense, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
 def get_dense_bias(
     x_shape=(1, 16),
     k_shape=(32, 16),