You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/10/11 02:54:47 UTC
[tvm] branch unity updated: [Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 67d61935df [Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)
67d61935df is described below
commit 67d61935df054f8113b526464d404ebf4bfaa36f
Author: Thrsu <89...@users.noreply.github.com>
AuthorDate: Wed Oct 11 10:54:40 2023 +0800
[Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)
* Fix MaxPool TypeError
* Add regression test case.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +-
tests/python/relax/test_frontend_onnx.py | 9 +++++++++
2 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fbd478ee5a..5333812c05 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1424,7 +1424,7 @@ class MaxPool(OnnxOpConverter):
dilations = attr.get("dilations", [1, 1])
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", 0)
- strides = attr.get("strides", 1)
+ strides = attr.get("strides", [1, 1])
assert len(kernel_shape) == 2, "Currently only 2D pooling is supported."
assert auto_pad in [
diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py
index d587d70636..a896ebb0b9 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1590,6 +1590,15 @@ def test_max_pool():
strides=[2, 2],
),
)
+ verify_unary(
+ "MaxPool",
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3, 3],
+ pads=None,
+ ),
+ )
def test_global_average_pool():