You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/20 14:45:16 UTC

[GitHub] [tvm] Johnson9009 opened a new issue #8093: Bug in StorageFlatten Cause Mussy Index

Johnson9009 opened a new issue #8093:
URL: https://github.com/apache/tvm/issues/8093


   Below simple case can reproduce the issue.
   ```python
   import tvm
   from tvm import te, nd
   
   _dtype = tvm.DataType("int8")
   dshape = (1, 14, 14, 1024)
   
   A = te.placeholder(dshape, name="A", dtype=_dtype)
   C = te.compute(dshape, lambda *i: A(*i) + 3, name="C")
   
   s = te.create_schedule(C.op)
   
   c_axis = s[C].fuse(*C.op.axis)
   outer, inner = s[C].split(c_axis, nparts=4)
   outer, inner = s[C].split(inner, 28*1024)
   
   ir_mod = tvm.lower(s, [A, C], name='fadd')
   ```
   
   The IR before and after pass "StorageFlatten" is something like below.
   ```
   PrintIR(Before StorageFlatten):
   primfn(A_1: handle, C_1: handle) -> ()
     attr = {"global_symbol": "fadd", "tir.noalias": True}
     buffers = {C: Buffer(C_2: Pointer(int32), int32, [1, 14, 14, 1024], []),
                A: Buffer(A_2: Pointer(int8), int8, [1, 14, 14, 1024], [])}
     buffer_map = {A_1: A, C_1: C} {
     attr [C] "realize_scope" = "";
     realize(C, [0:1, 0:14, 0:14, 0:1024], True {
       for (i0.i1.fused.i2.fused.i3.fused.outer: int32, 0, 4) {
         for (i0.i1.fused.i2.fused.i3.fused.inner.outer: int32, 0, 2) {
           for (i0.i1.fused.i2.fused.i3.fused.inner.inner: int32, 0, 28672) {
             if @tir.likely((floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14) < 1), dtype=bool) {
               if @tir.likely((floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14) < 14), dtype=bool) {
                 if @tir.likely((floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024) < 196), dtype=bool) {
                   if @tir.likely((((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)) < 200704), dtype=bool) {
                     if @tir.likely(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) < 50176), dtype=bool) {
                       C[floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), floormod(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), floormod(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), floormod(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024)] = (cast(int32, A[floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), floormod(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.
 fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14), floormod(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), floormod(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024)]) + 3)
                     }
                   }
                 }
               }
             }
           }
         }
       }
     })
   }
   
   PrintIR(After StorageFlatten):
   primfn(A_1: handle, C_1: handle) -> ()
     attr = {"global_symbol": "fadd", "tir.noalias": True}
     buffers = {C: Buffer(C_2: Pointer(int32), int32, [1, 14, 14, 1024], []),
                A: Buffer(A_2: Pointer(int8), int8, [1, 14, 14, 1024], [])}
     buffer_map = {A_1: A, C_1: C} {
     for (i0.i1.fused.i2.fused.i3.fused.outer: int32, 0, 4) {
       for (i0.i1.fused.i2.fused.i3.fused.inner.outer: int32, 0, 2) {
         for (i0.i1.fused.i2.fused.i3.fused.inner.inner: int32, 0, 28672) {
           if @tir.likely((floordiv(floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14), 14) < 1), dtype=bool) {
             if @tir.likely((floordiv(floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024), 14) < 14), dtype=bool) {
               if @tir.likely((floordiv(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)), 1024) < 196), dtype=bool) {
                 if @tir.likely((((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) + (i0.i1.fused.i2.fused.i3.fused.outer*50176)) < 200704), dtype=bool) {
                   if @tir.likely(((i0.i1.fused.i2.fused.i3.fused.inner.inner + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28672)) < 50176), dtype=bool) {
                     C_2[((((floordiv((((i0.i1.fused.i2.fused.i3.fused.outer*49) + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28)) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 196)*200704) + (floormod(((i0.i1.fused.i2.fused.i3.fused.inner.outer*2) + floordiv(((i0.i1.fused.i2.fused.i3.fused.outer*49) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)), 14)*14336)) + (floormod(((i0.i1.fused.i2.fused.i3.fused.outer*49) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)*1024)) + floormod(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024))] = (cast(int32, (int8*)A_2[((((floordiv((((i0.i1.fused.i2.fused.i3.fused.outer*49) + (i0.i1.fused.i2.fused.i3.fused.inner.outer*28)) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 196)*200704) + (floormod(((i0.i1.fused.i2.fused.i3.fused.inner.outer*2) + floordiv(((i0.i1.fused.i2.fused.i3.fused.outer*49) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)), 14)*14336)) + (floormod(
 ((i0.i1.fused.i2.fused.i3.fused.outer*49) + floordiv(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024)), 14)*1024)) + floormod(i0.i1.fused.i2.fused.i3.fused.inner.inner, 1024))]) + 3)
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   I can't found a smaller case can reproduce this issue, and the data shape is one of real ResNet-50 layer.
   The fix to it can be simple, just like something be done in this temporary and trial PR #8090.
   I have done some debugging and analyzing of it.
   ```
   1st Dimension Index:   floordiv(floordiv(floordiv((k + j*28672 + i*50176), 1024), 14), 14)
   2nd Dimension Index:  floormod(floordiv(floordiv((k + j*28672 + i*50176), 1024), 14), 14)
   3rd Dimension Index:  floormod(floordiv((k + j*28672 + i*50176), 1024), 14)
   4th Dimension Index:  floormod((k + j*28672 + i*50176), 1024)
   
   Now merge the 1st and 2nd dimension:
   The merged expression is 1st_dim_index * 14 + 2nd_dim_index.
   pick out the common part of 1st_dim_index and 2nd_dim_index and set it "x1".
   x1 = floordiv(floordiv((k + j*28672 + i*50176), 1024), 14)
   x2 = floordiv(x1, 14)*14
   x3 = floormod(x1, 14)
   
   Then the whole merged expression is "x2 + x3", obviously it can be simplified to "x1".
   
   Round 1:
   1. x1 of x2 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 14)
      We can see the (i*50176) part is moved out.
   2. x2 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 196)*14
      We can see the two "14" is merged together as "196" because of floordiv(floordiv(xxx, 14), 14).
   3. x1 of x3 -> floordiv((floordiv((k + j*28672), 1024) + i*49), 14)
   4. x3 -> floormod(floordiv((floordiv((k + j*28672), 1024) + i*49), 14), 14)
   
   Status:
   x1 = floordiv((k + j*28672), 1024) + i*49
   x2 = floordiv(x1, 196)*14
   x3 = floormod(floordiv(x1, 14), 14)
   
   all: x2 + x3
   Now the above simply make the simplify rule "floordiv(xxx, c) * c + floormod(xxx, c)" of add node can't be applied. 
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on issue #8093: Bug in StorageFlatten Cause Mussy Index

Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #8093:
URL: https://github.com/apache/tvm/issues/8093#issuecomment-854705597


   gentle ping @Johnson9009 to see if you have more followups :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Johnson9009 commented on issue #8093: Bug in StorageFlatten Cause Mussy Index

Posted by GitBox <gi...@apache.org>.
Johnson9009 commented on issue #8093:
URL: https://github.com/apache/tvm/issues/8093#issuecomment-855827225


   @tqchen Sorry for slow action, this issue has nothing to do with the idea we talk about in #8090, my fix of this issue maybe not the best one, but it is a simple and workable one, if you have better one, please let me know, thanks.
   By the way, the idea we talk about before in #8090, after more testing, I found the code of mul mod merging can't be replaced by the simplifiers of analyzer, as you said, the simplifiers can't work when the shape are symbolic.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen closed issue #8093: Bug in StorageFlatten Cause Mussy Index

Posted by GitBox <gi...@apache.org>.
tqchen closed issue #8093:
URL: https://github.com/apache/tvm/issues/8093


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org