You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/10/17 09:19:43 UTC
[tvm] branch main updated: [TVMC] Support dot inside of TVMC input
shape name arguments (#9294)
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2b06ab3 [TVMC] Support dot inside of TVMC input shape name arguments (#9294)
2b06ab3 is described below
commit 2b06ab312fb27c7eb567dfd128a7ee9470e7809e
Author: lixiaoquan <ra...@163.com>
AuthorDate: Sun Oct 17 17:19:05 2021 +0800
[TVMC] Support dot inside of TVMC input shape name arguments (#9294)
* [TVMC] Support dot inside of TVMC input shape name arguments
* dot -> dots
---
python/tvm/driver/tvmc/common.py | 5 +++--
tests/python/driver/tvmc/test_shape_parser.py | 7 +++++++
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py
index f4bc3ec..1ee24cf 100644
--- a/python/tvm/driver/tvmc/common.py
+++ b/python/tvm/driver/tvmc/common.py
@@ -418,7 +418,7 @@ def parse_shape_string(inputs_string):
----------
inputs_string: str
A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that
- indicates the desired shape for specific model inputs. Colons and forward slashes
+ indicates the desired shape for specific model inputs. Colons, forward slashes and dots
within input_names are supported. Spaces are supported inside of dimension arrays.
Returns
@@ -432,7 +432,8 @@ def parse_shape_string(inputs_string):
# * Spaces inside arrays
# * forward slashes inside names (but not at the beginning or end)
# * colons inside names (but not at the beginning or end)
- pattern = r"(?:\w+\/)?[:\w]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
+ # * dots inside names
+ pattern = r"(?:\w+\/)?[:\w.]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]"
input_mappings = re.findall(pattern, inputs_string)
if not input_mappings:
raise argparse.ArgumentTypeError(
diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py
index c021078..f49d89a 100644
--- a/tests/python/driver/tvmc/test_shape_parser.py
+++ b/tests/python/driver/tvmc/test_shape_parser.py
@@ -94,3 +94,10 @@ def test_invalid_colon():
def test_invalid_slashes(shape_string):
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
+
+
+def test_dot():
+ # Check dot in input name
+ shape_string = "input.1:[10,10,10]"
+ shape_dict = tvmc.common.parse_shape_string(shape_string)
+ assert shape_dict == {"input.1": [10, 10, 10]}