You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/10 11:16:13 UTC

[tvm] branch main updated: [BYOC-DNNL] add partition test on sum pattern (#12357)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 22ba659438 [BYOC-DNNL] add partition test on sum pattern (#12357)
22ba659438 is described below

commit 22ba659438a317ca59c8201430c662f86e2550fd
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Wed Aug 10 19:16:03 2022 +0800

    [BYOC-DNNL] add partition test on sum pattern (#12357)
    
    * add partition test on sum pattern
    
    * fix lint
---
 python/tvm/relay/op/contrib/dnnl.py             |  4 +-
 tests/python/relay/test_pass_partition_graph.py | 54 +++++++++++++++++++++++++
 2 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index f76d4bd10d..4ef342a26b 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -249,8 +249,6 @@ def make_sum_pattren_predicate(checker):
         for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]):
             args = get_args(e)
             attrs = get_attrs(e.args[0])
-            if attrs is None:
-                return False
             if not checker(attrs, args, op_name):
                 return False
         return True
@@ -284,6 +282,8 @@ def add_checker(attrs, args, op_name):
         if tuple(get_shape(args[0])) != tuple(get_shape(args[1])):
             return False
     if op_name == "bias_add":
+        if attrs is None:
+            return False
         if not isinstance(args[0].op, tvm.ir.op.Op):
             return False
         if args[0].op.name != "nn.conv2d":
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 5e1a812fa3..f073a00c19 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -930,6 +930,10 @@ def test_dnnl_fuse():
             conv2d_relu_pat = pattern
         elif pattern[0] == "dnnl.conv2d_sigmoid":
             conv2d_sigmoid_pat = pattern
+        elif pattern[0] == "dnnl.conv2d_bias_sum":
+            conv2d_bias_sum_pat = pattern
+        elif pattern[0] == "dnnl.conv2d_bias_sum_relu":
+            conv2d_bias_sum_relu_pat = pattern
 
     def get_blocks(
         prefix,
@@ -1009,6 +1013,52 @@ def test_dnnl_fuse():
         mod = get_partitoned_mod(mod, params, pattern_table)
         assert len(mod.functions) - 1 == num_expected_partition  # -1 for main
 
+    def test_sum_pattern(pattern_table, num_expected_partition):
+        def get_conv2d_bn_sum_relu(
+            x_shape=(1, 32, 8, 8),
+            k_shape=(16, 32, 3, 3),
+            sum_shape=(1, 16, 6, 6),
+            dtype="float32",
+        ):
+            x = relay.var("x", shape=(x_shape), dtype=dtype)
+            kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype))
+            bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+            beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
+            gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
+            moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
+            moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
+            sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
+
+            dic = {"x": x_shape, "bias": (k_shape[0],), "sum_data": sum_shape}
+            param_lst = ["bias", "sum_data"]
+
+            conv = relay.nn.conv2d(
+                x,
+                kernel,
+                channels=k_shape[0],
+                kernel_size=k_shape[2:4],
+            )
+            conv_bias = relay.nn.bias_add(conv, bias)
+            conv_bias_bn, _, _ = relay.nn.batch_norm(
+                conv_bias,
+                gamma=gamma,
+                beta=beta,
+                moving_mean=moving_mean,
+                moving_var=moving_var,
+                axis=1,
+                center=True,
+                scale=True,
+                epsilon=1e-5,
+            )
+            conv_bias_bn_sum = relay.add(conv_bias_bn, sum_data)
+            return relay.nn.relu(conv_bias_bn_sum), dic, param_lst
+
+        net, dic, param_lst = get_conv2d_bn_sum_relu()
+        net = tvm.IRModule.from_expr(net)
+        params = {x: np.random.uniform(-1, 1, dic[x]).astype("float32") for x in param_lst}
+        mod = get_partitoned_mod(net, params, pattern_table)
+        assert len(mod.functions) - 1 == num_expected_partition  # -1 for main
+
     def test_partition():
         # conv + bn + relu, conv + relu -> fused conv_bias_relu, conv, and relu
         test_detect_pattern([conv2d_bias_relu_pat], False, True, False, 3)
@@ -1033,6 +1083,10 @@ def test_dnnl_fuse():
         # conv + bias_add + bn + sigmoid + relu, conv + sigmoid + relu -> fused conv_bias_sigmoid,
         # fused conv_sigmoid and single op relu, relu
         test_detect_pattern([conv2d_bias_sigmoid_pat, conv2d_sigmoid_pat], True, True, True, 4)
+        # conv + bias_add + bn + add + relu -> fused conv_bias_sum, relu
+        test_sum_pattern([conv2d_bias_sum_pat], 2)
+        # conv + bias_add + bn + add + relu -> fused conv_bias_sum_relu,
+        test_sum_pattern([conv2d_bias_sum_relu_pat], 1)
 
     def test_partition_mobilenet():
         mod, params = relay.testing.mobilenet.get_workload()