You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by gi...@apache.org on 2022/06/06 14:10:30 UTC

[tvm] branch main updated: [microNPU] Optimize separate padding operation for conv2d (#11468)

This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1aac4d6826 [microNPU] Optimize separate padding operation for conv2d (#11468)
1aac4d6826 is described below

commit 1aac4d6826192383a755369ab5ccfe4876e8902b
Author: Luke Hutton <lu...@arm.com>
AuthorDate: Mon Jun 6 15:10:22 2022 +0100

    [microNPU] Optimize separate padding operation for conv2d (#11468)
    
    Optimizes a case where padding appears as a separate nn.pad operation followed by a qnn.conv2d. If possible, the nn.pad will be partitioned and offloaded together with the qnn.conv2d operation, as opposed to separately. As a fallback, both operations will be considered separately.
    
    cc Mousius NicolaLancellotti ekalda manupa-arm
---
 python/tvm/relay/op/contrib/ethosu.py             |  66 ++++++-
 tests/python/contrib/test_ethosu/infra.py         |  11 +-
 tests/python/contrib/test_ethosu/test_codegen.py  |  68 ++++++-
 tests/python/contrib/test_ethosu/test_legalize.py | 216 ++++++++++++++++++++++
 4 files changed, 349 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index dfdc0c82fb..806bf6dce2 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -201,6 +201,8 @@ class QnnConv2DParams:
         from tvm.relay.backend.contrib.ethosu.util import RequantArgs
 
         activation = None
+        separate_padding = None
+
         if str(func_body.op) in self.activation_map.keys():
             activation = func_body
             requantize_op = activation.args[0]
@@ -208,8 +210,11 @@ class QnnConv2DParams:
             requantize_op = func_body
         bias_add = requantize_op.args[0]
         qnn_conv2d = bias_add.args[0]
+        if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad":
+            separate_padding = qnn_conv2d.args[0]
         data_layout = qnn_conv2d.attrs.data_layout
         self.kernel_layout = qnn_conv2d.attrs.kernel_layout
+
         # We consider the weights & biases as params as it should be a Constant
         self.weights = TensorParams(
             qnn_conv2d.args[QConv2DArgs.WEIGHTS.value],
@@ -224,8 +229,11 @@ class QnnConv2DParams:
             requantize_op.args[RequantArgs.IFM_SCALE.value],
             requantize_op.args[RequantArgs.IFM_ZERO_POINT.value],
         )
+        ifm_tensor = (
+            separate_padding.args[0] if separate_padding else qnn_conv2d.args[QConv2DArgs.IFM.value]
+        )
         self.ifm = TensorParams(
-            qnn_conv2d.args[QConv2DArgs.IFM.value],
+            ifm_tensor,
             data_layout,
             qnn_conv2d.args[QConv2DArgs.IFM_SCALE.value],
             qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value],
@@ -237,7 +245,10 @@ class QnnConv2DParams:
             requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
         )
         attrs = qnn_conv2d.attrs
-        self.padding = attrs.padding
+
+        pad_value = int(qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value].data.asnumpy())
+        self.padding = self.extract_padding(attrs.padding, separate_padding, pad_value)
+
         self.strides = attrs.strides
         self.dilation = attrs.dilation
         self.activation = activation
@@ -250,6 +261,37 @@ class QnnConv2DParams:
         if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]:
             self.is_depthwise = True
 
+    @staticmethod
+    def extract_padding(
+        operator_padding: Tuple[int, int, int, int],
+        separate_padding: relay.Call,
+        pad_value: int,
+    ) -> Optional[Tuple[int, int, int, int]]:
+        """
+        Convolution operations can sometimes have padding represented as a separate
+        padding operation before the convolution operation itself. Here we can check
+        whether these representations can be combined into a single padding attribute
+        as part of the NPU convolution itself. If the padding specified by the separate
+        nn.pad operation is not supported, None will be returned. This will cause the
+        nn.pad to be offloaded separately.
+        """
+        if separate_padding is None:
+            return operator_padding
+        if pad_value != int(separate_padding.args[1].data.asnumpy()):
+            return None
+        pad_width = separate_padding.attrs["pad_width"]
+        if len(pad_width) != 4:
+            return None
+        if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]:
+            return None
+        top, left, bottom, right = operator_padding
+        return [
+            top + pad_width[1][0],
+            left + pad_width[2][0],
+            bottom + pad_width[1][1],
+            right + pad_width[2][1],
+        ]
+
     def is_valid(self) -> bool:
         """
         This function checks whether QnnConv2D has compatible attributes with the NPU
@@ -267,7 +309,7 @@ class QnnConv2DParams:
             return False
         if not check_dilation(self.dilation):
             return False
-        if not check_padding(self.padding, self.padding_bounds):
+        if not self.padding or not check_padding(self.padding, self.padding_bounds):
             return False
         legal_groups = [1, self.ofm.shape[3]]
         if self.groups not in legal_groups:
@@ -437,7 +479,7 @@ class QnnDepthwiseConv2DParams(QnnConv2DParams):
             return False
         if not check_dilation(self.dilation):
             return False
-        if not check_padding(self.padding, self.padding_bounds):
+        if not self.padding or not check_padding(self.padding, self.padding_bounds):
             return False
         if self.weights.layout != "HWOI":
             return False
@@ -453,8 +495,14 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     This function creates the pattern for qnn.conv2D with optional fused RELU activation.
     """
+    optional_pad = is_op("nn.pad")(wildcard(), is_constant())
     qnn_conv2d = is_op("qnn.conv2d")(
-        wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        optional_pad | wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
     ).has_attr({"kernel_layout": "HWIO"})
     bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
     req = is_op("qnn.requantize")(
@@ -468,8 +516,14 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
     This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation.
     """
+    optional_pad = is_op("nn.pad")(wildcard(), is_constant())
     qnn_conv2d = is_op("qnn.conv2d")(
-        wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
+        optional_pad | wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
     ).has_attr({"kernel_layout": "HWOI"})
     bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
     req = is_op("qnn.requantize")(
diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py
index a1bdcb47e6..1f999781e3 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -473,10 +473,17 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]
     assert len(strides) == 2
     assert len(dilation) == 2
     assert len(kernel_shape) == 2
-    if padding.lower() == "valid":
+    if isinstance(padding, tuple):
+        h = (
+            ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0] + padding[0] + padding[2]
+        ) // strides[0]
+        w = (
+            ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1] + padding[1] + padding[3]
+        ) // strides[1]
+    elif padding.lower() == "valid":
         h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0])
         w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1])
-    if padding.lower() == "same":
+    elif padding.lower() == "same":
         h = math.ceil(ifm_shape[1] / strides[0])
         w = math.ceil(ifm_shape[2] / strides[1])
     ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]]
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index b73ebd5361..2d3489889e 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -72,13 +72,43 @@ def test_ethosu_conv2d_single(
             padding=padding,
             dilations=dilation,
         )
-        if activation:
+        if activation == "RELU":
             op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)
 
 
+def test_tflite_conv2d_with_separate_pad():
+    np.random.seed(0)
+
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding = (0, 0, 1, 1)
+
+    @tf.function
+    def conv2d(x):
+        tf_strides = [1, strides[0], strides[1], 1]
+        op = tf.pad(
+            x,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
+        weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        return tf.nn.conv2d(
+            op,
+            weight,
+            strides=tf_strides,
+            padding="VALID",
+            dilations=dilation,
+        )
+
+    infra.compare_tvm_with_tflite(conv2d, [ifm_shape], "ethos-u55-256")
+
+
 @pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 2), (1, 27, 42, 3)])
 @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
 @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
@@ -120,7 +150,7 @@ def test_ethosu_conv2d_double(
             padding=padding,
             dilations=dilation,
         )
-        if activation:
+        if activation == "RELU":
             op2 = tf.nn.relu(op2)
         return op2
 
@@ -156,7 +186,7 @@ def test_out_of_range_scaling(weight_min, weight_max):
             padding=padding,
             dilations=dilation,
         )
-        if activation:
+        if activation == "RELU":
             op = tf.nn.relu(op)
         return op
 
@@ -191,13 +221,43 @@ def test_tflite_depthwise_conv2d(
         op = tf.nn.depthwise_conv2d(
             x, weight, strides=tf_strides, padding=padding, dilations=dilation
         )
-        if activation_function:
+        if activation_function == "RELU":
             op = tf.nn.relu(op)
         return op
 
     infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
 
 
+def test_tflite_depthwise_conv2d_with_separate_pad():
+    np.random.seed(0)
+
+    ifm_shape = (1, 23, 32, 7)
+    kernel_shape = (1, 2)
+    strides = (3, 2)
+    dilation = (1, 1)
+    padding = (0, 0, 1, 1)
+
+    @tf.function
+    def depthwise_conv2d(x):
+        tf_strides = [1, strides[0], strides[1], 1]
+        op = tf.pad(
+            x,
+            [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+            "CONSTANT",
+        )
+        weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+        weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+        return tf.nn.depthwise_conv2d(
+            op,
+            weight,
+            strides=tf_strides,
+            padding="VALID",
+            dilations=dilation,
+        )
+
+    infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256")
+
+
 @pytest.mark.parametrize(
     "accel_type",
     ACCEL_TYPES,
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 2dd5eff913..3f8b5f7d5b 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -347,6 +347,114 @@ def test_tflite_conv2d_legalize(ifm_shape, kernel_shape, padding, strides, dilat
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding = (0, 0, 1, 1)
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+                    "CONSTANT",
+                )
+                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                return tf.nn.conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = infra.compute_ofm_shape(
+            ifm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 3
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        assert list(op.attrs.padding) == list(padding)
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+
+    conv2d_pattern_table = [
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        )
+    ]
+
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
 @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)])
 @pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
@@ -458,6 +566,114 @@ def test_tflite_depthwise_conv_2d_legalize(
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_depthwise_conv2d_with_separate_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 23, 32, 7)
+    kernel_shape = (1, 2)
+    strides = (3, 2)
+    dilation = (1, 1)
+    padding = (0, 0, 1, 1)
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
+                    "CONSTANT",
+                )
+                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                return tf.nn.depthwise_conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = infra.compute_ofm_shape(
+            ifm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 1  # only depth multiplier 1 is supported
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        assert list(op.attrs.padding) == list(padding)
+        assert op.attrs.ofm_channels == ofm_channels
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+
+    depthwise_pattern_table = [
+        (
+            ethosu.QnnDepthwiseConv2DParams.composite_name,
+            ethosu.qnn_depthwise_conv2d_pattern(),
+            lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
+        )
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = partition_ethosu_by_table(mod, depthwise_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
 @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
 @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
 @pytest.mark.parametrize(