You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/04/07 07:12:33 UTC

[tvm] branch main updated: [CI] Updated argument parsing of optional arguments in ci.py (#10906)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new af8569c913 [CI] Updated argument parsing of optional arguments in ci.py (#10906)
af8569c913 is described below

commit af8569c9137b379c14592a5a32035a815b38dbf9
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Apr 7 02:12:27 2022 -0500

    [CI] Updated argument parsing of optional arguments in ci.py (#10906)
    
    * [CI] Updated argument parsing of optional arguments in ci.py
    
    Previously, optional arguments were identified by comparing the string
    `"typing.Optional"`.  This misses some cases, as `Optional[int]`
    expands to `Union[int, NoneType]`.  This commit updates the check to
    identify `typing.Union` annotations where one of the types is
    `NoneType`.
    
    * Bugfix, correctly handle type annotations outside of `typing.*`
---
 tests/scripts/ci.py | 35 +++++++++++++++++++++++++++++------
 1 file changed, 29 insertions(+), 6 deletions(-)

diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py
old mode 100644
new mode 100755
index 25c67ec6f1..5f2034b190
--- a/tests/scripts/ci.py
+++ b/tests/scripts/ci.py
@@ -32,6 +32,8 @@ import random
 import subprocess
 import platform
 import textwrap
+import typing
+
 from pathlib import Path
 from typing import List, Dict, Any, Optional, Tuple, Callable, Union
 
@@ -434,6 +436,28 @@ def cli_name(s: str) -> str:
     return s.replace("_", "-")
 
 
+def typing_get_origin(annotation):
+    if sys.version_info >= (3, 8):
+        return typing.get_origin(annotation)
+    else:
+        return annotation.__origin__
+
+
+def typing_get_args(annotation):
+    if sys.version_info >= (3, 8):
+        return typing.get_args(annotation)
+    else:
+        return annotation.__args__
+
+
+def is_optional_type(annotation):
+    return (
+        hasattr(annotation, "__origin__")
+        and (typing_get_origin(annotation) == typing.Union)
+        and (type(None) in typing_get_args(annotation))
+    )
+
+
 def add_subparser(
     func: Callable,
     subparsers: Any,
@@ -479,12 +503,11 @@ def add_subparser(
         arg_cli_name = cli_name(name)
         kwargs: Dict[str, Union[str, bool]] = {"help": arg_help_texts[arg_cli_name]}
 
-        arg_type = value.annotation
-        is_optional = False
-        if str(value.annotation).startswith("typing.Optional"):
-
-            is_optional = True
-            arg_type = value.annotation.__args__[0]
+        is_optional = is_optional_type(value.annotation)
+        if is_optional:
+            arg_type = typing_get_args(value.annotation)[0]
+        else:
+            arg_type = value.annotation
 
         # Grab the default value if present
         has_default = False