You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/08 04:58:09 UTC

[GitHub] [tvm] ilovetvm opened a new issue #7421: [Bug]Incorrect results if memory scope is set to 'local'

ilovetvm opened a new issue #7421:
URL: https://github.com/apache/tvm/issues/7421


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ilovetvm commented on issue #7421: [Bug] Incorrect results if memory scope is set to 'local'

Posted by GitBox <gi...@apache.org>.
ilovetvm commented on issue #7421:
URL: https://github.com/apache/tvm/issues/7421#issuecomment-795122960


   Thanks for checking this and finding the related code! Do you have any idea on how we could resolve this?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi closed issue #7421: [Bug] Incorrect results if memory scope is set to 'local'

Posted by GitBox <gi...@apache.org>.
masahi closed issue #7421:
URL: https://github.com/apache/tvm/issues/7421


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] leeexyz commented on issue #7421: [Bug] Incorrect results if memory scope is set to 'local'

Posted by GitBox <gi...@apache.org>.
leeexyz commented on issue #7421:
URL: https://github.com/apache/tvm/issues/7421#issuecomment-791995877


   I found the code snippet. In your case, i.outer.inner is bound to threadIdx.y and j.outer.inner is bound to threadIdx.x .
   
   If the scope is **local**, rank is 3, and the rank of threadIdx is 1, the value for this axis is 0. This is not correct.
   https://github.com/apache/tvm/blob/8aa2a7cdbc81a0633b1f78ab28f31921e9fa9e98/src/te/operation/op_utils.cc#L161-L164
   ```c++
   // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
   // attr [iter_var(i.outer.outer.outer, )] loop_scope = 0
   // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
   // attr [iter_var(j.outer.outer.outer, )] loop_scope = 0
   for (i.outer.outer.inner, 0, 2) {
     // attr [iter_var(i.outer.outer.inner, )] loop_scope = i.outer.outer.inner
     for (j.outer.outer.inner, 0, 2) {
       // attr [iter_var(j.outer.outer.inner, )] loop_scope = j.outer.outer.inner
       // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 4
       // attr [iter_var(i.outer.inner, )] loop_scope = 0
       // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 4
       // attr [iter_var(j.outer.inner, )] loop_scope = 0
       for (i.inner, 0, 4) {
         // attr [iter_var(i.inner, )] loop_scope = i.inner
         for (j.inner, 0, 4) {
           // attr [iter_var(j.inner, )] loop_scope = j.inner
           gemm[(i.inner + ((i.outer.outer.inner*4)*4)), (j.inner + ((j.outer.outer.inner*4)*4))] =0f
           for (k, 0, 32) {
             // attr [iter_var(k, )] loop_scope = k
             gemm[(i.inner + ((i.outer.outer.inner*4)*4)), (j.inner + ((j.outer.outer.inner*4)*4))] =(gemm[(i.inner + ((i.outer.outer.inner*4)*4)), (j.inner + ((j.outer.outer.inner*4)*4))] + (A[(i.inner + ((i.outer.outer.inner*4)*4)), k]*B[k, (j.inner + ((j.outer.outer.inner*4)*4))]))
           }
         }
       }
     }
   }
   ```
   
   
   But if the scope is **shared**, rank is 1, the value for this axis is variable threadIdx.
   https://github.com/apache/tvm/blob/8aa2a7cdbc81a0633b1f78ab28f31921e9fa9e98/src/te/operation/op_utils.cc#L176-L178
   ```c++
   // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
   // attr [iter_var(i.outer.outer.outer, )] loop_scope = 0
   // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
   // attr [iter_var(j.outer.outer.outer, )] loop_scope = 0
   for (i.outer.outer.inner, 0, 2) {
     // attr [iter_var(i.outer.outer.inner, )] loop_scope = i.outer.outer.inner
     for (j.outer.outer.inner, 0, 2) {
       // attr [iter_var(j.outer.outer.inner, )] loop_scope = j.outer.outer.inner
       // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 4
       // attr [iter_var(i.outer.inner, )] loop_scope = threadIdx.y
       // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 4
       // attr [iter_var(j.outer.inner, )] loop_scope = threadIdx.x
       for (i.inner, 0, 4) {
         // attr [iter_var(i.inner, )] loop_scope = i.inner
         for (j.inner, 0, 4) {
           // attr [iter_var(j.inner, )] loop_scope = j.inner
           gemm[(i.inner + ((threadIdx.y + (i.outer.outer.inner*4))*4)), (j.inner + ((threadIdx.x + (j.outer.outer.inner*4))*4))] =0f
           for (k, 0, 32) {
             // attr [iter_var(k, )] loop_scope = k
             gemm[(i.inner + ((threadIdx.y + (i.outer.outer.inner*4))*4)), (j.inner + ((threadIdx.x + (j.outer.outer.inner*4))*4))] =(gemm[(i.inner + ((threadIdx.y + (i.outer.outer.inner*4))*4)), (j.inner + ((threadIdx.x + (j.outer.outer.inner*4))*4))] + (A[(i.inner + ((threadIdx.y + (i.outer.outer.inner*4))*4)), k]*B[k, (j.inner + ((threadIdx.x + (j.outer.outer.inner*4))*4))]))
           }
         }
       }
     }
   }
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #7421: [Bug] Incorrect results if memory scope is set to 'local'

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #7421:
URL: https://github.com/apache/tvm/issues/7421#issuecomment-1008435437


   I assume this is no longer active.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] leeexyz commented on issue #7421: [Bug] Incorrect results if memory scope is set to 'local'

Posted by GitBox <gi...@apache.org>.
leeexyz commented on issue #7421:
URL: https://github.com/apache/tvm/issues/7421#issuecomment-796392197


   > Thanks for checking this and finding the related code! Do you have any idea on how we could resolve this?
   
   Not clear. But I will try it.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org