You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/10/11 02:54:47 UTC

[tvm] branch unity updated: [Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 67d61935df [Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)
67d61935df is described below

commit 67d61935df054f8113b526464d404ebf4bfaa36f
Author: Thrsu <89...@users.noreply.github.com>
AuthorDate: Wed Oct 11 10:54:40 2023 +0800

    [Unity] [Bugfix] Fix MaxPool TypeError in ONNX frontend (#15908)
    
    * Fix MaxPool TypeError
    
    * Add regression test case.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +-
 tests/python/relax/test_frontend_onnx.py        | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fbd478ee5a..5333812c05 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1424,7 +1424,7 @@ class MaxPool(OnnxOpConverter):
         dilations = attr.get("dilations", [1, 1])
         kernel_shape = attr.get("kernel_shape")
         pads = attr.get("pads", 0)
-        strides = attr.get("strides", 1)
+        strides = attr.get("strides", [1, 1])
 
         assert len(kernel_shape) == 2, "Currently only 2D pooling is supported."
         assert auto_pad in [
diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py
index d587d70636..a896ebb0b9 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1590,6 +1590,15 @@ def test_max_pool():
             strides=[2, 2],
         ),
     )
+    verify_unary(
+        "MaxPool",
+        [1, 1, 32, 32],
+        dict(
+            auto_pad="SAME_UPPER",
+            kernel_shape=[3, 3],
+            pads=None,
+        ),
+    )
 
 
 def test_global_average_pool():