You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/25 10:14:37 UTC

[GitHub] [incubator-tvm] jcf94 opened a new pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

jcf94 opened a new pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924


   This pr is part of #5883 , fix for the rewrite_simplify error when doing vectorized cooperative fetching in some cases.
   
   Code generated with bug is shown like this:
   ```
   A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
   (float32x4*)A_2[(((broadcast(((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)), 4) + (floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))*broadcast(512, 4))) + broadcast((k.outer.outer*64), 4)) + floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)))])
   ```
   Which will finally lower to wrong CUDA C instructions.
   This should be simplified to generate the correct RampNode:
   ```
   A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
   (float32x4*)A_2[ramp((((((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)) + (floordiv(threadIdx.x_1, 16)*512)) + (k.outer.outer*64)) + (floormod(threadIdx.x_1, 16)*4)), 1, 4)])
   ```
   
   Then main problems inside this expression are:
   ```
   threadIdx.x_1 = [0, 64]
   floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)) * broadcast(512, 4)
   floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))
   ```
   should be simplified to:
   ```
   threadIdx.x_1 = [0, 64]
   broadcast(floordiv(threadIdx.x_1, 16)*512), 4)
   ramp(floormod(threadIdx.x_1, 16)*4, 1, 4)
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650045223


   > @jcf94 Did our old rule affect the correctness of common operators?
   
   Yes, with those rules several other UTs will fail, for exp `test_arith_intset.py`.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650022894


   > Thanks @jcf94 we should add a testcase to test_arith_rewrite_simplify, by constructing the case and
   > 
   > * verify each of the rule added in this PR.
   > * Use `isinstance(x, tvm.ir.Ramp)` to assert the ramp node
   > * You mentioned a bug in the previous rule, it would be great if the testcase covers the bug you mentioned
   
   Commemts are all addressed.
   
   - Add several test cases in tests_arith_rewrite_simplify for simplify rules
   - Update test_target_codegen_cuda UTs & use pre post function to check the RampNode patterns
   - The bug metioned above is introduced by our former implementation, currently everything work fine


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-649619450


   @jcf94 we should add a testcase to test_arith_rewrite_simplify, by constructing the case and verify the rules you have worked. Use `isinstance(x, tvm.ir.Ramp)` to assert the ramp node


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650204858


   Thanks @jcf94 @merrymercy . this PR is now merged


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650045223


   > @jcf94 Did our old rule affect the correctness of common operators?
   
   Yes, with those rules several other UTs will fail.
   For example in `test_arith_intset.py:test_mod()`,
   ```
   ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (0, 7))
   ```
   Our rules make it to be
   ```
   (((z*8) + (x*4)) - (8*floordiv(((z*8) + (x*4)), 8))), ((((z*8) + (x*4)) + 3) - (8*floordiv(((z*8) + (x*4)), 8)))
   ```
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] merrymercy commented on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650034842


   @jcf94  Did our old rule affect the correctness of common operators?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 edited a comment on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
jcf94 edited a comment on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650045223


   > @jcf94 Did our old rule affect the correctness of common operators?
   
   Yes, with those rules several other UTs will fail.
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jcf94 commented on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-650045223


   > @jcf94 Did our old rule affect the correctness of common operators?
   
   Yes, with those rules several other UTs will fail, they're actually not always correct.
   For example we have:
   ```
       TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x * c1, c2),
                          c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
                          c2.Eval()->value % c1.Eval()->value == 0 &&
                          CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1));
   ```
   while `floordiv(x * 4 + 4, 8)` cannot be simplified to `floordiv(x * 4, 8)`.
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#issuecomment-649619450


   Thanks @jcf94 we should add a testcase to test_arith_rewrite_simplify, by constructing the case and 
   
   - verify each of the rule added in this PR.
   - Use `isinstance(x, tvm.ir.Ramp)` to assert the ramp node
   - You mentioned a bug in the previous rule, it would be great if the testcase covers the bug you mentioned


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #5924: [Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #5924:
URL: https://github.com/apache/incubator-tvm/pull/5924#discussion_r445642433



##########
File path: src/arith/rewrite_simplify.cc
##########
@@ -722,8 +728,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
       ModularSet bmod = analyzer_->modular_set(b1.Eval());
       int64_t ramp_min = floordiv(bmod->base, c2val);
       int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
-      if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
-        return broadcast(floordiv(b1, c2), lanes).Eval();
+      if (ramp_min == ramp_max) {
+        // If b1 can device c2

Review comment:
       device-> divide

##########
File path: src/arith/rewrite_simplify.cc
##########
@@ -722,8 +728,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
       ModularSet bmod = analyzer_->modular_set(b1.Eval());
       int64_t ramp_min = floordiv(bmod->base, c2val);
       int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
-      if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
-        return broadcast(floordiv(b1, c2), lanes).Eval();
+      if (ramp_min == ramp_max) {
+        // If b1 can device c2
+        if (bmod->coeff % c2val == 0) {
+          return broadcast(floordiv(b1, c2), lanes).Eval();
+        }
+        // If all indices can be guaranteed to settle inside a coeff range
+        if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {

Review comment:
       Please add a unit test in tests_arith_rewrite_simplify to cover this rule.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org