You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/08 11:25:50 UTC

[tvm] branch main updated: [microNPU] Add support for SIGMOID (#9627)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3371a76  [microNPU] Add support for SIGMOID (#9627)
3371a76 is described below

commit 3371a76400a2ae55b4bb59cf023e013b0b6d919f
Author: Elen Kalda <el...@arm.com>
AuthorDate: Wed Dec 8 11:25:26 2021 +0000

    [microNPU] Add support for SIGMOID (#9627)
    
    Add support for SIGMOID activation function using the lookup
    table mechanism in the NPU.
---
 python/tvm/relay/backend/contrib/ethosu/codegen.py |  2 +-
 .../tvm/relay/backend/contrib/ethosu/legalize.py   | 84 ++++++++++++++++++----
 .../relay/backend/contrib/ethosu/te/convolution.py |  6 +-
 .../relay/backend/contrib/ethosu/te/depthwise.py   |  6 +-
 .../relay/backend/contrib/ethosu/te/identity.py    |  6 +-
 .../tvm/relay/backend/contrib/ethosu/te/pooling.py |  6 +-
 python/tvm/relay/op/contrib/ethosu.py              | 37 ++++++++--
 tests/python/contrib/test_ethosu/test_codegen.py   | 74 +++++--------------
 tests/python/contrib/test_ethosu/test_legalize.py  | 53 ++++++++++++++
 .../contrib/test_ethosu/test_lookup_table.py       |  2 +-
 .../contrib/test_ethosu/test_lut_optimizer.py      |  4 +-
 .../contrib/test_ethosu/test_replace_conv2d.py     |  2 +-
 .../test_ethosu/test_replace_depthwise_conv2d.py   |  2 +-
 13 files changed, 194 insertions(+), 90 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index 002cb4b6..22f248b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -89,7 +89,7 @@ class OptimizeLUTs(ExprMutator):
             not refer to an Op. Else, a new call node with a new operator.
         """
         new_call = call
-        lut_activations = ["TANH", "LUT"]
+        lut_activations = ["TANH", "LUT", "SIGMOID"]
 
         if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
             producer_op = call.args[0]
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index f8beb7f..b2264f3 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
 """A set of passes to legalize some of operations for the NPU"""
-from typing import List, Type
+from typing import List, Type, Callable
 import math
 
 import numpy as np  # type: ignore
@@ -125,15 +125,17 @@ class LegalizeSplit:
         pass
 
 
-def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
-    """Method to calculate the values of the tanh lookup table"""
+def get_lut_from_func(
+    ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float]
+) -> List[int]:
+    """Method to calculate the values of the lookup table based on the calculation function"""
     lut_values = list()
     # Only int8 is currently supported
     dtype = np.int8
     qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
     for x in range(qmin, qmax + 1):
         x_real = ifm_scale * (x - ifm_zp)
-        out_real = math.tanh(x_real)
+        out_real = func(x_real)
         lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
         lut_result = min(qmax, max(qmin, lut_result))
         lut_values.append(lut_result)
@@ -141,16 +143,18 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
     return lut_values
 
 
-class TanhRewriter(DFPatternCallback):
-    """This pass adds tanh as a LUT to the identity operator"""
+class LutActivationRewriter(DFPatternCallback):
+    """A class to create an identity operator with the LUT"""
 
-    def __init__(self):
+    def __init__(
+        self, params_class: Type, activation_type: str, calc_func: Callable[[float], float]
+    ):
         super().__init__(require_type=True, rewrite_once=True)
-        self.pattern = (
-            wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
-        )(wildcard())
+        self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard())
+        self.activation_type = activation_type
+        self.calc_func = calc_func
 
-    def callback(self, pre, post, node_map):
+    def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
         id_input = post.args[0]
 
         quantize_args = post.op.body.args
@@ -161,7 +165,9 @@ class TanhRewriter(DFPatternCallback):
         input_scale = float(dequantize_args[1].data.asnumpy())
         input_zp = int(dequantize_args[2].data.asnumpy())
 
-        lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
+        lut_values = get_lut_from_func(
+            input_scale, input_zp, output_scale, output_zp, self.calc_func
+        )
         lut = relay.const(lut_values, dtype="uint8")
 
         # We baked the requantization into the LUT, so we don't requantize the identity operator
@@ -172,12 +178,21 @@ class TanhRewriter(DFPatternCallback):
             ifm_zero_point=input_zp,
             ofm_scale=input_scale,
             ofm_zero_point=input_zp,
-            activation="TANH",
+            activation=self.activation_type,
         )
 
         return identity
 
 
+class TanhRewriter(LutActivationRewriter):
+    """This pass adds tanh as a LUT to the identity operator"""
+
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.TanhParams, activation_type="TANH", calc_func=math.tanh
+        )
+
+
 @ir.transform.module_pass(opt_level=1)
 class LegalizeTanh:
     """This is the pass that wraps TanhRewriter"""
@@ -194,6 +209,48 @@ class LegalizeTanh:
         pass
 
 
+def sigmoid_calc_func(x: float) -> float:
+    """Function to calculate the values for sigmoid"""
+    # Thse limits are inherited from TFLite
+    upper_limit = 8.0
+    lower_limit = -8.0
+
+    if x <= lower_limit:
+        y = 0.0
+    elif x >= upper_limit:
+        y = 1.0
+    else:
+        y = 1 / (1 + math.exp(-x))
+    return y
+
+
+class SigmoidRewriter(LutActivationRewriter):
+    """This pass adds sigmoid as a LUT for identity op"""
+
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.SigmoidParams,
+            activation_type="SIGMOID",
+            calc_func=sigmoid_calc_func,
+        )
+
+
+@ir.transform.module_pass(opt_level=1)
+class LegalizeSigmoid:
+    """This is the pass that wraps SigmoidRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(SigmoidRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
 class Conv2DRewriter(DFPatternCallback):
     """Convert conv2d related composite functions into ethosu_conv2d operators"""
 
@@ -1196,6 +1253,7 @@ class LegalizeEthosU:
         mod = LegalizeTanh()(mod)
         mod = LegalizeMean()(mod)
         mod = LegalizeConcat()(mod)
+        mod = LegalizeSigmoid()(mod)
         mod = LegalizeReshape()(mod)
         mod = LegalizeStridedSlice()(mod)
         mod = LegalizeNoOps()(mod)
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
index 242c6fe..6e50c6f 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
@@ -140,11 +140,13 @@ def conv2d_compute(
         "dilation_w": dilation_w,
     }
 
+    has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
     # This is a trick to insert the LUT tensor into the TE graph if LUT is present
-    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
 
     # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
-    if activation in ("TANH", "LUT"):
+    if has_lut:
         conv2d_attrs["lut"] = lut
 
     conv = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
index c9a88e8..f54f2f3 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
@@ -139,11 +139,13 @@ def depthwise_conv2d_compute(
         "dilation_w": dilation_w,
     }
 
+    has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
     # This is a trick to insert the LUT tensor into the TE graph if LUT is present
-    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
 
     # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
-    if activation in ("TANH", "LUT"):
+    if has_lut:
         depthwise_conv2d_attrs["lut"] = lut
 
     depthwise = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py
index 574fc66..271ca15 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py
@@ -61,11 +61,13 @@ def identity_compute(
     dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
     id_attrs = {"op": "ethosu_identity", "activation": activation}
 
+    has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
     # This is a trick to insert the LUT tensor into the TE graph if LUT is present
-    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
 
     # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
-    if activation in ("TANH", "LUT"):
+    if has_lut:
         id_attrs["lut"] = lut
 
     identity = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
index 2ab0844..e98a72d 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
@@ -123,11 +123,13 @@ def pooling_compute(
         "upscale": upscale,
     }
 
+    has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
     # This is a trick to insert the LUT tensor into the TE graph if LUT is present
-    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+    lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
 
     # Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
-    if activation in ("TANH", "LUT"):
+    if has_lut:
         pooling_attrs["lut"] = lut
 
     pooling = te.compute(
diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index bf9e3f8..a7d3da3 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
     return pattern
 
 
-class TanhParams:
+class LutActivationParams:
     """
-    This class will parse a call to a ethos-u.tanh composite function
-    and extract the parameter information.
+    A parent class for LUT based activation functions that extract the input and
+    output tensors and check whether they are valid.
     """
 
-    composite_name = "ethos-u.tanh"
-
     def __init__(self, func_body: Call):
         self.ofm = TensorParams(func_body)
         self.ifm = TensorParams(func_body.args[0].args[0].args[0])
 
     def is_valid(self):
         """
-        This function checks whether reshape has compatible attributes with the NPU
+        This function checks whether activation has compatible attributes with the NPU
         """
         if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
             return False
         return True
 
 
+class TanhParams(LutActivationParams):
+
+    composite_name = "ethos-u.tanh"
+
+
 def tanh_pattern():
     """Create pattern for tanh"""
     dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
@@ -947,6 +950,23 @@ def tanh_pattern():
     return quant
 
 
+class SigmoidParams(LutActivationParams):
+    """
+    This class will parse a call to a ethos-u.sigmoid composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.sigmoid"
+
+
+def sigmoid_pattern():
+    """Create pattern for sigmoid"""
+    dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
+    sigmoid = is_op("sigmoid")(dequant)
+    quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant())
+    return quant
+
+
 class MeanParams:
     """
     This class will parse a call to ethosu.mean composite function
@@ -1162,6 +1182,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
             lambda pat: MeanParams(pat).is_valid(),
         ),
         (ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()),
+        (
+            SigmoidParams.composite_name,
+            sigmoid_pattern(),
+            lambda pat: SigmoidParams(pat).is_valid(),
+        ),
     ]
 
 
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 0e55487..21e86c8 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -815,66 +815,14 @@ def test_ethosu_clz(accel_type):
 
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 def test_tflite_tanh(accel_type):
-    dtype = "int8"
     ifm_shape = [1, 115, 32, 7]
 
-    def create_tflite_graph():
-        class Model(tf.Module):
-            @tf.function
-            def tanh_function(self, x):
-                op = tf.nn.tanh(x)
-                return op
-
-        model = Model()
-        concrete_func = model.tanh_function.get_concrete_function(
-            tf.TensorSpec(ifm_shape, dtype=tf.float32)
-        )
-
-        # Convert the model
-        def representative_dataset():
-            for _ in range(100):
-                data = np.random.rand(*tuple(ifm_shape))
-                yield [data.astype(np.float32)]
-
-        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
-        converter.optimizations = [tf.lite.Optimize.DEFAULT]
-        converter.representative_dataset = representative_dataset
-        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-        converter.inference_input_type = tf.int8
-        converter.inference_output_type = tf.int8
-        tflite_model = converter.convert()
-        return tflite_model
-
-    tflite_graph = create_tflite_graph()
-
-    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
-    relay_module, params = relay.frontend.from_tflite(
-        tflite_model,
-        shape_dict={"input": ifm_shape},
-        dtype_dict={"input": dtype},
-    )
-    mod = partition_for_ethosu(relay_module, params)
-
-    # Generate reference data
-    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+    @tf.function
+    def tanh_func(x):
+        op = tf.nn.tanh(x)
+        return op
 
-    compiled_models = infra.build_source(
-        mod,
-        input_data,
-        output_data,
-        accel_type,
-    )
-
-    # Assumes only two runtime.Modules are created -- i.e. single offload module
-    ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
-    # Verify generated C source
-    get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
-    compilation_artifacts = get_artifacts(ethosu_module)
-    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
-    infra.print_payload(cmms)
-    infra.verify_source(compiled_models, accel_type)
+    _compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)
 
 
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -896,5 +844,17 @@ def test_tflite_concat(shapes, axis, accel_type):
     _compare_tvm_with_tflite(concat_func, shapes, accel_type)
 
 
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+def test_tflite_sigmoid(accel_type):
+    ifm_shape = [1, 135, 41, 6]
+
+    @tf.function
+    def sigmoid_function(x):
+        op = tf.nn.sigmoid(x)
+        return op
+
+    _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 59bcf13..946aa95 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -1297,5 +1297,58 @@ def test_tflite_concat_legalize(shapes, axis):
     ]
 
 
+def test_tflite_sigmoid_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 237, 91, 7)
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def sigmoid_func(self, x):
+                op = tf.math.sigmoid(x)
+                return op
+
+        model = Model()
+        concrete_func = model.sigmoid_func.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_output_type = tf.int8
+        converter.inference_input_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod = ethosu.partition_for_ethosu(mod, params)
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.SigmoidRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    mod = relay.transform.InferType()(mod)
+
+    func_body = mod["tvmgen_default_ethos_u_main_0"].body
+    assert func_body.op.name == "contrib.ethosu.identity"
+    assert func_body.attrs.activation == "SIGMOID"
+    assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
+    assert tuple(func_body.args[1].checked_type.shape) == (256,)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py
index 9485b4f..90a51c5 100644
--- a/tests/python/contrib/test_ethosu/test_lookup_table.py
+++ b/tests/python/contrib/test_ethosu/test_lookup_table.py
@@ -59,7 +59,7 @@ def test_tflite_lut_activations(accel_type):
                 op = tf.nn.depthwise_conv2d(
                     op, weight2, strides=(1, 1, 1, 1), padding="VALID", dilations=(2, 2)
                 )
-                op = tf.nn.tanh(op)
+                op = tf.nn.sigmoid(op)
                 op = tf.nn.max_pool(op, (1, 1), strides=(1, 1, 1, 1), padding="SAME")
                 op = tf.nn.tanh(op)
                 return op
diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
index 8b406d1..16835ce 100644
--- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
@@ -39,7 +39,7 @@ def test_merge_lut_into_conv():
         conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
         id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH")
         conv2 = infra.make_ethosu_conv2d(id1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1))
-        id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="TANH")
+        id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="SIGMOID")
 
         func = relay.Function(relay.analysis.free_vars(id2), id2)
         mod = tvm.IRModule.from_expr(func)
@@ -50,7 +50,7 @@ def test_merge_lut_into_conv():
             ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH"
         )
         conv2 = infra.make_ethosu_conv2d(
-            conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="TANH"
+            conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="SIGMOID"
         )
 
         func = relay.Function(relay.analysis.free_vars(conv2), conv2)
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index 2f2cd7a..7b09fb2 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -32,7 +32,7 @@ from .infra import make_ethosu_conv2d, get_convolutional_args
         [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"],
         [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"],
         [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TRUNCATE"],
-        [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "TFL"],
+        [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "TFL"],
         [
             (1, 8, 2, 8, 16),
             18,
diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
index afd632c..edbfb49 100644
--- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
@@ -33,7 +33,7 @@ from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args
         [(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"],
         [(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "TRUNCATE"],
         [(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"],
-        [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "NATURAL"],
+        [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "NATURAL"],
         [
             (1, 8, 2, 8, 16),
             18,