You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/08 11:25:50 UTC
[tvm] branch main updated: [microNPU] Add support for SIGMOID (#9627)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3371a76 [microNPU] Add support for SIGMOID (#9627)
3371a76 is described below
commit 3371a76400a2ae55b4bb59cf023e013b0b6d919f
Author: Elen Kalda <el...@arm.com>
AuthorDate: Wed Dec 8 11:25:26 2021 +0000
[microNPU] Add support for SIGMOID (#9627)
Add support for SIGMOID activation function using the lookup
table mechanism in the NPU.
---
python/tvm/relay/backend/contrib/ethosu/codegen.py | 2 +-
.../tvm/relay/backend/contrib/ethosu/legalize.py | 84 ++++++++++++++++++----
.../relay/backend/contrib/ethosu/te/convolution.py | 6 +-
.../relay/backend/contrib/ethosu/te/depthwise.py | 6 +-
.../relay/backend/contrib/ethosu/te/identity.py | 6 +-
.../tvm/relay/backend/contrib/ethosu/te/pooling.py | 6 +-
python/tvm/relay/op/contrib/ethosu.py | 37 ++++++++--
tests/python/contrib/test_ethosu/test_codegen.py | 74 +++++--------------
tests/python/contrib/test_ethosu/test_legalize.py | 53 ++++++++++++++
.../contrib/test_ethosu/test_lookup_table.py | 2 +-
.../contrib/test_ethosu/test_lut_optimizer.py | 4 +-
.../contrib/test_ethosu/test_replace_conv2d.py | 2 +-
.../test_ethosu/test_replace_depthwise_conv2d.py | 2 +-
13 files changed, 194 insertions(+), 90 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index 002cb4b6..22f248b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -89,7 +89,7 @@ class OptimizeLUTs(ExprMutator):
not refer to an Op. Else, a new call node with a new operator.
"""
new_call = call
- lut_activations = ["TANH", "LUT"]
+ lut_activations = ["TANH", "LUT", "SIGMOID"]
if isinstance(call.op, tvm.ir.Op) and isinstance(call.args[0], tvm.relay.expr.Call):
producer_op = call.args[0]
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index f8beb7f..b2264f3 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
"""A set of passes to legalize some of operations for the NPU"""
-from typing import List, Type
+from typing import List, Type, Callable
import math
import numpy as np # type: ignore
@@ -125,15 +125,17 @@ class LegalizeSplit:
pass
-def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
- """Method to calculate the values of the tanh lookup table"""
+def get_lut_from_func(
+ ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float]
+) -> List[int]:
+ """Method to calculate the values of the lookup table based on the calculation function"""
lut_values = list()
# Only int8 is currently supported
dtype = np.int8
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
- out_real = math.tanh(x_real)
+ out_real = func(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)
@@ -141,16 +143,18 @@ def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp):
return lut_values
-class TanhRewriter(DFPatternCallback):
- """This pass adds tanh as a LUT to the identity operator"""
+class LutActivationRewriter(DFPatternCallback):
+ """A class to create an identity operator with the LUT"""
- def __init__(self):
+ def __init__(
+ self, params_class: Type, activation_type: str, calc_func: Callable[[float], float]
+ ):
super().__init__(require_type=True, rewrite_once=True)
- self.pattern = (
- wildcard().has_attr({"Composite": ethosu_patterns.TanhParams.composite_name})
- )(wildcard())
+ self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard())
+ self.activation_type = activation_type
+ self.calc_func = calc_func
- def callback(self, pre, post, node_map):
+ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
id_input = post.args[0]
quantize_args = post.op.body.args
@@ -161,7 +165,9 @@ class TanhRewriter(DFPatternCallback):
input_scale = float(dequantize_args[1].data.asnumpy())
input_zp = int(dequantize_args[2].data.asnumpy())
- lut_values = find_tanh_values(input_scale, input_zp, output_scale, output_zp)
+ lut_values = get_lut_from_func(
+ input_scale, input_zp, output_scale, output_zp, self.calc_func
+ )
lut = relay.const(lut_values, dtype="uint8")
# We baked the requantization into the LUT, so we don't requantize the identity operator
@@ -172,12 +178,21 @@ class TanhRewriter(DFPatternCallback):
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
- activation="TANH",
+ activation=self.activation_type,
)
return identity
+class TanhRewriter(LutActivationRewriter):
+ """This pass adds tanh as a LUT to the identity operator"""
+
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.TanhParams, activation_type="TANH", calc_func=math.tanh
+ )
+
+
@ir.transform.module_pass(opt_level=1)
class LegalizeTanh:
"""This is the pass that wraps TanhRewriter"""
@@ -194,6 +209,48 @@ class LegalizeTanh:
pass
+def sigmoid_calc_func(x: float) -> float:
+ """Function to calculate the values for sigmoid"""
+ # Thse limits are inherited from TFLite
+ upper_limit = 8.0
+ lower_limit = -8.0
+
+ if x <= lower_limit:
+ y = 0.0
+ elif x >= upper_limit:
+ y = 1.0
+ else:
+ y = 1 / (1 + math.exp(-x))
+ return y
+
+
+class SigmoidRewriter(LutActivationRewriter):
+ """This pass adds sigmoid as a LUT for identity op"""
+
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.SigmoidParams,
+ activation_type="SIGMOID",
+ calc_func=sigmoid_calc_func,
+ )
+
+
+@ir.transform.module_pass(opt_level=1)
+class LegalizeSigmoid:
+ """This is the pass that wraps SigmoidRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(SigmoidRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d operators"""
@@ -1196,6 +1253,7 @@ class LegalizeEthosU:
mod = LegalizeTanh()(mod)
mod = LegalizeMean()(mod)
mod = LegalizeConcat()(mod)
+ mod = LegalizeSigmoid()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
index 242c6fe..6e50c6f 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py
@@ -140,11 +140,13 @@ def conv2d_compute(
"dilation_w": dilation_w,
}
+ has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
- lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+ lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
- if activation in ("TANH", "LUT"):
+ if has_lut:
conv2d_attrs["lut"] = lut
conv = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
index c9a88e8..f54f2f3 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py
@@ -139,11 +139,13 @@ def depthwise_conv2d_compute(
"dilation_w": dilation_w,
}
+ has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
- lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+ lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
- if activation in ("TANH", "LUT"):
+ if has_lut:
depthwise_conv2d_attrs["lut"] = lut
depthwise = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/identity.py b/python/tvm/relay/backend/contrib/ethosu/te/identity.py
index 574fc66..271ca15 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/identity.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/identity.py
@@ -61,11 +61,13 @@ def identity_compute(
dmaed_ifm = read_compute(ifm, ifm_zero_point, ifm_scale)
id_attrs = {"op": "ethosu_identity", "activation": activation}
+ has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
- lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+ lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
- if activation in ("TANH", "LUT"):
+ if has_lut:
id_attrs["lut"] = lut
identity = te.compute(
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
index 2ab0844..e98a72d 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
@@ -123,11 +123,13 @@ def pooling_compute(
"upscale": upscale,
}
+ has_lut = activation in ("TANH", "LUT", "SIGMOID")
+
# This is a trick to insert the LUT tensor into the TE graph if LUT is present
- lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if activation in ("TANH", "LUT") else 0
+ lut_expr = (lut[0] + lut[255]).astype(ifm.dtype) if has_lut else 0
# Add the LUT tensor to the attributes to be able to later tell which tensor is the LUT
- if activation in ("TANH", "LUT"):
+ if has_lut:
pooling_attrs["lut"] = lut
pooling = te.compute(
diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py
index bf9e3f8..a7d3da3 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return pattern
-class TanhParams:
+class LutActivationParams:
"""
- This class will parse a call to a ethos-u.tanh composite function
- and extract the parameter information.
+ A parent class for LUT based activation functions that extract the input and
+ output tensors and check whether they are valid.
"""
- composite_name = "ethos-u.tanh"
-
def __init__(self, func_body: Call):
self.ofm = TensorParams(func_body)
self.ifm = TensorParams(func_body.args[0].args[0].args[0])
def is_valid(self):
"""
- This function checks whether reshape has compatible attributes with the NPU
+ This function checks whether activation has compatible attributes with the NPU
"""
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
return True
+class TanhParams(LutActivationParams):
+
+ composite_name = "ethos-u.tanh"
+
+
def tanh_pattern():
"""Create pattern for tanh"""
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
@@ -947,6 +950,23 @@ def tanh_pattern():
return quant
+class SigmoidParams(LutActivationParams):
+ """
+ This class will parse a call to a ethos-u.sigmoid composite function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethos-u.sigmoid"
+
+
+def sigmoid_pattern():
+ """Create pattern for sigmoid"""
+ dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
+ sigmoid = is_op("sigmoid")(dequant)
+ quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant())
+ return quant
+
+
class MeanParams:
"""
This class will parse a call to ethosu.mean composite function
@@ -1162,6 +1182,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
lambda pat: MeanParams(pat).is_valid(),
),
(ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()),
+ (
+ SigmoidParams.composite_name,
+ sigmoid_pattern(),
+ lambda pat: SigmoidParams(pat).is_valid(),
+ ),
]
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 0e55487..21e86c8 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -815,66 +815,14 @@ def test_ethosu_clz(accel_type):
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
def test_tflite_tanh(accel_type):
- dtype = "int8"
ifm_shape = [1, 115, 32, 7]
- def create_tflite_graph():
- class Model(tf.Module):
- @tf.function
- def tanh_function(self, x):
- op = tf.nn.tanh(x)
- return op
-
- model = Model()
- concrete_func = model.tanh_function.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
- )
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32)]
-
- converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph()
-
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"input": ifm_shape},
- dtype_dict={"input": dtype},
- )
- mod = partition_for_ethosu(relay_module, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+ @tf.function
+ def tanh_func(x):
+ op = tf.nn.tanh(x)
+ return op
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload module
- ethosu_module = compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts = tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ _compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -896,5 +844,17 @@ def test_tflite_concat(shapes, axis, accel_type):
_compare_tvm_with_tflite(concat_func, shapes, accel_type)
+@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
+def test_tflite_sigmoid(accel_type):
+ ifm_shape = [1, 135, 41, 6]
+
+ @tf.function
+ def sigmoid_function(x):
+ op = tf.nn.sigmoid(x)
+ return op
+
+ _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py
index 59bcf13..946aa95 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -1297,5 +1297,58 @@ def test_tflite_concat_legalize(shapes, axis):
]
+def test_tflite_sigmoid_legalize():
+ dtype = "int8"
+ ifm_shape = (1, 237, 91, 7)
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def sigmoid_func(self, x):
+ op = tf.math.sigmoid(x)
+ return op
+
+ model = Model()
+ concrete_func = model.sigmoid_func.get_concrete_function(
+ tf.TensorSpec(ifm_shape, dtype=tf.float32)
+ )
+
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ yield [data.astype(np.float32)]
+
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_output_type = tf.int8
+ converter.inference_input_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"input": ifm_shape},
+ dtype_dict={"input": dtype},
+ )
+
+ mod = ethosu.partition_for_ethosu(mod, params)
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ legalize.SigmoidRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+ )
+ mod = relay.transform.InferType()(mod)
+
+ func_body = mod["tvmgen_default_ethos_u_main_0"].body
+ assert func_body.op.name == "contrib.ethosu.identity"
+ assert func_body.attrs.activation == "SIGMOID"
+ assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
+ assert tuple(func_body.args[1].checked_type.shape) == (256,)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_lookup_table.py b/tests/python/contrib/test_ethosu/test_lookup_table.py
index 9485b4f..90a51c5 100644
--- a/tests/python/contrib/test_ethosu/test_lookup_table.py
+++ b/tests/python/contrib/test_ethosu/test_lookup_table.py
@@ -59,7 +59,7 @@ def test_tflite_lut_activations(accel_type):
op = tf.nn.depthwise_conv2d(
op, weight2, strides=(1, 1, 1, 1), padding="VALID", dilations=(2, 2)
)
- op = tf.nn.tanh(op)
+ op = tf.nn.sigmoid(op)
op = tf.nn.max_pool(op, (1, 1), strides=(1, 1, 1, 1), padding="SAME")
op = tf.nn.tanh(op)
return op
diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
index 8b406d1..16835ce 100644
--- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
@@ -39,7 +39,7 @@ def test_merge_lut_into_conv():
conv1 = infra.make_ethosu_conv2d(ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1))
id1 = infra.make_ethosu_identity(conv1, lut=lut1, activation="TANH")
conv2 = infra.make_ethosu_conv2d(id1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1))
- id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="TANH")
+ id2 = infra.make_ethosu_identity(conv2, lut=lut2, activation="SIGMOID")
func = relay.Function(relay.analysis.free_vars(id2), id2)
mod = tvm.IRModule.from_expr(func)
@@ -50,7 +50,7 @@ def test_merge_lut_into_conv():
ifm, 4, 4, (3, 3), (1, 1), (1, 1), (1, 1), lut=lut1, activation="TANH"
)
conv2 = infra.make_ethosu_conv2d(
- conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="TANH"
+ conv1, 4, 7, (2, 2), (1, 1), (1, 1), (1, 1), lut=lut2, activation="SIGMOID"
)
func = relay.Function(relay.analysis.free_vars(conv2), conv2)
diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
index 2f2cd7a..7b09fb2 100644
--- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py
@@ -32,7 +32,7 @@ from .infra import make_ethosu_conv2d, get_convolutional_args
[(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"],
[(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"],
[(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TRUNCATE"],
- [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "TFL"],
+ [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "TFL"],
[
(1, 8, 2, 8, 16),
18,
diff --git a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
index afd632c..edbfb49 100644
--- a/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
+++ b/tests/python/contrib/test_ethosu/test_replace_depthwise_conv2d.py
@@ -33,7 +33,7 @@ from .infra import make_ethosu_depthwise_conv2d, get_convolutional_args
[(1, 8, 8, 3), 3, (1, 1), (2, 1), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "NATURAL"],
[(1, 8, 8, 3), 3, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC", "TRUNCATE"],
[(1, 1, 1, 1), 1, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC", "TFL"],
- [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC", "NATURAL"],
+ [(1, 7, 9, 4), 4, (3, 2), (1, 2), (2, 1), (1, 2), "NONE", "NHWC", "NHWC", "NATURAL"],
[
(1, 8, 2, 8, 16),
18,