You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/09 05:38:11 UTC

[GitHub] [tvm] masahi opened a new pull request #7425: [TIR] Add TIR While node

masahi opened a new pull request #7425:
URL: https://github.com/apache/tvm/pull/7425


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-788690275


   @tqchen @junrushao1994 @vinx13 @ZihengJiang @zxybazh 
   
   I came to a conclusion that While node doesn't need a special handling in `storage_rewrite`.
   
   The first observation is that even if I remove all `ForNode` handling from `StoragePlanRewriter`, all tests in `test_tir_transform_storage_rewrite.py` except [test_parallel_alloc()](https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/tests/python/unittest/test_tir_transform_storage_rewrite.py#L269) pass.
   
   If we look at the visitor for `ForNode`, https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L440-L452
   it only does something special when `attach_map_` has an entry for this node. Here comes the second observation: the only case where`attach_map_` can have an entry for `ForNode` is if this `ForNode` is a parallel for loop, due to these lines: https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L766-L772 
   
   Together, these two handler for `ForNode` lift allocation inside an inner loop and attach merged allocation under the parallel loop scope (via `MakeAttach` function at https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L447). This is tested in `parallel_alloc`. For other kinds of `For` loop, a merged allocation is placed at the global scope, see https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L457-L461.  
   
   Since `While` node doesn't involve threading, I think we can always lift allocation done inside `While` loop into the global scope. That means `WhileNode` should be handled in the same way non-parallel `ForNode` are handled, i.e. we don't need a special handling logic for `WhileNode`.
   
   I think I nailed it, thoughts? 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-783457493


   @vinx13 can you please take another look at the PR and manage?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-784098499


   > Thanks @masahi ! the change has addressed my previous comments. Please add testcases to transforms that touches requires special While handling to cover these passes
   
   Yes, I'm still trying to figure out what `StorageRewrite` is doing. This pass is a beast :slightly_smiling_face: I think it is doing something like storage coalescing etc, and since this is purely for optimization (I think), I'm not sure what the "failure" means in this case.
   
   For example, we should definitely prevent invalid optimization. But I so far I'm having hard timing coming up with an example program that could fail. Anyone have any idea? Another example of failure is a missed optimization, I have an example where `For` loop coalesces but `While` loop does not https://github.com/apache/tvm/pull/7425#issuecomment-779798238
   
   I've also added a non trivial change to `StorageAccessVisitor`, I need to look at this class and its derived classes `ThreadSyncPlanner` etc


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-777974370


   @tqchen Can you have a look?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778449042


   @masahi , it would also be great for you to spend a bit more time to look into these passes :) It certainly takes more time, but  we also have more experts in TIR  passes :)
   
   Please also consider to add a test case to the passes that need while handling


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776703600


   @tqchen @junrushao1994 @vinx13 
   
   I went through the passes and here is my summary:
   * `VectorizeLoop`: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case `test_vectorize_while_fail()` to make sure we error out in such cases
   
   * `StorageAccessVisitor`: I don't understand what it does, but added a special visitor for `While` following the existing visitor for `IfThenElse`. Please check https://github.com/apache/tvm/pull/7425/commits/1e629b68b4112a01293683edc13c3e976a22a5bb
   
   * `CoProcSync` and `LiftAttrScope`: They both have special visitor for `IfThenElse`, but I don't understand them. They are only used by VTA, for now I just error out if we find `WhileNode` there. See https://github.com/apache/tvm/pull/7425/commits/a71066d49381aae62626593c8fd76e149e1e55ed and https://github.com/apache/tvm/pull/7425/commits/00c17d921005eecc07f4300df898b9107d15ea1d
   
   * `InjectVirtualThread`: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See https://github.com/apache/tvm/pull/7425/commits/896b02fb8aba00c22696f92195d32454bd593454 and let me know what we should do here.
   
   * Do we need to change `MergeNest`? I haven't touched it for now https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/ir_utils.cc#L35-L59 
   
   * Probably we don't need to change `hoist_if_then_else.cc` and `loop_partition.cc`. We can do something in `remove_no_op.cc`, but I think it is not important.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-788690275


   @tqchen @junrushao1994 @vinx13 @ZihengJiang @zxybazh 
   
   I came to a conclusion that While node doesn't need a special handling in `storage_rewrite`.
   
   The first observation is that even if I remove all `ForNode` handling from `StoragePlanRewriter`, all tests in `test_tir_transform_storage_rewrite.py` except [test_parallel_alloc()](https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/tests/python/unittest/test_tir_transform_storage_rewrite.py#L269) pass.
   
   If we look at the visitor for `ForNode`, https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L440-L452
   it only does something special when `attach_map_` has an entry for this node. Here comes the second observation: the only case where`attach_map_` can have an entry for `ForNode` is if this `ForNode` is a parallel for loop, due to these lines: https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L766-L772 
   
   Together, these two handler for `ForNode` lift allocation inside an inner loop and attach merged allocation under the parallel loop scope (via `MakeAttach` function at https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L447). This is tested in `parallel_alloc`. For other kinds of `For` loop, a merged allocation is placed at the global scope, see https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L457-L461.  
   
   Since `While` node doesn't involve threading, I think we can always lift allocation done inside `While` loop into the global scope. That means `WhileNode` should be handled in the same way non-parallel `ForNode` are handled, i.e. we don't need a special handling logic for `WhileNode`. Two simple test cases involving `While`  loop are added in https://github.com/apache/tvm/blob/c3af5ae9aa611580004ce03d16aa952ab124d826/tests/python/unittest/test_tir_transform_storage_rewrite.py#L301 to test allocation is attached at the right scope after `storage_rewrite`.
   
   I think I nailed it, thoughts? 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776703600


   @tqchen @junrushao1994 @vinx13 
   
   I went through passes and here is my summary:
   * `VectorizeLoop`: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case `test_vectorize_while_fail()` to make sure we error out in such cases
   
   * `StorageAccessVisitor`: I don't understand what it does, but added a special visitor for `While` following the existing visitor for `IfThenElse`. Please check https://github.com/apache/tvm/pull/7425/commits/1e629b68b4112a01293683edc13c3e976a22a5bb
   
   * `CoProcSync` and `LiftAttrScope`: They both have special visitor for `IfThenElse`, but I don't understand them. They are only used by VTA, for now I just error out if we find `WhileNode` there. See https://github.com/apache/tvm/pull/7425/commits/a71066d49381aae62626593c8fd76e149e1e55ed and https://github.com/apache/tvm/pull/7425/commits/00c17d921005eecc07f4300df898b9107d15ea1d
   
   * `InjectVirtualThread`: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See https://github.com/apache/tvm/pull/7425/commits/896b02fb8aba00c22696f92195d32454bd593454 and let me know what we should do here.
   
   * Do we need to change `MergeNest`? https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/ir_utils.cc#L35-L59 I haven't touch it for now
   
   * Probably we don't need to change `hoist_if_then_else.cc` and `loop_partition.cc`. We can do something in `remove_no_op.cc`, but I think it is not important.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778445359


   @vinx13 Ok, For `InplaceOpVerifier` I think I need to update
   https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L241-L251
   
   But I don't see how we should update `StoragePlanRewriter`. Maybe here? https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L757-L773 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778455332


   @masahi For `StoragePlanRewriter`, we need to do something similar to `ForNode` https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L440-L452


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778278954


   I left a comment for inject virtual thread, @junrushao1994 @ZihengJiang @vinx13 would be great if you can also help check the StorageAccessVisitor


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-788704483


   That makes sense to me. Thanks for diving deep into this issue!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789159545


   cc @tqchen please take a look


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #7425:
URL: https://github.com/apache/tvm/pull/7425


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789356174


   Thanks everyone @masahi @tqchen @junrushao1994 @giuseros @zxybazh @ZihengJiang 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-775986528


   Thanks @masahi , before we merge it in. it would be really awesome to go through the current list of passes and check if special handling of while is needed (so we won't bring in new bugs because the mix). Some of the example passes could include (I would at least check  passes that need special IfThenElse handling)
   
   For example, I can see the need to update following pass: 
   - Vectorize (we will need to abort if the condition is vectorized)
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-775995118


   also cc @zxybazh please help to review this PR


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#discussion_r585870509



##########
File path: tests/python/unittest/test_tir_ir_builder.py
##########
@@ -173,9 +173,383 @@ def check_target(target):
     check_target("cuda")
 
 
+def test_while_vectorize():
+    """Test while loop + vectorized inner loop"""
+
+    n = 64
+    num_iter = 10
+
+    def test_ir(A, B, C):
+        ib = tvm.tir.ir_builder.create()
+        n = C.shape[0]
+        A = ib.buffer_ptr(A)
+        B = ib.buffer_ptr(B)
+        C = ib.buffer_ptr(C)
+        i = ib.allocate("int32", (1,), name="i", scope="local")
+        i[0] = 0
+
+        with ib.for_range(0, n) as j:
+            C[j] = 0.0
+
+        with ib.while_loop(i[0] < num_iter):
+            with ib.for_range(0, n, kind="vectorize") as j:
+                C[j] += A[j] + B[j]
+            i[0] += 1
+
+        return ib.get()
+
+    def check_target(target, ir):
+        dtype = "float32"
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.placeholder((n,), name="B", dtype=dtype)
+
+        C = te.extern(
+            (n,),
+            [A, B],
+            lambda ins, outs: ir(ins[0], ins[1], outs[0]),
+            name="while_vectorize",
+            dtype=dtype,
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [A, B, C], target)
+
+        ctx = tvm.context(target, 0)
+        a_np = np.random.uniform(size=n).astype(A.dtype)
+        b_np = np.random.uniform(size=n).astype(B.dtype)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(a, b, c)
+        ref = num_iter * (a_np + b_np)
+        tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5)
+
+    check_target("llvm", test_ir)
+
+
+def test_while_collatz():
+    """Test while loop + if"""
+
+    def collatz_ref(n):
+        a = n
+        i = 0
+        while a > 1:
+            if a % 2 == 1:
+                a = 3 * a + 1
+            else:
+                a = a >> 1
+            i += 1
+        return i
+
+    def collatz(ib, n, C):
+        i = ib.allocate("int32", (1,), name="i", scope="local")
+        a = ib.allocate("int32", (1,), name="a", scope="local")
+        i[0] = 0
+        a[0] = n
+        with ib.while_loop(a[0] > 1):
+            with ib.if_scope(tvm.tir.floormod(a[0], 2) == 1):
+                a[0] = 3 * a[0] + 1
+            with ib.else_scope():
+                a[0] = a[0] >> 1
+            i[0] += 1
+
+        C[n] = i[0]
+
+    def collatz_ir_cpu(C):
+        ib = tvm.tir.ir_builder.create()
+        n = C.shape[0]
+        C = ib.buffer_ptr(C)
+
+        with ib.for_range(0, n, name="i", kind="parallel") as i:
+            collatz(ib, i, C)
+
+        body = ib.get()
+
+        return body
+
+    n = 30
+
+    def check_target(target, ir):
+        C = te.extern(
+            (n,),
+            [],
+            lambda ins, outs: ir(outs[0]),
+            name="collatz",
+            dtype="int32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [C], target)
+
+        ctx = tvm.context(target, 0)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(c)
+        ref = np.array([collatz_ref(i) for i in range(n)])
+        tvm.testing.assert_allclose(c.asnumpy(), ref)
+
+    check_target("llvm", collatz_ir_cpu)
+
+
+def test_while_mandel():
+    n = 160
+    shape = (n * 2, n)
+    t = 300
+
+    def mandel_ref():
+        def complex_sqr(z):
+            return np.array([z[0] ** 2 - z[1] ** 2, z[1] * z[0] * 2])
+
+        pixels = np.zeros(shape)
+
+        for i in range(pixels.shape[0]):
+            for j in range(pixels.shape[1]):
+                c = np.array([-0.8, np.cos(t) * 0.2])
+                z = np.array([i / n - 1, j / n - 0.5]) * 2
+                iterations = 0
+
+                while np.linalg.norm(z) < 20 and iterations < 50:
+                    z = complex_sqr(z) + c
+                    iterations += 1
+
+                pixels[i, j] = 1 - iterations * 0.02
+
+        return pixels
+
+    def mandel(ib, i, j, pixels):
+        z = ib.allocate("float32", (2,), name="z", scope="local")
+        tmp = ib.allocate("float32", (1,), name="tmp", scope="local")
+        iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
+
+        z[0] = (i / float(n) - 1) * 2
+        z[1] = (j / float(n) - 0.5) * 2
+        iterations[0] = 0
+        c = [-0.8, float(np.cos(t)) * 0.2]
+
+        def norm(z):
+            return tvm.tir.sqrt(z[0] * z[0] + z[1] * z[1])
+
+        with ib.while_loop(tvm.tir.all(norm(z) < 20, iterations[0] < 50)):
+            tmp[0] = z[0]
+            z[0] = z[0] * z[0] - z[1] * z[1] + c[0]
+            z[1] = z[1] * tmp[0] * 2 + c[1]
+            iterations[0] += 1
+
+        pixels[i, j] = 1 - iterations[0] * 0.02
+
+    def mandel_ir_cpu(C):
+        ib = tvm.tir.ir_builder.create()
+        ny = C.shape[0]
+        nx = C.shape[1]
+        C = ib.buffer_ptr(C)
+
+        with ib.for_range(0, ny, name="i", kind="parallel") as i:
+            with ib.for_range(0, nx, name="j") as j:
+                mandel(ib, i, j, C)
+
+        body = ib.get()
+
+        return body
+
+    def mandel_ir_gpu(C):
+        ib = tvm.tir.ir_builder.create()
+        ny = C.shape[0]
+        nx = C.shape[1]
+        C = ib.buffer_ptr(C)
+
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ty = te.thread_axis("threadIdx.y")
+
+        max_threads = 16
+        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(nx + max_threads - 1, max_threads))
+        ib.scope_attr(tx, "thread_extent", max_threads)
+        ib.scope_attr(by, "thread_extent", tvm.tir.indexdiv(ny + max_threads - 1, max_threads))
+        ib.scope_attr(ty, "thread_extent", max_threads)
+
+        tidx = bx * max_threads + tx
+        tidy = by * max_threads + ty
+
+        with ib.if_scope(tvm.tir.all(tidx < nx, tidy < ny)):
+            mandel(ib, tidy, tidx, C)
+
+        body = ib.get()
+
+        return body
+
+    ref = mandel_ref()
+
+    def check_target(target, ir):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        C = te.extern(
+            shape,
+            [],
+            lambda ins, outs: ir(outs[0]),
+            name="mandel_ir",
+            dtype="float32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [C], target)
+
+        ctx = tvm.context(target, 0)
+        c = tvm.nd.array(np.zeros(shape, dtype=C.dtype), ctx)
+        func(c)
+        tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5)
+
+    check_target("llvm", mandel_ir_cpu)
+    check_target("npvtx", mandel_ir_gpu)
+    check_target("cuda", mandel_ir_gpu)
+
+
+def test_while_binary_search():
+    def binary_search(ib, n, i, Aptr, Bptr, Cptr):
+        lo = ib.allocate("int32", (1,), name="lo", scope="local")
+        hi = ib.allocate("int32", (1,), name="hi", scope="local")
+
+        lo[0] = 0
+        hi[0] = n
+        v = Bptr[i]
+
+        with ib.while_loop(lo[0] < hi[0]):
+            mid = lo[0] + (hi[0] - lo[0] >> 1)
+            with ib.if_scope(Aptr[mid] < v):
+                lo[0] = mid + 1
+            with ib.else_scope():
+                hi[0] = mid
+
+        Cptr[i] = lo[0]
+
+    def searchsorted_ir_cpu(A, B, C, n):
+        ib = tvm.tir.ir_builder.create()
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+        Cptr = ib.buffer_ptr(C)
+
+        with ib.for_range(0, n, name="i", kind="parallel") as i:
+            binary_search(ib, n, i, Aptr, Bptr, Cptr)
+
+        body = ib.get()
+
+        return body
+
+    def searchsorted_ir_gpu(A, B, C, n):
+        ib = tvm.tir.ir_builder.create()
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+        Cptr = ib.buffer_ptr(C)
+
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        max_threads = 32
+        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads))
+        ib.scope_attr(tx, "thread_extent", max_threads)
+        tid = bx * max_threads + tx
+
+        with ib.if_scope(tid < n):
+            binary_search(ib, n, tid, Aptr, Bptr, Cptr)
+
+        body = ib.get()
+
+        return body
+
+    n = 1024
+    dtype = "float32"
+    A = te.placeholder((n,), name="A", dtype=dtype)
+    B = te.placeholder((n,), name="B", dtype=dtype)
+
+    def check_target(target, ir):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        C = te.extern(
+            A.shape,
+            [A, B],
+            lambda ins, outs: ir(ins[0], ins[1], outs[0], n),
+            name="searchsorted_ir",
+            dtype="int32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [A, B, C], target)
+
+        ctx = tvm.context(target, 0)
+        a_np = np.random.uniform(size=n).astype(A.dtype)
+        b_np = np.random.uniform(size=n).astype(B.dtype)
+        a_np = np.sort(a_np)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(a, b, c)
+        ref = np.searchsorted(a_np, b_np)
+        tvm.testing.assert_allclose(c.asnumpy(), ref)
+
+    check_target("llvm", searchsorted_ir_cpu)
+    check_target("cuda", searchsorted_ir_gpu)
+    check_target("nvptx", searchsorted_ir_gpu)
+
+
+def test_vectorize_while_fail():
+    """A while loop inside a vectorized loop should fail."""

Review comment:
       Please move to test_tir_transform_vectorize.py




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778460602


   ok, to me it's not obvious what it is doing, time for another deep dive...


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#discussion_r585874988



##########
File path: tests/python/unittest/test_tir_ir_builder.py
##########
@@ -173,9 +173,383 @@ def check_target(target):
     check_target("cuda")
 
 
+def test_while_vectorize():
+    """Test while loop + vectorized inner loop"""
+
+    n = 64
+    num_iter = 10
+
+    def test_ir(A, B, C):
+        ib = tvm.tir.ir_builder.create()
+        n = C.shape[0]
+        A = ib.buffer_ptr(A)
+        B = ib.buffer_ptr(B)
+        C = ib.buffer_ptr(C)
+        i = ib.allocate("int32", (1,), name="i", scope="local")
+        i[0] = 0
+
+        with ib.for_range(0, n) as j:
+            C[j] = 0.0
+
+        with ib.while_loop(i[0] < num_iter):
+            with ib.for_range(0, n, kind="vectorize") as j:
+                C[j] += A[j] + B[j]
+            i[0] += 1
+
+        return ib.get()
+
+    def check_target(target, ir):
+        dtype = "float32"
+        A = te.placeholder((n,), name="A", dtype=dtype)
+        B = te.placeholder((n,), name="B", dtype=dtype)
+
+        C = te.extern(
+            (n,),
+            [A, B],
+            lambda ins, outs: ir(ins[0], ins[1], outs[0]),
+            name="while_vectorize",
+            dtype=dtype,
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [A, B, C], target)
+
+        ctx = tvm.context(target, 0)
+        a_np = np.random.uniform(size=n).astype(A.dtype)
+        b_np = np.random.uniform(size=n).astype(B.dtype)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(a, b, c)
+        ref = num_iter * (a_np + b_np)
+        tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5)
+
+    check_target("llvm", test_ir)
+
+
+def test_while_collatz():
+    """Test while loop + if"""
+
+    def collatz_ref(n):
+        a = n
+        i = 0
+        while a > 1:
+            if a % 2 == 1:
+                a = 3 * a + 1
+            else:
+                a = a >> 1
+            i += 1
+        return i
+
+    def collatz(ib, n, C):
+        i = ib.allocate("int32", (1,), name="i", scope="local")
+        a = ib.allocate("int32", (1,), name="a", scope="local")
+        i[0] = 0
+        a[0] = n
+        with ib.while_loop(a[0] > 1):
+            with ib.if_scope(tvm.tir.floormod(a[0], 2) == 1):
+                a[0] = 3 * a[0] + 1
+            with ib.else_scope():
+                a[0] = a[0] >> 1
+            i[0] += 1
+
+        C[n] = i[0]
+
+    def collatz_ir_cpu(C):
+        ib = tvm.tir.ir_builder.create()
+        n = C.shape[0]
+        C = ib.buffer_ptr(C)
+
+        with ib.for_range(0, n, name="i", kind="parallel") as i:
+            collatz(ib, i, C)
+
+        body = ib.get()
+
+        return body
+
+    n = 30
+
+    def check_target(target, ir):
+        C = te.extern(
+            (n,),
+            [],
+            lambda ins, outs: ir(outs[0]),
+            name="collatz",
+            dtype="int32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [C], target)
+
+        ctx = tvm.context(target, 0)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(c)
+        ref = np.array([collatz_ref(i) for i in range(n)])
+        tvm.testing.assert_allclose(c.asnumpy(), ref)
+
+    check_target("llvm", collatz_ir_cpu)
+
+
+def test_while_mandel():
+    n = 160
+    shape = (n * 2, n)
+    t = 300
+
+    def mandel_ref():
+        def complex_sqr(z):
+            return np.array([z[0] ** 2 - z[1] ** 2, z[1] * z[0] * 2])
+
+        pixels = np.zeros(shape)
+
+        for i in range(pixels.shape[0]):
+            for j in range(pixels.shape[1]):
+                c = np.array([-0.8, np.cos(t) * 0.2])
+                z = np.array([i / n - 1, j / n - 0.5]) * 2
+                iterations = 0
+
+                while np.linalg.norm(z) < 20 and iterations < 50:
+                    z = complex_sqr(z) + c
+                    iterations += 1
+
+                pixels[i, j] = 1 - iterations * 0.02
+
+        return pixels
+
+    def mandel(ib, i, j, pixels):
+        z = ib.allocate("float32", (2,), name="z", scope="local")
+        tmp = ib.allocate("float32", (1,), name="tmp", scope="local")
+        iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
+
+        z[0] = (i / float(n) - 1) * 2
+        z[1] = (j / float(n) - 0.5) * 2
+        iterations[0] = 0
+        c = [-0.8, float(np.cos(t)) * 0.2]
+
+        def norm(z):
+            return tvm.tir.sqrt(z[0] * z[0] + z[1] * z[1])
+
+        with ib.while_loop(tvm.tir.all(norm(z) < 20, iterations[0] < 50)):
+            tmp[0] = z[0]
+            z[0] = z[0] * z[0] - z[1] * z[1] + c[0]
+            z[1] = z[1] * tmp[0] * 2 + c[1]
+            iterations[0] += 1
+
+        pixels[i, j] = 1 - iterations[0] * 0.02
+
+    def mandel_ir_cpu(C):
+        ib = tvm.tir.ir_builder.create()
+        ny = C.shape[0]
+        nx = C.shape[1]
+        C = ib.buffer_ptr(C)
+
+        with ib.for_range(0, ny, name="i", kind="parallel") as i:
+            with ib.for_range(0, nx, name="j") as j:
+                mandel(ib, i, j, C)
+
+        body = ib.get()
+
+        return body
+
+    def mandel_ir_gpu(C):
+        ib = tvm.tir.ir_builder.create()
+        ny = C.shape[0]
+        nx = C.shape[1]
+        C = ib.buffer_ptr(C)
+
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ty = te.thread_axis("threadIdx.y")
+
+        max_threads = 16
+        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(nx + max_threads - 1, max_threads))
+        ib.scope_attr(tx, "thread_extent", max_threads)
+        ib.scope_attr(by, "thread_extent", tvm.tir.indexdiv(ny + max_threads - 1, max_threads))
+        ib.scope_attr(ty, "thread_extent", max_threads)
+
+        tidx = bx * max_threads + tx
+        tidy = by * max_threads + ty
+
+        with ib.if_scope(tvm.tir.all(tidx < nx, tidy < ny)):
+            mandel(ib, tidy, tidx, C)
+
+        body = ib.get()
+
+        return body
+
+    ref = mandel_ref()
+
+    def check_target(target, ir):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        C = te.extern(
+            shape,
+            [],
+            lambda ins, outs: ir(outs[0]),
+            name="mandel_ir",
+            dtype="float32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [C], target)
+
+        ctx = tvm.context(target, 0)
+        c = tvm.nd.array(np.zeros(shape, dtype=C.dtype), ctx)
+        func(c)
+        tvm.testing.assert_allclose(c.asnumpy(), ref, rtol=1e-5, atol=1e-5)
+
+    check_target("llvm", mandel_ir_cpu)
+    check_target("npvtx", mandel_ir_gpu)
+    check_target("cuda", mandel_ir_gpu)
+
+
+def test_while_binary_search():
+    def binary_search(ib, n, i, Aptr, Bptr, Cptr):
+        lo = ib.allocate("int32", (1,), name="lo", scope="local")
+        hi = ib.allocate("int32", (1,), name="hi", scope="local")
+
+        lo[0] = 0
+        hi[0] = n
+        v = Bptr[i]
+
+        with ib.while_loop(lo[0] < hi[0]):
+            mid = lo[0] + (hi[0] - lo[0] >> 1)
+            with ib.if_scope(Aptr[mid] < v):
+                lo[0] = mid + 1
+            with ib.else_scope():
+                hi[0] = mid
+
+        Cptr[i] = lo[0]
+
+    def searchsorted_ir_cpu(A, B, C, n):
+        ib = tvm.tir.ir_builder.create()
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+        Cptr = ib.buffer_ptr(C)
+
+        with ib.for_range(0, n, name="i", kind="parallel") as i:
+            binary_search(ib, n, i, Aptr, Bptr, Cptr)
+
+        body = ib.get()
+
+        return body
+
+    def searchsorted_ir_gpu(A, B, C, n):
+        ib = tvm.tir.ir_builder.create()
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+        Cptr = ib.buffer_ptr(C)
+
+        bx = te.thread_axis("blockIdx.x")
+        tx = te.thread_axis("threadIdx.x")
+        max_threads = 32
+        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads))
+        ib.scope_attr(tx, "thread_extent", max_threads)
+        tid = bx * max_threads + tx
+
+        with ib.if_scope(tid < n):
+            binary_search(ib, n, tid, Aptr, Bptr, Cptr)
+
+        body = ib.get()
+
+        return body
+
+    n = 1024
+    dtype = "float32"
+    A = te.placeholder((n,), name="A", dtype=dtype)
+    B = te.placeholder((n,), name="B", dtype=dtype)
+
+    def check_target(target, ir):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        C = te.extern(
+            A.shape,
+            [A, B],
+            lambda ins, outs: ir(ins[0], ins[1], outs[0], n),
+            name="searchsorted_ir",
+            dtype="int32",
+        )
+        s = te.create_schedule(C.op)
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(s, [A, B, C], target)
+
+        ctx = tvm.context(target, 0)
+        a_np = np.random.uniform(size=n).astype(A.dtype)
+        b_np = np.random.uniform(size=n).astype(B.dtype)
+        a_np = np.sort(a_np)
+        a = tvm.nd.array(a_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(a, b, c)
+        ref = np.searchsorted(a_np, b_np)
+        tvm.testing.assert_allclose(c.asnumpy(), ref)
+
+    check_target("llvm", searchsorted_ir_cpu)
+    check_target("cuda", searchsorted_ir_gpu)
+    check_target("nvptx", searchsorted_ir_gpu)
+
+
+def test_vectorize_while_fail():
+    """A while loop inside a vectorized loop should fail."""

Review comment:
       done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-786903727


   @tqchen are there any specific things you still want @masahi to change? I have a large speedup for sort waiting in the wings for this PR to merge.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778449042


   Thanks @masahi , it would also be great for you to spend a bit more time to look into these passes :) It certainly takes more time, but  we will also have more experts in TIR  passes :)
   
   Please also consider to add a test case to the passes that need while handling


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776703600


   @tqchen @junrushao1994 @vinx13 
   
   I went through passes and here is my summary:
   * `VectorizeLoop`: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case `test_vectorize_while_fail()` to make sure we error out in such cases
   
   * `StorageAccessVisitor`: I don't understand what it does, but added a special visitor for `While` following the existing visitor for `IfThenElse`. Please check https://github.com/apache/tvm/pull/7425/commits/1e629b68b4112a01293683edc13c3e976a22a5bb
   
   * `CoProcSync` and `LiftAttrScope`: They both have special visitor for `IfThenElse`, but I don't understand them. They are only used by VTA, for now I just error out if we find `WhileNode` there. See https://github.com/apache/tvm/pull/7425/commits/a71066d49381aae62626593c8fd76e149e1e55ed and https://github.com/apache/tvm/pull/7425/commits/00c17d921005eecc07f4300df898b9107d15ea1d
   
   * `InjectVirtualThread`: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See https://github.com/apache/tvm/pull/7425/commits/896b02fb8aba00c22696f92195d32454bd593454 and let me know what we should do here.
   
   * Do we need to change `MergeNest`? I haven't touched it for now https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/ir_utils.cc#L35-L59 
   
   * Probably we don't need to change `hoist_if_then_else.cc` and `loop_partition.cc`. We can do something in `remove_no_op.cc`, but I think it is not important.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778401698


   I've checked `StorageAccessVisitor` and it looks good to me.`InplaceOpVerifier`, `StoragePlanRewriter` also needs handling. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-784098499


   > Thanks @masahi ! the change has addressed my previous comments. Please add testcases to transforms that touches requires special While handling to cover these passes
   
   Yes, I'm still trying to figure out what `StorageRewrite` is doing. This pass is a beast :slightly_smiling_face: I think it is doing something like storage coalescing etc, and since this is purely for optimization (I think), I'm not sure what the "failure" means in this case.
   
   For example, we should definitely prevent invalid optimization. But I so far I'm having hard timing coming up with an example program that could fail. Anyone have any idea? Another example of failure is a missed optimization, I have an example where `For` loop coalesces but `While` loop does not https://github.com/apache/tvm/pull/7425#issuecomment-779798238
   
   I've also added a non trivial change to `StorageAccessVisitor` following the existing visitor for `IfThenElseNode`, I need to look at this class and its derived classes `ThreadSyncPlanner` etc


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789176773


   @junrushao1994 @vinx13 please help to manage the PR


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789347656


   @junrushao1994 @vinx13 @tqchen ready to merge...!!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778449042


   @masahi , it would also be great for you to spend a bit more time to look into these passes :) It certainly takes more time, but  we also have more experts in TIR  passes :)


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776703600


   @tqchen @junrushao1994 @vinx13 
   
   I went through passes and here is my summary:
   * `VectorizeLoop`: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case `test_vectorize_while_fail()` to make sure we error out in such cases
   
   * `StorageAccessVisitor`: I don't understand what it does, but added a special visitor for `While` following the existing visitor for `IfThenElse`. Please check https://github.com/apache/tvm/pull/7425/commits/1e629b68b4112a01293683edc13c3e976a22a5bb
   
   * `CoProcSync` and `LiftAttrScope`: They both have special visitor for `IfThenElse`, but I don't understand them. They are only used by VTA, for now I just error out if we find `WhileNode` there. See https://github.com/apache/tvm/pull/7425/commits/a71066d49381aae62626593c8fd76e149e1e55ed and https://github.com/apache/tvm/pull/7425/commits/00c17d921005eecc07f4300df898b9107d15ea1d
   
   * `InjectVirtualThread`: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See https://github.com/apache/tvm/pull/7425/commits/896b02fb8aba00c22696f92195d32454bd593454 and let me know what we should do here.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-779798238


   @tqchen @vinx13 @junrushao1994 Does the behavior of `While` node wrt `StorageRewrite` below look reasonable?
   
   In the following IR, "A" and "B" buffers, which are allocated in `For` loop, are coalesced into a one buffer, but "C" buffer, which is allocated inside `While` loop, is not:  
   ```
   def test_parallel_alloc():
       ib = tvm.tir.ir_builder.create()
       n = te.var("n")
       with ib.for_range(0, n, name="i", kind="parallel") as i:
           with ib.for_range(0, 10, name="j") as j:
               A = ib.allocate("float32", n, name="A", scope="global")
               A[j] = A[j] + 2
   
           with ib.for_range(0, 10, name="j") as j:
               B = ib.allocate("float32", n, name="B", scope="global")
               B[j] = B[j] + 2
   
           i = ib.allocate("int32", (1,), name="i", scope="local")
           i[0] = 1
           with ib.while_loop(i[0] < 10):
               C = ib.allocate("float32", n, name="C", scope="local")
               C[i[0]] = C[i[0]] + 2
               i[0] += 1
   ```
   
   ```
   parallel (i, 0, n) {
     // attr [A] storage_scope = "global"
     allocate A[float32 * n]
     // attr [i] storage_scope = "local"
     allocate i[int32 * 1]
     // attr [C] storage_scope = "local"
     allocate C[float32 * n]
     for (j, 0, 10) {
       A[j] = (A[j] + 2f)
     }
     for (j, 0, 10) {
       A[j] = (A[j] + 2f)
     }
     i[0] = 1
     while((i[0] < 10)){
       C[i[0]] = (C[i[0]] + 2f)
       i[0] = (i[0] + 1)
     }
   }
   ```
   
   In the following IR, all buffers, including the one allocated inside `While` loop, are coalesced:
   ```
   def test_alloc_seq():
       scope_tb = "local.L0A"
       max_bits = 1024 * 1024 * 1024
   
       register_mem(scope_tb, max_bits)
   
       ib = tvm.tir.ir_builder.create()
       n = te.var("n")
       with ib.for_range(0, n, name="i") as i:
           with ib.for_range(0, 10, name="j") as j:
               A = ib.allocate("float32", 200, name="A", scope=scope_tb)
               A[j] = 1.2
           with ib.for_range(0, 10, name="j") as j:
               B = ib.allocate("float32", 200, name="B", scope=scope_tb)
               B[j] = 1.3
   
           i = ib.allocate("int32", (1,), name="i", scope="local")
           i[0] = 1
           with ib.while_loop(i[0] < 10):
               C = ib.allocate("float32", 200, name="C", scope=scope_tb)
               C[i[0]] = 1.4
               i[0] += 1
   
       body = ib.get()
   ```
   
   ```
   // attr [A] storage_scope = "local.L0A"
   allocate A[float32 * 200]
   // attr [i] storage_scope = "local"
   allocate i[int32 * 1]
   for (i, 0, n) {
     for (j, 0, 10) {
       A[j] = 1.2f
     }
     for (j, 0, 10) {
       A[j] = 1.3f
     }
     i[0] = 1
     while((i[0] < 10)){
       A[i[0]] = 1.4f
       i[0] = (i[0] + 1)
     }
   }
   
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-784098499


   > Thanks @masahi ! the change has addressed my previous comments. Please add testcases to transforms that touches requires special While handling to cover these passes
   
   Yes, so far only a test case to prevent invalid vectorization is added. I'm still trying to figure out what `StorageRewrite` is doing. This pass is a beast :slightly_smiling_face: I think it is doing something like storage coalescing etc, and since this is purely for optimization (I think), I'm not sure what the "failure" means in this case.
   
   For example, we should definitely prevent invalid optimization. But I so far I'm having hard timing coming up with an example program that could fail. Anyone have any idea? Another example of failure is a missed optimization, I have an example where `For` loop coalesces but `While` loop does not https://github.com/apache/tvm/pull/7425#issuecomment-779798238
   
   I've also added a non trivial change to `StorageAccessVisitor` following the existing visitor for `IfThenElseNode`, I need to look at this class and its derived classes `ThreadSyncPlanner` etc


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-788701925


   @masahi You are right, thanks for looking into this


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789378808


   Thank you very much for the reviews!!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-786911002


   Yeah I'm still trying to figure out if `storage_rewrite` is safe for `While` loop and if not, what to do and what tests to add.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-775986528


   Thanks @masahi , before we merge it in. it would be really awesome to go through the current list of passes and check if special handling of while is needed. Some of the example passes could include(check for passes that need special IfThenElse handling)
   
   For example, I can see the need to update following pass: 
   - Vectorize (we will need to abort if the condition is vectorized)
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-784192951


   Thanks @masahi . The primary thing to address is correctness(e.g. we should not generate invalid code). It is totally fine to not optimize While well. So the test coverage is mainly to make sure that the case is handled correctly. Please let us know once you confirmed the storageaccess and rewrite, then we can revisit and merge


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#discussion_r573138266



##########
File path: include/tvm/tir/stmt_functor.h
##########
@@ -109,6 +110,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
     IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
     IR_STMT_FUNCTOR_DISPATCH(ForNode);
+    IR_STMT_FUNCTOR_DISPATCH(WhileNode);

Review comment:
       need checks through the current passes, per my comment




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776149218


   CC @spectrometerHBH: we might want to have it supported in TensorIR too, either like a syntactic sugar to opaque binding or other ways


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-788690275


   @tqchen @junrushao1994 @vinx13 @ZihengJiang @zxybazh 
   
   I came to a conclusion that While node doesn't need a special handling in `storage_rewrite`.
   
   The first observation is that even if I remove all `ForNode` handling from `StoragePlanRewriter`, all tests in `test_tir_transform_storage_rewrite.py` except [test_parallel_alloc()](https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/tests/python/unittest/test_tir_transform_storage_rewrite.py#L269) pass.
   
   If we look at the visitor for `ForNode`, https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L440-L452
   it only does something special when `attach_map_` has an entry for this node. Here comes the second observation: the only case where`attach_map_` can have an entry for `ForNode` is if this `ForNode` is a parallel for loop, due to these lines: https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L766-L772 
   
   Together, these two handler for `ForNode` lift allocation inside an inner loop and attach merged allocation under the parallel loop scope (via `MakeAttach` function at https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L447). This is what's tested in `test_parallel_alloc()`. For other kinds of `For` loop, a merged allocation is placed at the global scope, see https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/storage_rewrite.cc#L457-L461.  
   
   Since `While` node doesn't involve threading, I think we can always lift allocation done inside `While` loop into the global scope. That means `WhileNode` should be handled in the same way non-parallel `ForNode` are handled, i.e. we don't need a special handling logic for `WhileNode`. Two simple test cases involving `While`  loop are added in https://github.com/apache/tvm/blob/c3af5ae9aa611580004ce03d16aa952ab124d826/tests/python/unittest/test_tir_transform_storage_rewrite.py#L301 to test allocation is attached at the right scope after `storage_rewrite`.
   
   I think I nailed it, thoughts? 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#discussion_r575323069



##########
File path: src/tir/transforms/inject_virtual_thread.cc
##########
@@ -333,6 +333,12 @@ class VTInjector : public StmtExprMutator {
     }
   }
 
+  // While
+  Stmt VisitStmt_(const WhileNode* op) final {
+    // TODO(masahi): Do we need a special handling for While nodes?

Review comment:
       Let us disable for now. Likely need special handling later




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
vinx13 edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-778401698


   I've checked `StorageAccessVisitor` and it looks good to me.`InplaceOpVerifier`, `StoragePlanRewriter` also need handling. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789371094


   Really awesome work!!!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-776703600


   @tqchen @junrushao1994 @vinx13 
   
   I went through passes and here is my summary:
   * `VectorizeLoop`: Need to disallow a while loop inside a vectorized loop. Without it, no errors occurs during lowering but the lowered code is incorrect. Add a test case `test_vectorize_while_fail()` to make sure we error out in such cases
   
   * `StorageAccessVisitor`: I don't understand what it does, but added a special visitor for `While` following the existing visitor for `IfThenElse`. Please check https://github.com/apache/tvm/pull/7425/commits/1e629b68b4112a01293683edc13c3e976a22a5bb
   
   * `CoProcSync` and `LiftAttrScope`: They both have special visitor for `IfThenElse`, but I don't understand them. They are only used by VTA, for now I just error out if we find `WhileNode` there. See https://github.com/apache/tvm/pull/7425/commits/a71066d49381aae62626593c8fd76e149e1e55ed and https://github.com/apache/tvm/pull/7425/commits/00c17d921005eecc07f4300df898b9107d15ea1d
   
   * `InjectVirtualThread`: I think we need some special handling for this, but I don't know what it should be. For now I just added a placeholder and call the base class visitor. See https://github.com/apache/tvm/pull/7425/commits/896b02fb8aba00c22696f92195d32454bd593454 and let me know what we should do here.
   
   * Do we need to change `MergeNest`? I haven't touch it for now https://github.com/apache/tvm/blob/7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb/src/tir/transforms/ir_utils.cc#L35-L59 
   
   * Probably we don't need to change `hoist_if_then_else.cc` and `loop_partition.cc`. We can do something in `remove_no_op.cc`, but I think it is not important.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7425: [TIR] Add TIR While node

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7425:
URL: https://github.com/apache/tvm/pull/7425#issuecomment-789176475


   @masahi you are right that the MakeAttach is only needed for parallel for loop, where we can nolonger lift the memory to the outside(otherwise the memory won't be thread local)


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org