You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/01 08:58:55 UTC

[GitHub] [tvm] shingjan opened a new pull request #9620: [TVMScript] Add for loop syntax sugar

shingjan opened a new pull request #9620:
URL: https://github.com/apache/tvm/pull/9620


   This PR intends to add syntax sugar for For loop scope handler in TIR.
   Before this PR:
   ```
   for i in T.serial(0, 128)
   for i in T.parallel(0, 128)
   for i in T.vectorized(0, 128)
   for i in T.unroll(0, 128)
   for i in T.thread_binding(0, 128, thread="threadIdx.x")
   ```
   After this PR, note that the starting 0 can be omitted:
   ```
   for i in T.serial(128)
   for i in T.parallel(128)
   for i in T.vectorized(128)
   for i in T.unroll(128)
   for i in T.thread_binding(128, thread="threadIdx.x")
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r761326040



##########
File path: python/tvm/script/tir/scope_handler.py
##########
@@ -560,10 +572,19 @@ class ThreadBinding(ForScopeHandler):
     def __init__(self):
         def thread_binding(
             begin: PrimExpr,
-            end: PrimExpr,
-            thread: str,
+            end: PrimExpr = None,
+            thread: str = None,
             annotations: Optional[Mapping[str, Object]] = None,
         ):
+            if not thread:

Review comment:
       You need to explicit check `thread is not None` (same for `end`), see the CI error




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #9620:
URL: https://github.com/apache/tvm/pull/9620


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r763306526



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -319,28 +319,28 @@ Scope handler - Loops
 
 def serial(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,

Review comment:
       Could we use overload to make it more clear?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r760813705



##########
File path: python/tvm/tir/utils.py
##########
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
+"""Common utility functions in TVM tir"""
+import inspect
+import tvm
+from tvm.ir.diagnostics import override_renderer
+
+
+def check_error(func, rel_lineno):

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r764270200



##########
File path: tests/python/unittest/test_tvmscript_syntax_sugar.py
##########
@@ -62,5 +63,43 @@ def test_reads_writes_syntax_sugar():
     assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar)
 
 
+@T.prim_func
+def loop_no_syntax_sugar(a: T.handle) -> None:
+    A = T.match_buffer(a, (128, 128, 128, 128))
+    for i in T.serial(0, 128):
+        for j in T.parallel(0, 128):
+            for k in T.vectorized(0, 128):
+                for x in T.unroll(0, 128):
+                    for y in T.thread_binding(0, 128, thread="threadIdx.x"):
+                        for z in T.thread_binding(0, 128, thread="threadIdx.x"):
+                            A[i, j, k, x] = A[i, j, k, x] * 2.0
+
+
+@T.prim_func
+def loop_syntax_sugar(a: T.handle) -> None:
+    A = T.match_buffer(a, (128, 128, 128, 128))
+    for i in T.serial(128):
+        for j in T.parallel(128):
+            for k in T.vectorized(128):
+                for x in T.unroll(128):
+                    for y in T.thread_binding(128, "threadIdx.x"):
+                        for z in T.thread_binding(128, thread="threadIdx.x"):
+                            A[i, j, k, x] = A[i, j, k, x] * 2.0
+
+
+def loop_syntax_sugar_fail(a: T.handle) -> None:

Review comment:
       decorator is still missing for this one, it is important to make sure the test case passes locally




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r763325151



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -316,36 +316,72 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr:
 """
 Scope handler - Loops
 """
-
+@overload
 def serial(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
+def serial(
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
 def parallel(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
+def parallel(
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
 def vectorized(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
+def vectorized(
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
 def unroll(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
+def unroll(
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
 def thread_binding(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    thread: str = None,

Review comment:
       given we have overloads, default value `thread = None` is not needed 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#issuecomment-987171688


   Let's improve the printer in a subsequent PR


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r761326040



##########
File path: python/tvm/script/tir/scope_handler.py
##########
@@ -560,10 +572,19 @@ class ThreadBinding(ForScopeHandler):
     def __init__(self):
         def thread_binding(
             begin: PrimExpr,
-            end: PrimExpr,
-            thread: str,
+            end: PrimExpr = None,
+            thread: str = None,
             annotations: Optional[Mapping[str, Object]] = None,
         ):
+            if not thread:

Review comment:
       You need to explicit check `thread is None` (same for `end`), see the CI error




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#issuecomment-987459102


   @shingjan the other PR is merged, please rebase this one


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r764315817



##########
File path: tests/python/unittest/test_tvmscript_syntax_sugar.py
##########
@@ -62,5 +63,43 @@ def test_reads_writes_syntax_sugar():
     assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar)
 
 
+@T.prim_func
+def loop_no_syntax_sugar(a: T.handle) -> None:
+    A = T.match_buffer(a, (128, 128, 128, 128))
+    for i in T.serial(0, 128):
+        for j in T.parallel(0, 128):
+            for k in T.vectorized(0, 128):
+                for x in T.unroll(0, 128):
+                    for y in T.thread_binding(0, 128, thread="threadIdx.x"):
+                        for z in T.thread_binding(0, 128, thread="threadIdx.x"):
+                            A[i, j, k, x] = A[i, j, k, x] * 2.0
+
+
+@T.prim_func
+def loop_syntax_sugar(a: T.handle) -> None:
+    A = T.match_buffer(a, (128, 128, 128, 128))
+    for i in T.serial(128):
+        for j in T.parallel(128):
+            for k in T.vectorized(128):
+                for x in T.unroll(128):
+                    for y in T.thread_binding(128, "threadIdx.x"):
+                        for z in T.thread_binding(128, thread="threadIdx.x"):
+                            A[i, j, k, x] = A[i, j, k, x] * 2.0
+
+
+def loop_syntax_sugar_fail(a: T.handle) -> None:

Review comment:
       This is intentional as `check_error` will append the decorator in the source code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#issuecomment-987842488


   @vinx13 this PR is rebased


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r761439264



##########
File path: python/tvm/script/tir/scope_handler.py
##########
@@ -560,10 +572,19 @@ class ThreadBinding(ForScopeHandler):
     def __init__(self):
         def thread_binding(
             begin: PrimExpr,
-            end: PrimExpr,
-            thread: str,
+            end: PrimExpr = None,
+            thread: str = None,
             annotations: Optional[Mapping[str, Object]] = None,
         ):
+            if not thread:

Review comment:
       fixed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r760452389



##########
File path: python/tvm/script/parser.py
##########
@@ -272,6 +272,13 @@ def parse_arg_list(self, func, node_call):
                 f"but it is {type(func).__name__}",
                 node_call.span,
             )
+
+        # for loop syntax sugar, check if starting 0 is omitted
+        # param_list[0] is the list of positional args which could include kw_args
+        # therefore the sum of len(args) and len(kw_args) is used here
+        if isinstance(func, ForScopeHandler) and len(args) + len(kw_args) < len(param_list[0]):
+            if args[0] != 0:
+                args.insert(0, 0)

Review comment:
       if possible I'd prefer implementing such logic in `ForScopeHandler` (or its subclasses like `Parallel`, `Vectorized`, etc.) to separate the details of these statements from the parser. see the example here https://github.com/apache/tvm/blob/main/python/tvm/script/tir/scope_handler.py#L591-L593




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r760734868



##########
File path: python/tvm/tir/utils.py
##########
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
+"""Common utility functions in TVM tir"""
+import inspect
+import tvm
+from tvm.ir.diagnostics import override_renderer
+
+
+def check_error(func, rel_lineno):

Review comment:
       Let's move test only util file to tvm/testing




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r760686863



##########
File path: python/tvm/script/parser.py
##########
@@ -272,6 +272,13 @@ def parse_arg_list(self, func, node_call):
                 f"but it is {type(func).__name__}",
                 node_call.span,
             )
+
+        # for loop syntax sugar, check if starting 0 is omitted
+        # param_list[0] is the list of positional args which could include kw_args
+        # therefore the sum of len(args) and len(kw_args) is used here
+        if isinstance(func, ForScopeHandler) and len(args) + len(kw_args) < len(param_list[0]):
+            if args[0] != 0:
+                args.insert(0, 0)

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9620: [TVMScript] Add for loop syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9620:
URL: https://github.com/apache/tvm/pull/9620#discussion_r763359829



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -316,36 +316,72 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr:
 """
 Scope handler - Loops
 """
-
+@overload
 def serial(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
+def serial(
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
 def parallel(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
+def parallel(
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
 def vectorized(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
+def vectorized(
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
 def unroll(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
     annotations: Optional[Mapping[str, Object]] = None,
 ) -> Iterable[IterVar]: ...
+@overload
+def unroll(
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+@overload
 def thread_binding(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int] = None,
+    end: Union[PrimExpr, int],
+    thread: str = None,

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -319,28 +319,28 @@ Scope handler - Loops
 
 def serial(
     begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org