You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/24 01:45:33 UTC

[GitHub] [tvm] vinx13 commented on a change in pull request #9492: [TVMScript] Add syntax sugar for T.handle and T.match_buffer

vinx13 commented on a change in pull request #9492:
URL: https://github.com/apache/tvm/pull/9492#discussion_r755632879



##########
File path: python/tvm/script/parser.py
##########
@@ -1047,6 +1056,91 @@ def transform_TypeConstant(self, node):
         """
         return node.value
 
+    def transform_TypeTuple(self, node):
+        return node.values
+
+    def transform_TypeCall(self, node):
+        """Call value visitor for TypeCall.
+
+        This method is for syntax sugar of T.match_buffer()
+        """
+
+        def parse_typecall_params(func, params, keyword_params):
+            args = []
+            for arg in params:
+                if isinstance(arg, ast.TypeTuple):
+                    values = []
+                    for value in self.transform(arg):
+                        values.append(self.transform(value))
+                else:
+                    values = self.transform(arg)
+                args.append(values)
+            kw_args = {}
+            for k, v in keyword_params.items():
+                if isinstance(v, ast.TypeTuple):
+                    values = []
+                    for value in self.transform(v):
+                        values.append(self.transform(value))
+                else:
+                    values = self.transform(v)
+                kw_args[self.transform(k)] = values
+            # get the name and parameter list of func
+            func_name, param_list = func.signature()
+            # check arguments and parameter list and get a list of arguments
+            reader = CallArgumentReader(func_name, args, kw_args, self, node)
+            pos_only, kwargs, varargs = param_list
+            internal_args = list()
+            for i, arg_name in enumerate(pos_only):
+                internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
+            for i, arg_info in enumerate(kwargs):
+                arg_name, default = arg_info
+                internal_args.append(
+                    reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)
+                )
+            return internal_args
+
+        func = self.transform(node.func_name)
+
+        if isinstance(func, SpecialStmt):
+            # parse args and kwargs for TypeCall
+            arg_list = parse_typecall_params(func, node.params, node.keyword_params)
+            buf = func.handle(node, self.context, arg_list, node.func_name.span)
+            return buf
+        self.report_error(
+            "Syntax sugar for T.match_buffer needs to be evaluated into a SpecialStmt.",
+            node.span,
+        )
+
+    def transform_TypeApply(self, node):
+        """Call value visitor for TypeApply.
+
+        This method is for syntax sugar of T.match_buffer()
+        """
+
+        def parse_typeapply_params(params):
+            args = []
+            for arg in params:
+                if isinstance(arg, ast.TypeTuple):
+                    values = []
+                    for value in self.transform(arg):
+                        values.append(self.transform(value))
+                else:
+                    values = self.transform(arg)
+                args.append(values)
+            return args
+
+        func = self.transform(node.func_name)
+
+        if isinstance(func, SpecialStmt):
+            # parse args for TypeApply
+            arg_list = parse_typeapply_params(node.params)
+            buf = func.handle(node, self.context, arg_list, node.func_name.span)
+            return buf
+        self.report_error(
+            "Syntax sugar for T.match_buffer needs to be evaluated into a SpecialStmt.",

Review comment:
       Ditto 

##########
File path: python/tvm/script/parser.py
##########
@@ -1047,6 +1056,91 @@ def transform_TypeConstant(self, node):
         """
         return node.value
 
+    def transform_TypeTuple(self, node):
+        return node.values
+
+    def transform_TypeCall(self, node):
+        """Call value visitor for TypeCall.
+
+        This method is for syntax sugar of T.match_buffer()
+        """
+
+        def parse_typecall_params(func, params, keyword_params):
+            args = []
+            for arg in params:
+                if isinstance(arg, ast.TypeTuple):
+                    values = []
+                    for value in self.transform(arg):
+                        values.append(self.transform(value))
+                else:
+                    values = self.transform(arg)
+                args.append(values)
+            kw_args = {}
+            for k, v in keyword_params.items():
+                if isinstance(v, ast.TypeTuple):
+                    values = []
+                    for value in self.transform(v):
+                        values.append(self.transform(value))
+                else:
+                    values = self.transform(v)
+                kw_args[self.transform(k)] = values
+            # get the name and parameter list of func
+            func_name, param_list = func.signature()
+            # check arguments and parameter list and get a list of arguments
+            reader = CallArgumentReader(func_name, args, kw_args, self, node)
+            pos_only, kwargs, varargs = param_list
+            internal_args = list()
+            for i, arg_name in enumerate(pos_only):
+                internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
+            for i, arg_info in enumerate(kwargs):
+                arg_name, default = arg_info
+                internal_args.append(
+                    reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)
+                )
+            return internal_args
+
+        func = self.transform(node.func_name)
+
+        if isinstance(func, SpecialStmt):
+            # parse args and kwargs for TypeCall
+            arg_list = parse_typecall_params(func, node.params, node.keyword_params)
+            buf = func.handle(node, self.context, arg_list, node.func_name.span)
+            return buf
+        self.report_error(
+            "Syntax sugar for T.match_buffer needs to be evaluated into a SpecialStmt.",

Review comment:
       Though the only usage is match buffer, the visitor here is generic and shouldn’t mention match buffer in error message. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org