You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/26 03:20:59 UTC

[GitHub] [tvm] wyc-ruiker opened a new pull request #8137: [Codegen][CUDA] Fix make_int4x cuda codegen vectorize

wyc-ruiker opened a new pull request #8137:
URL: https://github.com/apache/tvm/pull/8137


   Added support for int4x32 int4x16 int4x4 in BroadcastNode.
   
   In the int4x4 testcase, the IR is:
   ```
   primfn(compute_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {compute: Buffer(compute_2: Pointer(int4), int4, [64, 4], [])}
     buffer_map = {compute_1: compute} {
     attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 64;
     compute_2[ramp((blockIdx.x*4), 1, 4)] = broadcast(1i4, 4)
   }
   ```
   Before the fix in codegen_c.cc, the codegen cuda is:
   ```
   extern "C" __global__ void make_int4x4_kernel0(int* __restrict__ compute) {
     ((int16_t*)(compute + ((((int)blockIdx.x) * 4)) / 8))[0] = (int16_t)4369;
   }
   ```
   For int16_t, this index `(((int)blockIdx.x) * 4)) / 8` is a bug. 
   After the fix in codegen_c.cc, the codegen cuda is:
   ```
   extern "C" __global__ void make_int4x4_kernel0(int* __restrict__ compute) {
     ((int16_t*)(compute) + ((((int)blockIdx.x) * 4)) / 4)[0] = (int16_t)4369;
   }
   ```
   Could you please help review this fix? @vinx13 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #8137: [Codegen][CUDA] Fix make_int4x cuda codegen vectorize

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #8137:
URL: https://github.com/apache/tvm/pull/8137


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #8137: [Codegen][CUDA] Fix make_int4x cuda codegen vectorize

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #8137:
URL: https://github.com/apache/tvm/pull/8137#issuecomment-849010013


   @vinx13 please help to manage this PR


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org