You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "sergio-grovety (via GitHub)" <gi...@apache.org> on 2023/05/04 11:29:00 UTC

[GitHub] [tvm] sergio-grovety opened a new pull request, #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

sergio-grovety opened a new pull request, #14765:
URL: https://github.com/apache/tvm/pull/14765

   A separate channel-dimension nn.pad relay operator is rewritten as Relay concatenate operation.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "ekalda (via GitHub)" <gi...@apache.org>.
ekalda commented on PR #14765:
URL: https://github.com/apache/tvm/pull/14765#issuecomment-1554746842

   Thanks @arina-grovety, @sergio-grovety and @Aleksei-grovety, this is now merged!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14765:
URL: https://github.com/apache/tvm/pull/14765#issuecomment-1534604259

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @Mousius, @leandron, @lhutton1 <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1196352564


##########
python/tvm/relay/backend/contrib/ethosu/legalize.py:
##########
@@ -1447,6 +1447,84 @@ def callback(
         )
 
 
+class ChannelPadRewriter(DFPatternCallback):
+    """Convert ethos-u.pad2d composite function to the Relay concatenate operation"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": ethosu_patterns.ChannelPadParams.composite_name})
+        )(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = ethosu_patterns.ChannelPadParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+
+        concat_args = list()
+        # Activations requiring LUT is currently not supported, so setting it to an empty list

Review Comment:
   Got it, thank you.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1185816248


##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -2000,6 +2000,79 @@ def is_valid(self):
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethosu.pad2d composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.channel-pad"
+    # The ethos-u.channel-pad composite function will be transformed
+    # to the Relay concatenate operation.
+
+    def __init__(self, func_body: Call):
+        from tvm.relay.backend.contrib.ethosu.util import QPadArgs
+
+        # there is no 'layout' attribute in nn.pad
+        layout = "NHWC"
+        self.ifm = TensorParams(
+            tensor=func_body.args[QPadArgs.IFM.value],
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+        self.ch_padding = self.extract_ch_padding(func_body)
+        self.ofm = TensorParams(
+            tensor=func_body,
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+    @staticmethod
+    def extract_ch_padding(
+        padding: relay.Call,
+    ) -> Optional[Tuple[int, int]]:
+        """
+        Here we check whether a separate channel-dimension padding operation can be
+        rewritten as Relay concatenate operation. If the padding specified by the
+        separate nn.pad operation is not supported by NPU, None will be returned.
+        This will cause the nn.pad not to be offloaded to NPU.
+        """
+        pad_width = padding.attrs["pad_width"]
+        if len(pad_width) != 4:
+            return None
+        if (
+            list(pad_width[0]) != [0, 0]
+            or list(pad_width[1]) != [0, 0]

Review Comment:
   Yes, you are right, spatial and channel pad can of course occur in neural networks. This is a separate task, we discussed that it will be useful in the future. We plan to solve it when we have time or a network with such pad appears.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] sergio-grovety commented on pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "sergio-grovety (via GitHub)" <gi...@apache.org>.
sergio-grovety commented on PR #14765:
URL: https://github.com/apache/tvm/pull/14765#issuecomment-1554169757

   @tvm-bot  rerun


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "ekalda (via GitHub)" <gi...@apache.org>.
ekalda commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1195229954


##########
python/tvm/relay/backend/contrib/ethosu/legalize.py:
##########
@@ -1447,6 +1447,84 @@ def callback(
         )
 
 
+class ChannelPadRewriter(DFPatternCallback):
+    """Convert ethos-u.pad2d composite function to the Relay concatenate operation"""

Review Comment:
   ```suggestion
       """Convert ethos-u.channel-pad composite function to the Relay concatenate operation"""
   ```



##########
python/tvm/relay/backend/contrib/ethosu/legalize.py:
##########
@@ -1447,6 +1447,84 @@ def callback(
         )
 
 
+class ChannelPadRewriter(DFPatternCallback):
+    """Convert ethos-u.pad2d composite function to the Relay concatenate operation"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": ethosu_patterns.ChannelPadParams.composite_name})
+        )(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = ethosu_patterns.ChannelPadParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+
+        concat_args = list()
+        # Activations requiring LUT is currently not supported, so setting it to an empty list

Review Comment:
   I know every operator here has this copy pasted legacy comment, but let's remove it... Firstly LUT based activations are supported and secondly it could leave an impression that implementing something like fused pad + sigmoid is a TODO.



##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -462,6 +463,118 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_channel_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding_ch = (1, 1)
+
+    class ArePadOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'nn.pad' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "nn.pad":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_pad_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 'nn.pad' op is ongraph

Review Comment:
   Nit:
   ```suggestion
               This function recursively visits the graph and checks if 'nn.pad' op is on graph
   ```



##########
python/tvm/relay/backend/contrib/ethosu/legalize.py:
##########
@@ -1447,6 +1447,84 @@ def callback(
         )
 
 
+class ChannelPadRewriter(DFPatternCallback):
+    """Convert ethos-u.pad2d composite function to the Relay concatenate operation"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": ethosu_patterns.ChannelPadParams.composite_name})
+        )(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = ethosu_patterns.ChannelPadParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+
+        concat_args = list()
+        # Activations requiring LUT is currently not supported, so setting it to an empty list
+        lut = relay.const([], dtype="int8")
+        # pad channels before
+        if params.ch_padding[0] > 0:
+            shape1 = list(params.ifm.shape)
+            shape1[3] = params.ch_padding[0].value
+            pad_channels = relay.Constant(
+                tvm.nd.array(
+                    np.full(
+                        shape=shape1,
+                        fill_value=int(params.ifm.q_params.zero_point),
+                        dtype=params.ifm.dtype,
+                    )
+                )
+            )
+            identity1 = ethosu_ops.ethosu_identity(
+                ifm=pad_channels,
+                lut=lut,
+                ifm_scale=float(params.ifm.q_params.scale_f32),
+                ifm_zero_point=int(params.ifm.q_params.zero_point),
+                ofm_scale=float(params.ofm.q_params.scale_f32),
+                ofm_zero_point=int(params.ofm.q_params.zero_point),
+            )
+            concat_args.append(identity1)
+
+        identity2 = ethosu_ops.ethosu_identity(
+            ifm=params.ifm.tensor,
+            lut=lut,
+            ifm_scale=float(params.ifm.q_params.scale_f32),
+            ifm_zero_point=int(params.ifm.q_params.zero_point),
+            ofm_scale=float(params.ofm.q_params.scale_f32),
+            ofm_zero_point=int(params.ofm.q_params.zero_point),
+        )
+        concat_args.append(identity2)
+
+        # pad channels after
+        if params.ch_padding[1] > 0:
+            shape3 = list(params.ifm.shape)
+            shape3[3] = params.ch_padding[1].value
+            pad_channels3 = relay.Constant(
+                tvm.nd.array(
+                    np.full(
+                        shape=shape3,
+                        fill_value=int(params.ifm.q_params.zero_point),
+                        dtype=params.ifm.dtype,
+                    )
+                )
+            )
+            identity3 = ethosu_ops.ethosu_identity(
+                ifm=pad_channels3,
+                lut=lut,
+                ifm_scale=float(params.ifm.q_params.scale_f32),
+                ifm_zero_point=int(params.ifm.q_params.zero_point),
+                ofm_scale=float(params.ofm.q_params.scale_f32),
+                ofm_zero_point=int(params.ofm.q_params.zero_point),
+            )
+            concat_args.append(identity3)
+
+        axis = 3
+        return relay.op.concatenate(relay.Tuple(concat_args), axis=axis)

Review Comment:
   Since it is not used elsewhere, maybe just 
   ```suggestion
           return relay.op.concatenate(relay.Tuple(concat_args), axis=3)
   ```



##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -760,7 +873,98 @@ def verify(ext_func):
             ethosu.PadParams.composite_name,
             ethosu.pad_pattern(),
             lambda pat: ethosu.PadParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = partition_ethosu_by_table(mod, pad_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.PadRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
+@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
+@pytest.mark.parametrize("channel_padding", [(0, 1), (1, 1), (5, 2)])
+@pytest.mark.parametrize("const_value", [0, 5, 125, -5])
+def test_tflite_separate_channel_padding_legalize(ifm_shape, channel_padding, const_value):

Review Comment:
   Shouldn't this test be using `ChannelPadRewriter` and then check in `verify` that the concatenates got created? 



##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -462,6 +463,118 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_channel_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding_ch = (1, 1)
+
+    class ArePadOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'nn.pad' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "nn.pad":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_pad_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 'nn.pad' op is ongraph
+            """
+            self.visit(subgraph)
+            return self.on_graph
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [0, 0], [0, 0], [padding_ch[0], padding_ch[1]]],
+                    "CONSTANT",
+                )
+                # HWIO
+                weight_shape = [
+                    kernel_shape[0],
+                    kernel_shape[1],
+                    ifm_shape[3] + padding_ch[0] + padding_ch[1],
+                    3,
+                ]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                return tf.nn.conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+
+        assert ArePadOnGraph().are_pad_on_graph(ext_func.body) == True
+
+    conv2d_pattern_table = [
+        (
+            ethosu.ChannelPadParams.composite_name,
+            ethosu.pad_pattern(),
+            lambda pat: ethosu.ChannelPadParams(pat).is_valid(),
+        ),
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])

Review Comment:
   I'm a bit confused about what that test does... It creates a TFLite graph with channel pad and conv2d, then partitions them for microNPU, then legalizes the conv2d into `ethosu.conv2d` and then checks that the Relay pad is still there? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1196374217


##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -462,6 +463,118 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_channel_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding_ch = (1, 1)
+
+    class ArePadOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'nn.pad' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "nn.pad":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_pad_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 'nn.pad' op is ongraph
+            """
+            self.visit(subgraph)
+            return self.on_graph
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [0, 0], [0, 0], [padding_ch[0], padding_ch[1]]],
+                    "CONSTANT",
+                )
+                # HWIO
+                weight_shape = [
+                    kernel_shape[0],
+                    kernel_shape[1],
+                    ifm_shape[3] + padding_ch[0] + padding_ch[1],
+                    3,
+                ]
+                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
+                return tf.nn.conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+
+        assert ArePadOnGraph().are_pad_on_graph(ext_func.body) == True
+
+    conv2d_pattern_table = [
+        (
+            ethosu.ChannelPadParams.composite_name,
+            ethosu.pad_pattern(),
+            lambda pat: ethosu.ChannelPadParams(pat).is_valid(),
+        ),
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])

Review Comment:
   Hi @ekalda, thank you for the review! 
   Yes, it is. Here we check that the pad by channel does not merge with conv2d, as it happens with the spatial pad.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Aleksei-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "Aleksei-grovety (via GitHub)" <gi...@apache.org>.
Aleksei-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1184956051


##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -2000,6 +2000,79 @@ def is_valid(self):
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethosu.pad2d composite function

Review Comment:
   This class will parse a call to a ethos-u.channel-pad composite function



##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -2000,6 +2000,79 @@ def is_valid(self):
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethosu.pad2d composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.channel-pad"
+    # The ethos-u.channel-pad composite function will be transformed
+    # to the Relay concatenate operation.
+
+    def __init__(self, func_body: Call):
+        from tvm.relay.backend.contrib.ethosu.util import QPadArgs
+
+        # there is no 'layout' attribute in nn.pad
+        layout = "NHWC"
+        self.ifm = TensorParams(
+            tensor=func_body.args[QPadArgs.IFM.value],
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+        self.ch_padding = self.extract_ch_padding(func_body)
+        self.ofm = TensorParams(
+            tensor=func_body,
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+    @staticmethod
+    def extract_ch_padding(
+        padding: relay.Call,
+    ) -> Optional[Tuple[int, int]]:
+        """
+        Here we check whether a separate channel-dimension padding operation can be
+        rewritten as Relay concatenate operation. If the padding specified by the
+        separate nn.pad operation is not supported by NPU, None will be returned.
+        This will cause the nn.pad not to be offloaded to NPU.
+        """
+        pad_width = padding.attrs["pad_width"]
+        if len(pad_width) != 4:
+            return None
+        if (
+            list(pad_width[0]) != [0, 0]
+            or list(pad_width[1]) != [0, 0]

Review Comment:
   Are there networks that have paddings in height, width and channels? If there are such, then it would be possible to remove width and height restrictions and add width and height padding processing to the legalization using depthwise convolution as it is done for pad2d.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] sergio-grovety commented on pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "sergio-grovety (via GitHub)" <gi...@apache.org>.
sergio-grovety commented on PR #14765:
URL: https://github.com/apache/tvm/pull/14765#issuecomment-1534607348

   cc @neildhickey, @ekalda, @ilyag-grovety, @Alex-grovety @arina-grovety


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1197851648


##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -760,7 +873,98 @@ def verify(ext_func):
             ethosu.PadParams.composite_name,
             ethosu.pad_pattern(),
             lambda pat: ethosu.PadParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = partition_ethosu_by_table(mod, pad_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.PadRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
+@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
+@pytest.mark.parametrize("channel_padding", [(0, 1), (1, 1), (5, 2)])
+@pytest.mark.parametrize("const_value", [0, 5, 125, -5])
+def test_tflite_separate_channel_padding_legalize(ifm_shape, channel_padding, const_value):

Review Comment:
   Thanks for the comment, the test really did not correspond to the tested functionality. I corrected the test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1185816248


##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -2000,6 +2000,79 @@ def is_valid(self):
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethosu.pad2d composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.channel-pad"
+    # The ethos-u.channel-pad composite function will be transformed
+    # to the Relay concatenate operation.
+
+    def __init__(self, func_body: Call):
+        from tvm.relay.backend.contrib.ethosu.util import QPadArgs
+
+        # there is no 'layout' attribute in nn.pad
+        layout = "NHWC"
+        self.ifm = TensorParams(
+            tensor=func_body.args[QPadArgs.IFM.value],
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+        self.ch_padding = self.extract_ch_padding(func_body)
+        self.ofm = TensorParams(
+            tensor=func_body,
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+    @staticmethod
+    def extract_ch_padding(
+        padding: relay.Call,
+    ) -> Optional[Tuple[int, int]]:
+        """
+        Here we check whether a separate channel-dimension padding operation can be
+        rewritten as Relay concatenate operation. If the padding specified by the
+        separate nn.pad operation is not supported by NPU, None will be returned.
+        This will cause the nn.pad not to be offloaded to NPU.
+        """
+        pad_width = padding.attrs["pad_width"]
+        if len(pad_width) != 4:
+            return None
+        if (
+            list(pad_width[0]) != [0, 0]
+            or list(pad_width[1]) != [0, 0]

Review Comment:
   Yes, you are right, spatial and channel pad can of course occur in neural networks. This is a separate task in our backlog. We plan to solve it when we have time or a network with such pad appears.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] arina-grovety commented on a diff in pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "arina-grovety (via GitHub)" <gi...@apache.org>.
arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1185813101


##########
python/tvm/relay/op/contrib/ethosu.py:
##########
@@ -2000,6 +2000,79 @@ def is_valid(self):
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethosu.pad2d composite function

Review Comment:
   Hi @Aleksei-grovety , thank you, done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] sergio-grovety commented on pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "sergio-grovety (via GitHub)" <gi...@apache.org>.
sergio-grovety commented on PR #14765:
URL: https://github.com/apache/tvm/pull/14765#issuecomment-1537076322

   @tvm-bot rerun


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ekalda merged pull request #14765: [microNPU][ETHOSU] Channel pad offloaded to NPU

Posted by "ekalda (via GitHub)" <gi...@apache.org>.
ekalda merged PR #14765:
URL: https://github.com/apache/tvm/pull/14765


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org