You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "cyx-6 (via GitHub)" <gi...@apache.org> on 2023/01/20 23:01:35 UTC

[GitHub] [tvm] cyx-6 commented on a diff in pull request #13819: [TVMScript] Implicit root block syntax sugar for TVMScript printer

cyx-6 commented on code in PR #13819:
URL: https://github.com/apache/tvm/pull/13819#discussion_r1083119963


##########
tests/python/unittest/test_tvmscript_printer_tir.py:
##########
@@ -717,21 +717,49 @@ def block_with_remap_explicitly():
 
     expected_output = """@T.prim_func
 def main():
-    with T.block("root"):
-        T.reads()
-        T.writes()
-        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
-            with T.block("update"):
-                v0 = T.axis.spatial(128, i0 + 1)
-                v1, v2 = T.axis.remap("SR", [i1, i2])
-                v3 = T.axis.spatial(128, i3 - 1)
-                v4, v5 = T.axis.remap("RS", [i4, i5])
-                T.reads()
-                T.writes()
-                T.evaluate(0)"""
+    for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+        with T.block("update"):
+            v0 = T.axis.spatial(128, i0 + 1)
+            v1, v2 = T.axis.remap("SR", [i1, i2])
+            v3 = T.axis.spatial(128, i3 - 1)
+            v4, v5 = T.axis.remap("RS", [i4, i5])
+            T.reads()
+            T.writes()
+            T.evaluate(0)"""
     _assert_print(block_with_remap_explicitly, expected_output)
     _assert_print(block_with_remap_implicitly, expected_output)
 
 
+def test_root_block():
+    from tvm.script import tir as T
+
+    @T.prim_func
+    def root_block_implicitly():
+        a = T.alloc_buffer([128, 128])
+        for i, j in T.grid(128, 128):
+            with T.block():
+                T.evaluate(0)
+
+    @T.prim_func
+    def root_block_explicitly():
+        with T.block("root"):
+            a = T.alloc_buffer([128, 128])
+            for i, j in T.grid(128, 128):
+                with T.block():
+                    T.evaluate(0)
+
+    expected_output = """@T.prim_func
+def main():

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org