You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ek...@apache.org on 2023/01/18 15:09:49 UTC

[tvm] branch main updated: [microNPU] Add hardware constraints for binary elementwise (#13772)

This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 60358a145b [microNPU] Add hardware constraints for binary elementwise (#13772)
60358a145b is described below

commit 60358a145ba7094d9a41aabfaa25544f58e04dae
Author: Alexey Yazev <11...@users.noreply.github.com>
AuthorDate: Wed Jan 18 19:09:40 2023 +0400

    [microNPU] Add hardware constraints for binary elementwise (#13772)
    
    Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU. Since there are hardware constraints, we cannot perform min or max operation fused with requantize (please look at NPU_SET_OFM_SCALE register description https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-) when we have different scales.
    min/max operations with matching scales are offloaded to NPU as ethosu_binary_elementwise
    min/max operations with different scales are offloaded to NPU as ethosu_binary_elementwise + ethosu_identity
---
 python/tvm/relay/op/contrib/ethosu.py             | 80 +++++++++++++++++-----
 tests/python/contrib/test_ethosu/test_codegen.py  | 23 +++++++
 tests/python/contrib/test_ethosu/test_legalize.py | 81 ++++++++++++++++++-----
 3 files changed, 150 insertions(+), 34 deletions(-)

diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index bd9a7d5ba0..5d1e75b030 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -700,15 +700,13 @@ class BinaryElementwiseParams:
         clip = None
         requantize = None
 
-        if is_quantized_operation:
-            if str(current_call.op.name) == "clip":
-                clip = current_call
-                current_call = clip.args[0]
-        else:
-            if str(current_call.op.name) == "qnn.requantize":
-                requantize = current_call
-                clip = current_call.args[0]
-                current_call = clip.args[0]
+        if str(current_call.op.name) == "clip":
+            clip = current_call
+            current_call = clip.args[0]
+        elif str(current_call.op.name) == "qnn.requantize":
+            requantize = current_call
+            clip = current_call.args[0]
+            current_call = clip.args[0]
         binary_op = current_call
 
         layout = "NHWC"
@@ -941,21 +939,40 @@ class MinParams(BinaryElementwiseParams):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
+        # MIN with different scales is not supported on NPU
+        # (please look at NPU_SET_OFM_SCALE register description
+        # https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
+        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
+            return False
         return True
 
 
+# This pattern is for case when there are different scales for requantize and
+# minimum + clip + qnn.requantize can't be offloaded to NPU by one operation
+# due to hardware constraints.
+# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
 def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
-    This function creates the pattern for minimum with optional fused RELU activation.
+    This function creates the pattern for minimum with optional fused RELU activation without
+    requantize.
     """
     minimum = is_op("minimum")(wildcard(), wildcard())
     optional_min_clip = is_op("clip")(minimum)
-    optional_min_clip = is_op("qnn.requantize")(
-        optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
-    )
     return minimum | optional_min_clip
 
 
+def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for minimum with fused RELU activation with requantize.
+    """
+    pattern = is_op("minimum")(wildcard(), wildcard())
+    pattern = is_op("clip")(pattern)
+    pattern = is_op("qnn.requantize")(
+        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    )
+    return pattern
+
+
 class MaxParams(BinaryElementwiseParams):
     """
     This class will parse a call to a ethosu.binary_elementwise Max composite function
@@ -979,21 +996,40 @@ class MaxParams(BinaryElementwiseParams):
             [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
         ):
             return False
+        # MAX with different scales is not supported on NPU
+        # (please look at NPU_SET_OFM_SCALE register description
+        # https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
+        if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
+            return False
         return True
 
 
+# This pattern is for case when there are different scales for requantize and
+# maximum + clip + qnn.requantize can't be offloaded to NPU by one operation due to
+# hardware constraints.
+# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
 def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     """
-    This function creates the pattern for maximum with optional fused RELU activation.
+    This function creates the pattern for maximum with optional fused RELU activation without
+    requantize.
     """
     maximum = is_op("maximum")(wildcard(), wildcard())
     optional_max_clip = is_op("clip")(maximum)
-    optional_max_clip = is_op("qnn.requantize")(
-        optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
-    )
     return maximum | optional_max_clip
 
 
+def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for maximum with fused RELU activation with requantize.
+    """
+    pattern = is_op("maximum")(wildcard(), wildcard())
+    pattern = is_op("clip")(pattern)
+    pattern = is_op("qnn.requantize")(
+        pattern, is_constant(), is_constant(), is_constant(), is_constant()
+    )
+    return pattern
+
+
 class ShlParams(BinaryElementwiseParams):
     """
     This class will parse a call to a ethosu.binary_elementwise Shl composite function
@@ -1913,11 +1949,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
             qnn_mul_pattern(),
             lambda pat: MulParams(pat).is_valid(),
         ),
+        (
+            MinParams.composite_name,
+            minimum_clip_requantize_pattern(),
+            lambda pat: MinParams(pat).is_valid(),
+        ),
         (
             MinParams.composite_name,
             minimum_pattern(),
             lambda pat: MinParams(pat).is_valid(),
         ),
+        (
+            MaxParams.composite_name,
+            maximum_clip_requantize_pattern(),
+            lambda pat: MaxParams(pat).is_valid(),
+        ),
         (
             MaxParams.composite_name,
             maximum_pattern(),
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index dc54ef071d..05ba7467b3 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1191,6 +1191,29 @@ def test_tflite_relu6():
     )
 
 
+# Specific case when operation cannot be offloaded to NPU by single binary elementwise operation because
+# min and max operations cannot be fused with requantize if there are different scales as it's not supported on NPU.
+@pytest.mark.parametrize("operation", [tf.math.minimum, tf.math.maximum])
+def test_tflite_min_max_relu_n1_to_1(operation):
+    np.random.seed(0)
+    accel_type = "ethos-u55-128"
+    ifm_shape = (1, 12, 16, 8)
+
+    @tf.function
+    def min_max_relu_n1_to_1(lhs, rhs):
+        op = operation(lhs, rhs)
+        # The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
+        return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))
+
+    infra.compare_tvm_with_tflite(
+        min_max_relu_n1_to_1,
+        [ifm_shape, ifm_shape],
+        accel_type,
+        enable_cascader=True,
+        ranges=[(-1, 1), (0, 2)],
+    )
+
+
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
 @pytest.mark.parametrize("ofm_channels", [32, 64])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 5ddc7565f2..5bc31dacb5 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -53,6 +53,13 @@ def partition_ethosu_by_table(mod, pattern_table):
     return mod
 
 
+def relu_n1_to_1(x):
+    """
+    The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
+    """
+    return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))
+
+
 def test_split_indices_legalize():
     def create_graph(axis):
         x = relay.var("x", shape=(1, 50, 50, 3))
@@ -881,7 +888,7 @@ def test_tflite_pool2d_legalize(
         ([1, 4, 4], [4, 1], False),
     ],
 )
-@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
+@pytest.mark.parametrize("activation_function", [None, tf.nn.relu])
 def test_tflite_binary_elemwise_legalize(
     operator_type,
     ifm_shape,
@@ -906,8 +913,8 @@ def test_tflite_binary_elemwise_legalize(
                     op = tf.math.minimum(x, y)
                 elif operator_type == "MAX":
                     op = tf.math.maximum(x, y)
-                if activation_function == "RELU":
-                    op = tf.nn.relu(op)
+                if activation_function:
+                    op = activation_function(op)
                 return op
 
         model = Model()
@@ -938,9 +945,13 @@ def test_tflite_binary_elemwise_legalize(
         op = ext_func.body
 
         has_reshaped_output = False
+        has_separate_requantize = False
         shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
         out_padded = [1] * (4 - len(out_shape)) + out_shape
-        if op.op.name != "contrib.ethosu.binary_elementwise":
+        if op.op.name == "contrib.ethosu.identity":
+            op = op.args[0]
+            has_separate_requantize = True
+        if op.op.name == "reshape":
             has_reshaped_output = True
             op = op.args[0]
 
@@ -951,20 +962,30 @@ def test_tflite_binary_elemwise_legalize(
         assert op.checked_type.dtype == dtype
         assert op.attrs.operator_type == operator_type
         assert op.attrs.reversed_operands == reversed_operands
-        if activation_function == "RELU":
+        if activation_function != None:
             assert str(op.attrs.activation) == "CLIP"
 
             if operator_type in ["MIN", "MAX"]:
-                # MIN and MAX with an activation must have a requantize operation
-                # baked into the output. To check the extra requantize node was
-                # picked up by the pattern, we can make sure the quantization
-                # information is not default.
-                assert float(op.attrs.ifm_scale) != 1.0
-                assert int(op.attrs.ifm_zero_point) != 0
-                assert float(op.attrs.ifm2_scale) != 1.0
-                assert int(op.attrs.ifm2_zero_point) != 0
-                assert float(op.attrs.ofm_scale) != 1.0
-                assert int(op.attrs.ofm_zero_point) != 0
+                if has_separate_requantize:
+                    # In case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints
+                    # there should be default quantization values since requantize is separate operation.
+                    assert float(op.attrs.ifm_scale) == 1.0
+                    assert int(op.attrs.ifm_zero_point) == 0
+                    assert float(op.attrs.ifm2_scale) == 1.0
+                    assert int(op.attrs.ifm2_zero_point) == 0
+                    assert float(op.attrs.ofm_scale) == 1.0
+                    assert int(op.attrs.ofm_zero_point) == 0
+                else:
+                    # MIN and MAX with an activation must have a requantize operation
+                    # baked into the output. To check the extra requantize node was
+                    # picked up by the pattern, we can make sure the quantization
+                    # information is not default.
+                    assert float(op.attrs.ifm_scale) != 1.0
+                    assert int(op.attrs.ifm_zero_point) != 0
+                    assert float(op.attrs.ifm2_scale) != 1.0
+                    assert int(op.attrs.ifm2_zero_point) != 0
+                    assert float(op.attrs.ofm_scale) != 1.0
+                    assert int(op.attrs.ofm_zero_point) != 0
 
         if has_reshaped_output:
             assert list(ext_func.body.checked_type.shape) == out_shape
@@ -997,22 +1018,42 @@ def test_tflite_binary_elemwise_legalize(
             ),
         ]
     elif operator_type == "MIN":
-        rewriter = legalize.MinRewriter()
+        rewriter = [legalize.MinRewriter(), legalize.RequantizeRewriter()]
         pattern_table = [
+            (
+                ethosu.MinParams.composite_name,
+                ethosu.minimum_clip_requantize_pattern(),
+                lambda pat: ethosu.MinParams(pat).is_valid(),
+            ),
             (
                 ethosu.MinParams.composite_name,
                 ethosu.minimum_pattern(),
                 lambda pat: ethosu.MinParams(pat).is_valid(),
             ),
+            (
+                ethosu.RequantizeParams.composite_name,
+                ethosu.requantize_pattern(),
+                lambda pat: ethosu.RequantizeParams(pat).is_valid(),
+            ),
         ]
     elif operator_type == "MAX":
-        rewriter = legalize.MaxRewriter()
+        rewriter = [legalize.MaxRewriter(), legalize.RequantizeRewriter()]
         pattern_table = [
+            (
+                ethosu.MaxParams.composite_name,
+                ethosu.maximum_clip_requantize_pattern(),
+                lambda pat: ethosu.MaxParams(pat).is_valid(),
+            ),
             (
                 ethosu.MaxParams.composite_name,
                 ethosu.maximum_pattern(),
                 lambda pat: ethosu.MaxParams(pat).is_valid(),
             ),
+            (
+                ethosu.RequantizeParams.composite_name,
+                ethosu.requantize_pattern(),
+                lambda pat: ethosu.RequantizeParams(pat).is_valid(),
+            ),
         ]
 
     tflite_graph = create_tflite_graph()
@@ -1031,6 +1072,12 @@ def test_tflite_binary_elemwise_legalize(
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+# This test is for checking the case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints.
+def test_tflite_max_relu_n1_to_1_legalize():
+    ifm_shape = [1, 4, 8, 16]
+    test_tflite_binary_elemwise_legalize("MAX", ifm_shape, ifm_shape, False, relu_n1_to_1)
+
+
 def test_binary_add_from_constant_scalar():
     dtype = "uint8"
     ifm_shape = (1, 4, 4, 8)