You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2022/01/18 07:09:24 UTC

[incubator-mxnet] branch master updated: Fix the regular expression in RTC code (#20810)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 59b4c18  Fix the regular expression in RTC code (#20810)
59b4c18 is described below

commit 59b4c188f655ec4596ab9369aaa441672110e064
Author: Przemyslaw Tredak <pt...@nvidia.com>
AuthorDate: Mon Jan 17 23:06:38 2022 -0800

    Fix the regular expression in RTC code (#20810)
---
 python/mxnet/rtc.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/rtc.py b/python/mxnet/rtc.py
index ada4876..f03685a 100644
--- a/python/mxnet/rtc.py
+++ b/python/mxnet/rtc.py
@@ -141,21 +141,22 @@ class CudaModule(object):
         is_ndarray = []
         is_const = []
         dtypes = []
-        pattern = re.compile(r"""^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$""")
+        pattern = re.compile(r"""^(const)?\s?([\w_]+)\s?(\*)?\s?([\w_]+)?$""")
         args = re.sub(r"\s+", " ", signature).split(",")
         for arg in args:
-            match = pattern.match(arg)
+            sanitized_arg = " ".join(arg.split())
+            match = pattern.match(sanitized_arg)
             if not match or match.groups()[1] == 'const':
                 raise ValueError(
                     'Invalid function prototype "%s". Must be in the '
-                    'form of "(const) type (*) (name)"'%arg)
+                    'form of "(const) type (*) (name)"'%sanitized_arg)
             is_const.append(bool(match.groups()[0]))
             dtype = match.groups()[1]
             is_ndarray.append(bool(match.groups()[2]))
             if dtype not in _DTYPE_CPP_TO_NP:
                 raise TypeError(
                     "Unsupported kernel argument type %s. Supported types are: %s."%(
-                        arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
+                        sanitized_arg, ','.join(_DTYPE_CPP_TO_NP.keys())))
             dtypes.append(_DTYPE_NP_TO_MX[_DTYPE_CPP_TO_NP[dtype]])
 
         check_call(_LIB.MXRtcCudaKernelCreate(