You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/10 07:37:42 UTC

[GitHub] [tvm] syang-ng opened a new pull request #9485: [BugFix] Fix divide by zero error in TIR pass lower_warp_memory

syang-ng opened a new pull request #9485:
URL: https://github.com/apache/tvm/pull/9485


   Here is a PR to fix the `divide by zero` bug mentioned in this [post](https://discuss.tvm.apache.org/t/divide-by-zero-error-in-tir-pass-lower-warp-memory/11433). In short, when applying TIR pass `lower_warp_memory`,  it will re-calculate the size and then allocate. And if the variable `factor` is zero, it can cause such an error.
   
   https://github.com/syang-ng/tvm/blob/main/src/tir/transforms/lower_warp_memory.cc#L222-L227
   
   Here is the code example to reproduce the bug:
   
   ```python
   import tvm
   from tvm import te, tir
   
   ib = tir.ir_builder.IRBuilder()
   bx = te.thread_axis("blockIdx.x")
   tx = te.thread_axis("threadIdx.x")
   
   with ib.new_scope():
       ib.scope_attr(bx, "thread_extent", 32)
       ib.scope_attr(tx, "thread_extent", 32)
       t = ib.allocate("float32", 16, name="t", scope="warp")
       n = ib.allocate("float32", 16, name="n", scope="local")
       n[0] = t[0]
   
   stmt = ib.get()
   f = tvm.tir.PrimFunc([], stmt)
   f = f.with_attr('from_legacy_te_schedule', True)
   m = tvm.lower(f)
   tvm.build(m, target=tvm.target.Target('cuda'))
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Mousius commented on pull request #9485: [BugFix] Fix divide by zero error in TIR pass lower_warp_memory

Posted by GitBox <gi...@apache.org>.
Mousius commented on pull request #9485:
URL: https://github.com/apache/tvm/pull/9485#issuecomment-965533397


   Hi @syang-ng,
   
   Is it possible for you to convert your reproduction into a test case using `pytest.raises` to check an exception is raised? :smile_cat: 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Mousius commented on a change in pull request #9485: [BugFix] Fix divide by zero error in TIR pass lower_warp_memory

Posted by GitBox <gi...@apache.org>.
Mousius commented on a change in pull request #9485:
URL: https://github.com/apache/tvm/pull/9485#discussion_r747345307



##########
File path: tests/python/unittest/test_tir_transform_lower_warp_memory.py
##########
@@ -310,11 +311,29 @@ def test_lower_warp_memory_same_thread():
     assert "tvm_warp_shuffle" not in fdevice.astext()
 
 
+@tvm.testing.requires_cuda
+def test_lower_warp_memory_divide_by_factor():
+    ib = tvm.tir.ir_builder.IRBuilder()
+    bx = te.thread_axis("blockIdx.x")
+    tx = te.thread_axis("threadIdx.x")
+
+    with ib.new_scope():
+        ib.scope_attr(bx, "thread_extent", 32)
+        ib.scope_attr(tx, "thread_extent", 32)
+        t = ib.allocate("float32", 16, name="t", scope="warp")
+        n = ib.allocate("float32", 16, name="n", scope="local")
+        n[0] = t[0]
+
+    stmt = ib.get()
+    func = tvm.tir.PrimFunc([], stmt)
+    func = func.with_attr("from_legacy_te_schedule", True)
+    cuda_target = tvm.target.Target("cuda")
+    mod = tvm.lower(func, name="f")
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
+    with pytest.raises(tvm.error.TVMError) as cm:

Review comment:
       Minor nit: you could do use the `match` parameter instead of an assert on the `str` here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] syang-ng commented on pull request #9485: [BugFix] Fix divide by zero error in TIR pass lower_warp_memory

Posted by GitBox <gi...@apache.org>.
syang-ng commented on pull request #9485:
URL: https://github.com/apache/tvm/pull/9485#issuecomment-966005639


   > Hi @syang-ng,
   > 
   > Is it possible for you to convert your reproduction into a test case using `pytest.raises` to check an exception is raised? 😸
   
   Hi @Mousius, thanks for your suggestion! I have added a test case in test_tir_transform_lower_warp_memory.py. :-) 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy merged pull request #9485: [BugFix] Fix divide by zero error in TIR pass lower_warp_memory

Posted by GitBox <gi...@apache.org>.
Hzfengsy merged pull request #9485:
URL: https://github.com/apache/tvm/pull/9485


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org