You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/21 16:29:56 UTC

[GitHub] [tvm] lazycal opened a new pull request, #11803: Support reduce axis=0 in softmax schedule.

lazycal opened a new pull request, #11803:
URL: https://github.com/apache/tvm/pull/11803

   When scheduling 2D softmax, the current cuda schedule assumes the reduction axis to be the last axis, and yields incorrect schedule and raise error messages that are hard to debug. For example, running the follow snippet:
   ```python
   import tvm
   from tvm import relay
   
   shape = (64, 2)
   dtype = 'float32'
   
   A = relay.var('A', shape=shape, dtype=dtype)
   B = relay.nn.softmax(A, axis=0)
   f = relay.Function([A], B)
   mod = tvm.IRModule.from_expr(f)
   
   dev = tvm.cuda()
   target = tvm.target.Target('cuda')
   with tvm.transform.PassContext(opt_level=0):
       executor = relay.build_module.create_executor(
           'graph', mod, dev, target).evaluate()
   ```
   I got 
   ```Check failed: (!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) is false: LowerWarpMemory failed to rewrite load to shuffle for index ((threadIdx.x*5) + (k.inner*2)) local_index=(((threadIdx.x*5) + (k.inner*2))/32)```
   with `opt_level=0` and 
   ```Check failed: (match) is false: iter_var(blockIdx.x, , blockIdx.x) domain already inferred, cannot prove their extents are the same 64 vs 2```
   with `opt_level=4`. 
   
   This PR fixes the schedule to also support axis=0 for all the cuda 2D schedules and enhances the unit testing to test all reduction axes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi merged pull request #11803: [TE Schedule] Fix broken 2D softmax TE schedules when axis=0

Posted by GitBox <gi...@apache.org>.
masahi merged PR #11803:
URL: https://github.com/apache/tvm/pull/11803


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lazycal commented on pull request #11803: [TE Schedule] Fix broken 2D softmax TE schedules when axis=0

Posted by GitBox <gi...@apache.org>.
lazycal commented on PR #11803:
URL: https://github.com/apache/tvm/pull/11803#issuecomment-1162267184

   @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org