You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@tvm.apache.org by Wei Sun via TVM Discuss <no...@discuss.tvm.ai> on 2020/04/01 10:24:11 UTC

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack


Hi all:

I am learning the TVM CUDA backend. I have a question about how CUDA kernel is launched.

Below is my simple test program:
```
import tvm
from tvm import te
import numpy as np

dtype = "float32"
# GEMM size
M=16;K=8;N=16
# declear algorithm 
k = te.reduce_axis((0, K), 'k') # loop over dimension K
A = te.placeholder((M, K), name='A')
B = te.placeholder((K, N), name='B')
C = te.compute(
           (M, N),
           lambda x, y: te.sum(A[x, k] * B[k, y], axis=k),
           name='C')
# defualt schedule 
s = te.create_schedule(C.op)
#print(tvm.lower(s, [A, B, C], simple_mode=True))
# optimized schedule : tiling
bn = 4 # Tiling size: 4, over M, and N
# outer -> inner
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
#print(tvm.lower(s, [A, B, C], simple_mode=True))
AS =  s.cache_read(A, 'shared',[C])
BS =  s.cache_read(B, 'shared',[C])
s[AS].compute_at(s[C], xo)
s[BS].compute_at(s[C], yo)
s[C].bind(xo, te.thread_axis("blockIdx.x"))
s[C].bind(yo, te.thread_axis("blockIdx.y"))
s[C].bind(xi, te.thread_axis("threadIdx.x"))
s[C].bind(yi, te.thread_axis("threadIdx.y"))
target = 'cuda'
ctx = tvm.context(target, 0)
a = tvm.nd.array(np.random.rand(M, K).astype(dtype), ctx)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), ctx)
# comput C through numpy lib
answer = np.dot(a.asnumpy(), b.asnumpy())

func = tvm.build(s, [A, B, C], target=target, name='mmult')
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), ctx)
# a, b : input matrix, c : resul
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
#print(func.get_source())
dev_module = func.imported_modules[0]
print(dev_module)
print("-----GPU code-----")
print(dev_module.get_source())
```

The generated CUDA code:
```
extern "C" __global__ void mmult_kernel0( float* __restrict__ A,  float* __restrict__ B,  float* __restrict__ C) {
  __shared__ float A_shared[32];
  __shared__ float B_shared[32];
  for (int ax0 = 0; ax0 < 4; ++ax0) {
    for (int ax1 = 0; ax1 < 8; ++ax1) {
      A_shared[(((ax0 * 8) + ax1))] = A[((((((int)blockIdx.x) * 32) + (ax0 * 8)) + ax1))];
    }
  }
  for (int ax01 = 0; ax01 < 8; ++ax01) {
    for (int ax11 = 0; ax11 < 4; ++ax11) {
      B_shared[(((ax01 * 4) + ax11))] = B[((((ax01 * 16) + (((int)blockIdx.y) * 4)) + ax11))];
    }
  }
  C[(((((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 16)) + (((int)blockIdx.y) * 4)) + ((int)threadIdx.y)))] = 0.000000e+00f;
  __syncthreads();
  for (int k = 0; k < 8; ++k) {
    C[(((((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 16)) + (((int)blockIdx.y) * 4)) + ((int)threadIdx.y)))] = (C[(((((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 16)) + (((int)blockIdx.y) * 4)) + ((int)threadIdx.y)))] + (A_shared[(((((int)threadIdx.x) * 8) + k))] * B_shared[(((k * 4) + ((int)threadIdx.y)))]));
  }
}
```

Which is straightforward. But what confused me is that, how this kernel **mmult_kernel0** is launched by host(CPU, LLVM backend). I did not see how blockdim and griddim is configured.
We know normally we launch a CUDA kernel from CPU by:
```
kernel<<<griddim,blockdim>>>(a,b,c)
``` 
How TVM manage this settings?
Could anyone give me some tips?
@tqchen @FrozenGene





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/1) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/7395d8ea959ea826f97e8334951457d098d2885cf9fd0071f5f3a4dc02c97fcf).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by Wei Sun via TVM Discuss <no...@discuss.tvm.ai>.

Hi: 

I am investigating the capability of TVM primitives (CUDA backend). I take CUTLASS as a baseline of highly-optimized CUDA library. 

I think most of optimization techniques used in CUTLASS like tiling, shared_mem management are supported by TVM primitives. 

Streaming is also an important optimization technique I think, but I did not find this property in TVM (python frond-end ). So I am wondering how can we use streaming in TVM stack. I think streaming is an important property for CUDA backend.





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/8) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/1550681d02fe08f5b7844ada341c66ef0c546bcd874566ef84175c1ca2ceced4).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by masahi via TVM Discuss <no...@discuss.tvm.ai>.

I don't know or think if we are exposing CUDA stream abstraction to python frontend. We typically don't care about cuda stream (we don't support any concurrency at runtime).

What is your use case?





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/7) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/09be5af8377c0d4580a11f5fb458243831cc749c6832f2aea327eac6da4404d6).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by Wei Sun via TVM Discuss <no...@discuss.tvm.ai>.

BTW, I am also wondering if TVM stack supports CUDA streaming features like
(https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/)





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/2) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/9cf26925fa243a78e567a6a78d43082cd9f02c8e765e860520fd17df3028656f).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by Wei Sun via TVM Discuss <no...@discuss.tvm.ai>.

Hi: 
Thank you for your help! 
So, based on my understanding for these codes.
in python
```
func(a,b,c)
```
 will call this 
```
void operator() (TVMArgs args,
                TVMRetValue* rv,
                void** void_args) const
```
And grid_dim, block_dim are inferred from **TVMArgs args**( a and b )in this case by 
```
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
```
And we can not manually  set grid_dim, block_dim. 
Am I correct?

Thank you very much!





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/4) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/7dac3f6b174b9875a7681e7d9d5c5bbfe7fd54ffea0e475be54aa6f307d214fa).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by masahi via TVM Discuss <no...@discuss.tvm.ai>.

The answer is we use CUDA driver API to launch kernels from C++ code. ```kernel<<<griddim,blockdim>>>(a,b,c)``` is not the only way to launch kernel and it requires compiling with NVCC.


See 
https://github.com/apache/incubator-tvm/blob/e0122c0ea68043372220e4e02b81692c34832227/src/runtime/cuda/cuda_module.cc#L189

There is a longer explanation of "a life of vector add" from python definition to cuda kernel launch here

 https://docs.tvm.ai/dev/codebase_walkthrough.html#vector-add-example





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/3) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/a4dbe57d33e2287327c4f6ef317b4688979aa3a2c4dd0965d391072a571f3fe3).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by Wei Sun via TVM Discuss <no...@discuss.tvm.ai>.

Hi: 

Thanks for you answer. I will check autotvm to see how it tunes grid/block. Because based on experience, grid/block dims will affect performance.

And another question is that, I see there is arg for **cuda stream**

```
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
```
I didn't find any documents about cuda streaming supports in TVM, could you give me a hints about how we could use streaming? 

Thank you very much!





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/6) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/10ab8e059f5689fac891d5c95dbff32b50104fb9b1771a2ea24d42a16f1df506).

[TVM Discuss] [Questions] How CUDA kernel is launched in TVM stack

Posted by masahi via TVM Discuss <no...@discuss.tvm.ai>.

Correct. You can tweak the schedule to change the launch config, but as a user you shouldn't care about the exact size of grid/block.

 If you really want the best perf, use autotvm to tune your schedule, and the resulting grid/block size is optimal based on real measurament.





---
[Visit Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/5) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/834fe1950ab4bbfdf0a475a8792829e7f4f125678d1562e1680577a72430cc9a).