You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/05 07:41:33 UTC

[tvm] branch main updated: [BYOC-DNNL]rewrite downsize blocks for rensetv1 to get better performance (#11822)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b7e299f4a4 [BYOC-DNNL]rewrite downsize blocks for rensetv1 to get better performance (#11822)
b7e299f4a4 is described below

commit b7e299f4a4f9a90b2538d77bc3ae9da9bbff4ef1
Author: Ivy Zhang <ya...@intel.com>
AuthorDate: Tue Jul 5 15:41:25 2022 +0800

    [BYOC-DNNL]rewrite downsize blocks for rensetv1 to get better performance (#11822)
    
    * rewrite downsize blocks for rensetv1 to get better performance
    
    * fix lint
---
 python/tvm/relay/op/contrib/dnnl.py | 179 ++++++++++++++++++++++++++++++++++++
 tests/python/contrib/test_dnnl.py   | 100 ++++++++++++++++++++
 2 files changed, 279 insertions(+)

diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py
index c251b66bfb..b3ef478f20 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -782,6 +782,185 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
     return mod
 
 
+class ResNetV1Rewrite(DFPatternCallback):
+    """
+    A callback to advance downsize operation when the patterns are as pattern1,
+    and the result is written in pattern2:
+    Pattern #1:
+    %26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1));
+    %27 = add(%26, ty=Tensor[(64, 1, 1));
+    %28 = nn.relu(%27);
+
+    %29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3));
+    %30 = add(%29, ty=Tensor[(64, 1, 1));
+    %31 = nn.relu(%30);
+
+    %32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1));
+    %33 = add(%32, ty=Tensor[(256, 1, 1));
+    %34 = add(%33, %25);
+    %35 = nn.relu(%34);
+
+    %36 = nn.conv2d(%35, ty=Tensor[(128, 256, 1, 1), strides=[2, 2]);
+    %37 = add(%36, ty=Tensor[(128, 1, 1));
+    %38 = nn.relu(%37);
+
+    %39 = nn.conv2d(%38, ty=Tensor[(128, 128, 3, 3));
+    %40 = add(%39, ty=Tensor[(128, 1, 1)]);
+    %41 = nn.relu(%40);
+
+    %42 = nn.conv2d(%41, ty=Tensor[(512, 128, 1, 1));
+    %43 = nn.conv2d(%35, ty=Tensor[(512, 256, 1, 1), strides=[2, 2]);
+    %44 = add(%42, ty=Tensor[(512, 1, 1));
+    %45 = add(%43, ty=Tensor[(512, 1, 1));
+
+    %46 = add(%44, %45);
+    %47 = nn.relu(%46);
+    Pattern #2:
+    %26 = nn.conv2d(%25, ty=Tensor[(64, 256, 1, 1));
+    %27 = add(%26, ty=Tensor[(64, 1, 1));
+    %28 = nn.relu(%27);
+
+    %29 = nn.conv2d(%28, ty=Tensor[(64, 64, 3, 3), strides=[2, 2]);
+    %30 = add(%29, ty=Tensor[(64, 1, 1));
+    %31 = nn.relu(%30);
+
+    %32 = nn.conv2d(%31, ty=Tensor[(256, 64, 1, 1));
+    %33 = add(%32, ty=Tensor[(256, 1, 1));
+    %34 = nn.max_pool2d(%25, pool_size=[1, 1], strides=[2, 2], padding=[0, 0, 0, 0]);
+    %35 = add(%33, %34);
+    %36 = nn.relu(%35);
+
+    %37 = nn.conv2d(%36, ty=Tensor[(128, 256, 1, 1));
+    %38 = add(%37, ty=Tensor[(128, 1, 1));
+    %39 = nn.relu(%38);
+
+    %40 = nn.conv2d(%39, ty=Tensor[(128, 128, 3, 3));
+    %41 = add(%40, ty=Tensor[(128, 1, 1));
+    %42 = nn.relu(%41);
+
+    %43 = nn.conv2d(%42, ty=Tensor[(512, 128, 1, 1));
+    %44 = nn.conv2d(%36, ty=Tensor[(512, 256, 1, 1));
+    %45 = add(%43, ty=Tensor[(512, 1, 1));
+    %46 = add(%44, ty=Tensor[(512, 1, 1));
+    %47 = add(%45, %46);
+    %48 = nn.relu(%47);
+    """
+
+    def __init__(self):
+        super(ResNetV1Rewrite, self).__init__()
+        self.attr_lst = []
+        self.data = wildcard()
+        self.w1, self.b1 = wildcard(), wildcard()
+        self.w2, self.b2 = wildcard(), wildcard()
+        self.w3, self.b3 = wildcard(), wildcard()
+        self.w4, self.b4 = wildcard(), wildcard()
+        self.w5, self.b5 = wildcard(), wildcard()
+        self.w6, self.b6 = wildcard(), wildcard()
+        self.w7, self.b7 = wildcard(), wildcard()
+
+        conv1 = is_op("nn.conv2d")(self.data, self.w1).has_attr({"kernel_size": [1, 1]})
+        conv1 = is_op("add")(conv1, self.b1)
+        conv1 = is_op("nn.relu")(conv1)
+
+        conv2 = is_op("nn.conv2d")(conv1, self.w2).has_attr({"kernel_size": [3, 3]})
+        conv2 = is_op("add")(conv2, self.b2)
+        conv2 = is_op("nn.relu")(conv2)
+
+        conv3 = is_op("nn.conv2d")(conv2, self.w3).has_attr({"kernel_size": [1, 1]})
+        conv3 = is_op("add")(conv3, self.b3)
+        conv3 = is_op("add")(conv3, self.data)
+        conv3 = is_op("nn.relu")(conv3)
+
+        left_conv4 = is_op("nn.conv2d")(conv3, self.w4).has_attr({"strides": [2, 2]})
+        left_conv4 = is_op("add")(left_conv4, self.b4)
+        left_conv4 = is_op("nn.relu")(left_conv4)
+
+        left_conv5 = is_op("nn.conv2d")(left_conv4, self.w5).has_attr({"kernel_size": [3, 3]})
+        left_conv5 = is_op("add")(left_conv5, self.b5)
+        left_conv5 = is_op("nn.relu")(left_conv5)
+
+        left_conv6 = is_op("nn.conv2d")(left_conv5, self.w6).has_attr({"kernel_size": [1, 1]})
+        left_conv6 = is_op("add")(left_conv6, self.b6)
+
+        right_conv7 = is_op("nn.conv2d")(conv3, self.w7).has_attr({"strides": [2, 2]})
+        right_conv7 = is_op("add")(right_conv7, self.b7)
+
+        out = is_op("add")(left_conv6, right_conv7)
+        out = is_op("nn.relu")(out)
+        self.pattern = out
+
+    def get_attr(self, pre):
+        """Recursively retrieve attributes from reshape operator."""
+
+        def visit_func(expr):
+            if isinstance(expr, _expr.Call) and expr.op == relay.op.get("nn.conv2d"):
+                self.attr_lst.append(expr.attrs)
+
+        _analysis.post_order_visit(pre, visit_func)
+
+    def callback(self, pre, post, node_map):
+        self.get_attr(pre)
+        data = node_map[self.data][0]
+        w1, b1 = node_map[self.w1][0], node_map[self.b1][0]
+        w2, b2 = node_map[self.w2][0], node_map[self.b2][0]
+        w3, b3 = node_map[self.w3][0], node_map[self.b3][0]
+        w4, b4 = node_map[self.w4][0], node_map[self.b4][0]
+        w5, b5 = node_map[self.w5][0], node_map[self.b5][0]
+        w6, b6 = node_map[self.w6][0], node_map[self.b6][0]
+        w7, b7 = node_map[self.w7][0], node_map[self.b7][0]
+
+        new_attrs = self.attr_lst[-7]
+        conv1 = relay.op.nn.conv2d(data, w1, **new_attrs)
+        conv1 = relay.op.add(conv1, b1)
+        conv1 = relay.op.nn.relu(conv1)
+
+        new_attrs = dict(self.attr_lst[-6])
+        new_attrs["strides"] = [2, 2]
+        conv2 = relay.op.nn.conv2d(conv1, w2, **new_attrs)
+        conv2 = relay.op.add(conv2, b2)
+        conv2 = relay.op.nn.relu(conv2)
+
+        new_attrs = self.attr_lst[-5]
+        conv3 = relay.op.nn.conv2d(conv2, w3, **new_attrs)
+        conv3 = relay.op.add(conv3, b3)
+        max_pool = relay.op.nn.max_pool2d(
+            data, pool_size=(1, 1), strides=(2, 2), layout=new_attrs["data_layout"]
+        )
+        conv3 = relay.op.add(conv3, max_pool)
+        conv3 = relay.op.nn.relu(conv3)
+
+        new_attrs = dict(self.attr_lst[-4])
+        new_attrs["strides"] = [1, 1]
+        left_conv4 = relay.op.nn.conv2d(conv3, w4, **new_attrs)
+        left_conv4 = relay.op.add(left_conv4, b4)
+        left_conv4 = relay.op.nn.relu(left_conv4)
+
+        new_attrs = self.attr_lst[-3]
+        left_conv5 = relay.op.nn.conv2d(left_conv4, w5, **new_attrs)
+        left_conv5 = relay.op.add(left_conv5, b5)
+        left_conv5 = relay.op.nn.relu(left_conv5)
+
+        new_attrs = self.attr_lst[-2]
+        left_conv6 = relay.op.nn.conv2d(left_conv5, w6, **new_attrs)
+        left_conv6 = relay.op.add(left_conv6, b6)
+
+        new_attrs = dict(self.attr_lst[-1])
+        new_attrs["strides"] = [1, 1]
+        right_conv7 = relay.op.nn.conv2d(conv3, w7, **new_attrs)
+        right_conv7 = relay.op.add(right_conv7, b7)
+
+        out = relay.op.add(left_conv6, right_conv7)
+        out = relay.op.nn.relu(out)
+        self.attr_lst = []
+        return out
+
+
+def rewrite_resnetv1(mod):
+    """Rewrite the the ResNetV1 downsize block to reduce the computation complexity."""
+    mod["main"] = rewrite(ResNetV1Rewrite(), mod["main"])
+    return mod
+
+
 class LegalizeQnnOpForDnnl(DFPatternCallback):
     """Legalize QNN based patterns to match DNNL
 
diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py
index 2138eda086..078483798c 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -1128,6 +1128,106 @@ def test_rewrite_dense_bias_gelu_reshape_last(run_module, dtype="float32"):
     )
 
 
+def test_resnetv1_rewrite(run_module, dtype="float32"):
+    def get_graph():
+        data_shape = (1, 256, 56, 56)
+        w_shapes = [
+            (64, 256, 1, 1),
+            (64, 64, 3, 3),
+            (256, 64, 1, 1),
+            (128, 256, 1, 1),
+            (128, 128, 3, 3),
+            (512, 128, 1, 1),
+            (512, 256, 1, 1),
+        ]
+        x = relay.var("x", shape=data_shape, dtype=dtype)
+        wights = [relay.const(np.random.randint(0, 1, w).astype(dtype)) for w in w_shapes]
+        biases = [relay.const(np.random.randint(0, 1, w[0]).astype(dtype)) for w in w_shapes]
+
+        conv1 = relay.nn.conv2d(
+            x,
+            wights[0],
+            channels=w_shapes[0][0],
+            kernel_size=w_shapes[0][2:4],
+            padding=(w_shapes[0][2] // 2, w_shapes[0][3] // 2),
+        )
+        conv1 = relay.nn.bias_add(conv1, biases[0])
+        conv1 = relay.nn.relu(conv1)
+
+        conv2 = relay.nn.conv2d(
+            conv1,
+            wights[1],
+            channels=w_shapes[1][0],
+            kernel_size=w_shapes[1][2:4],
+            padding=(w_shapes[1][2] // 2, w_shapes[1][3] // 2),
+        )
+        conv2 = relay.nn.bias_add(conv2, biases[1])
+        conv2 = relay.nn.relu(conv2)
+
+        conv3 = relay.nn.conv2d(
+            conv2,
+            wights[2],
+            channels=w_shapes[2][0],
+            kernel_size=w_shapes[2][2:4],
+            padding=(w_shapes[2][2] // 2, w_shapes[2][3] // 2),
+        )
+        conv3 = relay.nn.bias_add(conv3, biases[2])
+        conv3 = relay.add(conv3, x)
+        conv3 = relay.nn.relu(conv3)
+
+        left_conv4 = relay.nn.conv2d(
+            conv3,
+            wights[3],
+            channels=w_shapes[3][0],
+            strides=(2, 2),
+            kernel_size=w_shapes[3][2:4],
+            padding=(w_shapes[3][2] // 2, w_shapes[3][3] // 2),
+        )
+        left_conv4 = relay.nn.bias_add(left_conv4, biases[3])
+        left_conv4 = relay.nn.relu(left_conv4)
+
+        left_conv5 = relay.nn.conv2d(
+            left_conv4,
+            wights[4],
+            channels=w_shapes[4][0],
+            kernel_size=w_shapes[4][2:4],
+            padding=(w_shapes[4][2] // 2, w_shapes[4][3] // 2),
+        )
+        left_conv5 = relay.nn.bias_add(left_conv5, biases[4])
+        left_conv5 = relay.nn.relu(left_conv5)
+
+        left_conv6 = relay.nn.conv2d(
+            left_conv5,
+            wights[5],
+            channels=w_shapes[5][0],
+            kernel_size=w_shapes[5][2:4],
+            padding=(w_shapes[5][2] // 2, w_shapes[5][3] // 2),
+        )
+        left_conv6 = relay.nn.bias_add(left_conv6, biases[5])
+
+        right_conv7 = relay.nn.conv2d(
+            conv3,
+            wights[6],
+            channels=w_shapes[6][0],
+            strides=(2, 2),
+            kernel_size=w_shapes[6][2:4],
+            padding=(w_shapes[6][2] // 2, w_shapes[6][3] // 2),
+        )
+        right_conv7 = relay.nn.bias_add(right_conv7, biases[6])
+
+        out = relay.add(left_conv6, right_conv7)
+        out = relay.nn.relu(out)
+
+        dic = {"x": data_shape}
+        param_lst = []
+        return out, dic, param_lst
+
+    net, dic, param_lst = get_graph()
+    net = tvm.IRModule.from_expr(net)
+    config = net, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
 def permute_shape(shape, l_from="", l_to=""):
     res_shape = []
     for label in l_to: