You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tvm.apache.org by Siyuan Feng <no...@github.com> on 2019/10/03 05:14:15 UTC

[dmlc/tvm] [RFC] Tensor Core Support (#4052)

Tensor Core is a defining feature of the NVIDIA new Volta and Turing GPU Architecture, which gives a massive boost for matrix multiplication and convolution. Tensor Cores enable us to use mixed-precision to achieve higher throughput without sacrificing accuracy.

## Tensor Core Overview
Each Tensor Core provides a 4×4×4 matrix processing array that operates `D = A * B + C`, where `A`, `B`, `C` and `D` are 4×4 matrices as Figure shows. The matrix multiply inputs A and B are FP16 matrices, while the accumulation matrices C and D may be FP16 or FP32 matrices.

![image](https://user-images.githubusercontent.com/25500082/66098090-1263f680-e556-11e9-8fc3-0c97917d43c9.png)
However, CUDA programmers can only use warp-level primitive `wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform 16×16×16 half-precision matrix multiplication on tensor cores. Before invoking the matrix multiplication, programmers must load data from memory into registers with primitive `wmma::load_matrix_sync`, explicitly. The NVCC compiler translates that primitive into multiple memory load instructions. At run time, every thread loads 16 elements from matrix A and 16 elements from B.

# Proposed Design
It can be regarded as a new hardware instruction just like gemm instruction in vta. So it is easy to use `tensorization` to replace the code. Note that unlike other accelerators, we need also to consider the shared memory when we use tensor cores. Also, `wmma::mma_sync` is a wrap-level instruction, which means it will call all threads (32 threads) in a warp. It is a brand new schedule level.

## Warp Level Schedule
Although `wmma::mma_sync` is a warp-level operator, NVIDIA doesn't change the API for kernel launch. It still uses `gridDim`, `blockDim` and `dynamic shared memory` (optional) to launch a kernel. The only thing we should do is ensuring `blockDim.x` be a multiple of warp size(32).

In tvm schedule, we can just make the extent of `threadIdx.x` equals 32. It's safe if we want to use `threadIdx.y` and `threadIdx.z`, and their extents have no extra constraint. Note that, `threadIdx.x` can be only used at memory copy or other thread-level operators.

## New Memory Scope
As mentioned above, programmers must load data from memory into a new memory scope `wmma::fragment` before using `wmma::mma_sync`. There are three types of fragment: `matrix_a`, `matrix_b` and `accumulator`. So I create three new build-in memory scope in tvm: `wmma.matrix_a`, `wmma.matrix_b` and `wmma.accumulator`.

## Memory layout
For now, we must relayout before launching the kernel. The input and output matrix shape is `[n //16, m //16, 16, 16]`, which is the same as the vta input and output. The native Cuda API does support the native shape of [n, m], so we can drop this constraint.

## Tensor Intrinsic
Here is a tensor intrinsic example for mma_sync
```python
def intrin_wmma_gemm():
    n = 16
    A = tvm.placeholder((n, n), name='A', dtype='float16')
    B = tvm.placeholder((n, n), name='B', dtype='float16')
    k = tvm.reduce_axis((0, n), name="k")
    C = tvm.compute((n, n),
                    lambda ii, jj:
                    tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
                    name='C')
    BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
    BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
    BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)

    def intrin_func(ins, outs):
        BA, BB = ins
        BC, = outs

        def init():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, BC.elem_offset // 256, 0.0))
            return ib.get()

        def update():
            ib = tvm.ir_builder.create()
            ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
                                    BC.data, BC.elem_offset // 256,
                                    BA.data, BA.elem_offset // 256,
                                    BB.data, BB.elem_offset // 256,
                                    BC.data, BC.elem_offset // 256))
            return ib.get()

        return update(), init(), update()

    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
```

# Performance
The speed test of 4096×4096×4096 mixed-precision matrix multiplication. The test is running on a TITAN V GPU
- tvm w/o tensor core: 11.7415 ms
- cublas w/o tensor core: 11.462592 ms
- tvm w/ tensor core: 2.795257 ms
- cublas w/ tensor core: 1.787328 ms

# Roadmap
- [x] schedule for gemm
- [ ] schedule for conv2d
- [ ] add support for col-major matrix
- [ ] add support for native layout

# Example code and schedule
https://gist.github.com/Hzfengsy/2b13215a926ae439515cc70b4e7027e3

Comments welcome!

cc @tqchen @tmoreau89 @merrymercy 

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by "Minmin Sun (孙敏敏)" <no...@github.com>.
@Hzfengsy Sure, we will show the code as well as a sample schedule very soon. It's being under internal review now. As you will see, the schedule for TensorCore CodeGen looks no different than a normal matmul schedule for GPU. Everything is done in IR passes including matrix_a/matrix_b/accumulator recognition, row/col_major recgnition as @yangjunpro mentioned, thread index unification within a warp for tensorcore operations, loop scaling etc..

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537872194

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Siyuan Feng <no...@github.com>.
@yangjunpro Really happy to see another solution for TensorCore. 

You are right! I just extend tvm intrinsic to support it. It does cause programmers who write the schedule some trouble. It is not easy to write a high-performance schedule.

I'm really curious about how to use IR passes to recognize the pattern. Does it need to split into several loops of 16 in python code? I appreciate it if you can show me some details and simple examples

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537821079

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Jon Soifer <no...@github.com>.
Would it be easy to extend your gemm schedule into a schedule for BatchMatMul? That would help round out the TensorCore story for matrix multiplication.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-538852791

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Jun Yang <no...@github.com>.
Nice to see other folks working on adding TensorCore support into TVM, we have also been working on enhancing TVM to incorporate TensorCore schedule support. 
If my understanding is correct, @Hzfengsy your solution is based on extending TVM's intrinsic while our solution put most of the complexity into TVM IR passes so that at python code level, we don't need to consider too much about TensorCore stuffs, rather the TVM core will take care of most of the TensorCore pattern recognition and schedule customization work, such as col-major/row-major pattern recognition and etc.

@minmin.sun may provide more comments. 

Thanks

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537808119

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Siyuan Feng <no...@github.com>.
@soiferj Thank you for such a helpful comment. I have just made the extension into the schedule for BatchMatMul. You can check the schedule in my fork repo: https://github.com/Hzfengsy/tvm/blob/master/tests/python/unittest/test_schedule_tensor_core.py#L101

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-539355844

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Bing Xu <no...@github.com>.
cc @Laurawly 

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537808438

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Thierry Moreau <no...@github.com>.
Very welcome work @Hzfengsy ! 

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537791018

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Siyuan Feng <no...@github.com>.
Closed #4052.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#event-2753595541

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Siyuan Feng <no...@github.com>.
@tmoreau89 Exactly! For now, we use the NCHWnc layout, the same layout with VTA.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-537816661

Re: [dmlc/tvm] [RFC] Tensor Core Support (#4052)

Posted by Siyuan Feng <no...@github.com>.
I have chatted with @minminsun and his team these days. Just as then mentioned https://github.com/dmlc/tvm/issues/4105#issuecomment-542032766. We can have different frontends but only one backend. In my previous implement, users can only use fragments with 16x16x16 shape and row-major layout. To solve this problem, Minmin uses `new_expr`.  Here I proposed a new design here. We use attributes to transmit metadata. Here is an example:
```
// attr [A.shared.wmma.matrix_a] storage_scope = "wmma.matrix_a"
// attr [A.shared.wmma.matrix_a] fragment_layout = "row_major"
// attr [A.shared.wmma.matrix_a] fragment_shape = "16, 16, 16"
allocate A.shared.wmma.matrix_a[float16 * 1024]
```
This sulotion has been accepted by Minmin and his team. Thanks for supporting my propsal.
Users can set these configuration in tensor intrinsics
``` Python
def intrin_wmma_load_matrix(scope):
    n = 16
    A = tvm.placeholder((n, n), name='A', dtype='float16')
    BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
    C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
    BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)

    def intrin_func(ins, outs):
        ib = tvm.ir_builder.create()

        BA = ins[0]
        BC = outs[0]
        # shape (n, n, n) and 'row_major' layout
        ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
                                BC.data, n, n, n, BC.elem_offset // 256,
                                BA.access_ptr('r'), n, 'row_major'))
        return ib.get()

    return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
```
There is a new IR_pass will transmit these information from intrinsic to attributes.

I am really happy to cooperate with Minmin's team. Thank you again for contributing to TVM.
cc @tqchen 

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4052#issuecomment-542488197