You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "LeiWang1999 (via GitHub)" <gi...@apache.org> on 2023/04/21 10:10:07 UTC

[GitHub] [tvm] LeiWang1999 commented on pull request #14695: add comment to cuda kernel code about the intended block/grid layout

LeiWang1999 commented on PR #14695:
URL: https://github.com/apache/tvm/pull/14695#issuecomment-1517601264

   This pull request make sense to me. The CUDA kernels generated by TVM do not contain thread and block information, which can be problematic in some situations, particularly when integrating these kernels into other systems.
   
   For instance, in my own implementation, I use an ad-hoc method to extract thread block information from the high-level TVM IR, as shown below:
   
   ```python
   GridDim_z = 1
   GridDim_y = int(sch.get_sref(block_i).stmt.extent)
   GridDim_x = int(sch.get_sref(block_j).stmt.extent)
   BlockDim_y = int(sch.get_sref(i).stmt.extent)
   BlockDim_z = int(sch.get_sref(j).stmt.extent)
   BlockDim_x = warp_size
   ```
   
   incorporating thread block kernel information into kernel comments is a sensible solution. However, I believe the format of the comments requires further discussion. I have also encountered similar systems based on TVM for code generation that address this issue in the same manner. For example, Microsoft's [Antares](https://github.com/microsoft/antares), which is built on TVM, demonstrates this in its code generation.
   
   ```c++
   extern "C" __global__ __launch_bounds__(16) void template_op_kernel0(half* __restrict__ input0, half* __restrict__ output0) {
     // [thread_extent] blockIdx.x = 256
     // [thread_extent] threadIdx.x = 8
     // [thread_extent] blockIdx.y = 256
     // [thread_extent] threadIdx.y = 2
     ((output0[(((((((int)blockIdx.y) * 262144) + (((int)blockIdx.x) * 512)) + (((int)threadIdx.y) * 128)) + (((int)threadIdx.x) * 2)))]) = (input0[(((((((int)blockIdx.y) * 262144) + (((int)threadIdx.y) * 65536)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) * 2)))]));
     ((output0[((((((((int)blockIdx.y) * 262144) + (((int)blockIdx.x) * 512)) + (((int)threadIdx.y) * 128)) + (((int)threadIdx.x) * 2)) + 1))]) = (input0[((((((((int)blockIdx.y) * 262144) + (((int)threadIdx.y) * 65536)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) * 2)) + 1))]));
     ((output0[((((((((int)blockIdx.y) * 262144) + (((int)blockIdx.x) * 512)) + (((int)threadIdx.y) * 128)) + (((int)threadIdx.x) * 2)) + 256))]) = (input0[((((((((int)blockIdx.y) * 262144) + (((int)threadIdx.y) * 65536)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.x) * 2)) + 16))]));
   ...
   }
   
   // Saved Perf = 3.166700e-04 sec / run; Step Produced = 138; Planned Steps = 1000;
   ```
   
   also CC @junrushao please.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org