You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/10/17 09:19:43 UTC

[tvm] branch main updated: [TVMC] Support dot inside of TVMC input shape name arguments (#9294)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b06ab3  [TVMC] Support dot inside of TVMC input shape name arguments (#9294)
2b06ab3 is described below

commit 2b06ab312fb27c7eb567dfd128a7ee9470e7809e
Author: lixiaoquan <ra...@163.com>
AuthorDate: Sun Oct 17 17:19:05 2021 +0800

    [TVMC] Support dot inside of TVMC input shape name arguments (#9294)
    
    * [TVMC] Support dot inside of TVMC input shape name arguments
    
    * dot -> dots
---
 python/tvm/driver/tvmc/common.py              | 5 +++--
 tests/python/driver/tvmc/test_shape_parser.py | 7 +++++++
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index f4bc3ec..1ee24cf 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -418,7 +418,7 @@ def parse_shape_string(inputs_string):
     ----------
     inputs_string: str
         A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that
-        indicates the desired shape for specific model inputs. Colons and forward slashes
+        indicates the desired shape for specific model inputs. Colons, forward slashes and dots
         within input_names are supported. Spaces are supported inside of dimension arrays.
 
     Returns
@@ -432,7 +432,8 @@ def parse_shape_string(inputs_string):
     # * Spaces inside arrays
     # * forward slashes inside names (but not at the beginning or end)
     # * colons inside names (but not at the beginning or end)
-    pattern = r"(?:\w+\/)?[:\w]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
+    # * dots inside names
+    pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
     input_mappings = re.findall(pattern, inputs_string)
     if not input_mappings:
         raise argparse.ArgumentTypeError(
diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py
index c021078..f49d89a 100644
--- a/tests/python/driver/tvmc/test_shape_parser.py
+++ b/tests/python/driver/tvmc/test_shape_parser.py
@@ -94,3 +94,10 @@ def test_invalid_colon():
 def test_invalid_slashes(shape_string):
     with pytest.raises(argparse.ArgumentTypeError):
         tvmc.common.parse_shape_string(shape_string)
+
+
+def test_dot():
+    # Check dot in input name
+    shape_string = "input.1:[10,10,10]"
+    shape_dict = tvmc.common.parse_shape_string(shape_string)
+    assert shape_dict == {"input.1": [10, 10, 10]}