You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/10 00:50:10 UTC

[GitHub] [tvm] yzh119 opened a new pull request #10207: Support sub warp reduction for CUDA target.

yzh119 opened a new pull request #10207:
URL: https://github.com/apache/tvm/pull/10207


   Previously the `LowerThreadAllReduce` pass will only emit code that uses `shfl_down` when reduce extent equals warp size, when reduce extent is less than warp size, the codegen fall back to emit code that uses shared memory, which is not efficient. Consider CUDA supports sub warp reduction by specifying the mask, we can still use shuffle-down approach for reduction by changing the mask.
   
   Example code:
   ```python
   import tvm
   import numpy as np
   from tvm.script import tir as T
   
   
   @T.prim_func
   def reduce(a: T.handle, b: T.handle) -> None:
       A = T.match_buffer(a, [1024, 11])
       B = T.match_buffer(b, [1024])
   
       for i, j in T.grid(1024, 11):
           with T.block("reduce"):
               vi, vj = T.axis.remap("SR", [i, j])
               with T.init():
                   B[vi] = 0.
               B[vi] = B[vi] + A[vi, vj]
   
   sch = tvm.tir.Schedule(reduce)
   blk = sch.get_block("reduce")
   i, j = sch.get_loops(blk)
   sch.bind(i, "blockIdx.x")
   sch.bind(j, "threadIdx.x")
   f = tvm.build(sch.mod["main"], target="cuda")
   print(f.imported_modules[0].get_source())
   ```
   
   
   Emitted code before this PR:
   ```cuda
   extern "C" __global__ void __launch_bounds__(11) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     __shared__ float red_buf0[11];
     __syncthreads();
     ((volatile float*)red_buf0)[(((int)threadIdx.x))] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
     __syncthreads();
     if (((int)threadIdx.x) < 3) {
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 8))]);
     }
     __syncthreads();
     if (((int)threadIdx.x) < 4) {
       float w_4_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 4))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_4_0;
       float w_2_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 2))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_2_0;
       float w_1_0 = (((volatile float*)red_buf0)[(((int)threadIdx.x))] + ((volatile float*)red_buf0)[((((int)threadIdx.x) + 1))]);
       ((volatile float*)red_buf0)[(((int)threadIdx.x))] = w_1_0;
     }
     __syncthreads();
     B[(((int)blockIdx.x))] = ((volatile float*)red_buf0)[(0)];
   }
   ```
   
   Emitted code after this PR:
   ```cuda
   extern "C" __global__ void __launch_bounds__(11) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     float red_buf0[1];
     uint mask[1];
     float t0[1];
     red_buf0[(0)] = A[(((((int)blockIdx.x) * 11) + ((int)threadIdx.x)))];
     mask[(0)] = (__activemask() & (uint)2047);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 8, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], 0, 32);
     B[(((int)blockIdx.x))] = red_buf0[(0)];
   }
   ```
   
   # Future work
   CUDA 11 supports [cooperative group reduction](https://developer.nvidia.com/blog/cuda-11-features-revealed/) which we can directly use.
   
   cc @vinx13 @junrushao1994 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #10207:
URL: https://github.com/apache/tvm/pull/10207


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034562205


   CC @MasterJH5574 I believe you are interested 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1035995100


   > BTW do we have this requirement in the codebase now?
   
   @MasterJH5574 yes there is a notion of `group_extent` and `reduce_extent`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034575574


   There are some issues to be solved:
   
   If in the following case:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1024, 4, 8])
       B = T.match_buffer(b, [1024, 4])
   
       for i, j, k in T.grid(1024, 4, 8):
           with T.block("reduce"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   B[vi, vj] = 0.
               B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
   ```
   we bind `j` to `threadIdx.y` and `k` to `threadIdx.x`, different `j`'s might be mapped to the same warp, we need different masks for different `j` to distinguish them.
   
   Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size might be a multiple of `blockDim.x` when `blockDim.y * blockDim.z != 1`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034535980


   Sure, below is the measured time of the kernel:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1048576, n])
       B = T.match_buffer(b, [1048576])
   
       for i, j in T.grid(1048576, n):
           with T.block("reduce"):
               vi, vj = T.axis.remap("SR", [i, j])
               with T.init():
                   B[vi] = 0.
               B[vi] = B[vi] + A[vi, vj]
   ```
   and change n between 2,4,8,16,32.
   
   | n      | 2                  | 4                 | 8                  | 16                 | 32                 |
   |----------|--------------------|-------------------|--------------------|--------------------|--------------------|
   | shuffle-down time(ms) | 0.836363387  | 0.902631863 | 1.214023657 | 1.249731274 | 1.175273217 |
   | shared mem time(ms) | 0.80920489 | 0.9997110469999999 | 1.076497658  | 1.103504739 | 1.1167795269999998 |
   
   there is some variance across multiple runs. Time evaluated with TVM's native `time_evaluator`, takes the average time of 1000 runs.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034535980


   Sure, below is the measured time of the kernel:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1048576, n])
       B = T.match_buffer(b, [1048576])
   
       for i, j in T.grid(1048576, n):
           with T.block("reduce"):
               vi, vj = T.axis.remap("SR", [i, j])
               with T.init():
                   B[vi] = 0.
               B[vi] = B[vi] + A[vi, vj]
   ```
   and change n between 2,4,8,16,32.
   
   | n      | 2                  | 4                 | 8                  | 16                 | 32                 |
   |----------|--------------------|-------------------|--------------------|--------------------|--------------------|
   | shuffle-down time(ms) | 1.840511957804362  | 1.877586046854655 | 2.1820863087972007 | 2.2471348444620767 | 2.1001497904459634 |
   | shared mem time(ms) | 1.7892122268676758 | 1.922925313313802 | 2.053538958231608  | 2.0630757013956704 | 2.1170775095621743 |
   
   there are some variance across multiple runs.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1039779883


   will leave the PR to @vinx13 and @masahi for a second look :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #10207:
URL: https://github.com/apache/tvm/pull/10207


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034575574


   There are some issues to be solved:
   
   If in the following case:
   ```
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1024, 4, 8])
       B = T.match_buffer(b, [1024, 4])
   
       for i, j, k in T.grid(1024, 4, 8):
           with T.block("reduce"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   B[vi, vj] = 0.
               B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
   ```
   we bind j to `threadIdx.y` and k to `threadIdx.x`, different `j`'s might be mapped to the same warp, we need different masks for different `j` to distinguish them.
   
   Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size might be a multiple of `blockDim.x` when `blockDim.y * blockDim.z != 1`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034575574


   There are some issues to be solved:
   
   If in the following case:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1024, 4, 8])
       B = T.match_buffer(b, [1024, 4])
   
       for i, j, k in T.grid(1024, 4, 8):
           with T.block("reduce"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   B[vi, vj] = 0.
               B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
   ```
   we bind j to `threadIdx.y` and k to `threadIdx.x`, different `j`'s might be mapped to the same warp, we need different masks for different `j` to distinguish them.
   
   Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size might be a multiple of `blockDim.x` when `blockDim.y * blockDim.z != 1`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1039779883


   will leave the PR to @vinx13 and @masahi for a second look :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1035993696


   @MasterJH5574 Ah I think n=4 is the only case shuffle-down worse than shared memory.
   Another benefit of using shuffle-down is reducing the shared memory usage thus increasing the number of blocks can be executed concurrently.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1035995100


   > BTW do we have this requirement in the codebase now?
   @MasterJH5574 yes there is a notion of `group_extent` and `reduce_extent`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034434444


   Do you have any performance results? Also please add testcases


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034535980


   Sure, below is the measured time of the kernel:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1048576, n])
       B = T.match_buffer(b, [1048576])
   
       for i, j in T.grid(1048576, n):
           with T.block("reduce"):
               vi, vj = T.axis.remap("SR", [i, j])
               with T.init():
                   B[vi] = 0.
               B[vi] = B[vi] + A[vi, vj]
   ```
   and change n between 2,4,8,16,32.
   
   | n      | 2                  | 4                 | 8                  | 16                 | 32                 |
   |----------|--------------------|-------------------|--------------------|--------------------|--------------------|
   | shared-mem time(ms) | 0.836363387  | 0.902631863 | 1.214023657 | 1.249731274 | 1.175273217 |
   | shuffle-down time(ms) | 0.80920489 | 0.9997110469999999 | 1.076497658  | 1.103504739 | 1.1167795269999998 |
   
   there is some variance across multiple runs. Time evaluated with TVM's native `time_evaluator`, takes the average time of 1000 runs.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034810548


   @Hzfengsy I write a unit test and find a bug (#10210 ) in original shared memory-based tree reduction, it was fixed in this PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1034575574


   Some other notes:
   
   If in the following case:
   ```python
   @T.prim_func
   def reduce(a: T.handle, b: T.handle, n: T.int32) -> None:
       A = T.match_buffer(a, [1, 4, 8])
       B = T.match_buffer(b, [1, 4])
   
       for i, j, k in T.grid(1, 4, 8):
           with T.block("reduce"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   B[vi, vj] = 0.
               B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
   ```
   we bind `j` to `threadIdx.y` and `k` to `threadIdx.x`, different `j`'s might be mapped to the same warp, we need different masks for different `j` to distinguish them.
   
   Below is an example of generated code:
   ```python
   extern "C" __global__ void __launch_bounds__(32) default_function_kernel0(float* __restrict__ A, float* __restrict__ B) {
     float red_buf0[1];
     uint mask[1];
     float t0[1];
     red_buf0[(0)] = A[(((((int)threadIdx.y) * 8) + ((int)threadIdx.x)))];
     mask[(0)] = (__activemask() & ((uint)(255 << (((int)threadIdx.y) * 8))));
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 4, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 2, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     t0[(0)] = __shfl_down_sync(mask[(0)], red_buf0[(0)], 1, 32);
     red_buf0[(0)] = (red_buf0[(0)] + t0[(0)]);
     red_buf0[(0)] = __shfl_sync(mask[(0)], red_buf0[(0)], (((int)threadIdx.y) * 8), 32);
     B[(((int)threadIdx.y))] = red_buf0[(0)];
   }
   ```
   
   Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size must be a multiple of `blockDim.x` when `blockDim.y * blockDim.z != 1`.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1035982910


   Interesting. Looks like the perf improvement isn't very much? Only when `n = 4` the shuffle-down implementation is better than the shared memory implementation 🤔
   
   > Another thing worth noting is, we can only allow cross warp reduction by shuffle-down, thus warp size must be a multiple of blockDim.x when blockDim.y * blockDim.z != 1.
   
   BTW do we have this requirement in the codebase now?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] yzh119 edited a comment on pull request #10207: Support sub warp reduction for CUDA target.

Posted by GitBox <gi...@apache.org>.
yzh119 edited a comment on pull request #10207:
URL: https://github.com/apache/tvm/pull/10207#issuecomment-1035993696


   > Looks like the perf improvement isn't very much? Only when n = 4 the shuffle-down implementation is better than the shared memory implementation 🤔
   
   My typo, I have fixed it.
   
   Another benefit of using shuffle-down is reducing the shared memory usage thus increasing the number of blocks can be executed concurrently.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org