You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/04/13 19:18:54 UTC

[GitHub] [incubator-mxnet] sxjscience opened a new issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum

sxjscience opened a new issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum
URL: https://github.com/apache/incubator-mxnet/issues/18043
 
 
   The performance of `np.einsum` in GPU is not very good and will usually be 500 times slower than `th.einsum`. Because `einsum` is essential for implementing the attention mechanism used in NLP + CV, we should accelerate the implementation.
   
   Here is the code to profile different implementations of einsum (also in gist: https://gist.github.com/sxjscience/bfda1a8bd2942d93eef5ddf8a15b52b8). The profiling result shows that the following order
   
   **PyTorch einsum > MXNet no-einsum >> MXNet  einsum**
   
   ```python
   
   import mxnet as mx
   import numpy as np
   import torch as th
   import argparse
   mx.npx.set_np()
   
   parser = argparse.ArgumentParser(description='Profile einsum')
   parser.add_argument('--mode', choices=['einsum', 'no_einsum', 'th_einsum'],
                       default='einsum', required=True)
   parser.add_argument('--problem', type=int,
                       choices=[0, 1, 2], help='Problem type.', default=0, required=True)
   args = parser.parse_args()
   
   np.random.seed(100)
   batch_size = 64
   num_heads = 8
   seq_length_A = 100
   seq_length_B = 50
   units = 128
   
   if args.problem == 0:
       lhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_A, units))
       rhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_B, units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bnid,bnjd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.npx.batch_dot(mx_lhs, mx_rhs, transpose_b=True)
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   elif args.problem == 1:
       lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
       rhs = np.random.normal(0, 1, (batch_size, seq_length_B, num_heads, units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bind,bjnd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.npx.batch_dot(mx.np.swapaxes(mx_lhs, 1, 2),
                                  mx.np.swapaxes(mx_rhs, 1, 2),
                                  transpose_b=True)
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   elif args.problem == 2:
       lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units))
       rhs = np.random.normal(0, 1, (seq_length_B, num_heads, units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bind,jnd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.np.matmul(mx.np.swapaxes(mx_lhs, 1, 2),
                              mx.np.transpose(mx_rhs, (1, 2, 0)))
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   
   ```
   
   We profiled three different usages of einsum:
   
   1. (B, K, T0, C) X (B, K, T1, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 41.009ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 198.75us | volta_sgemm_128x64_tn |
      - PyTorch Implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 192.35us | volta_sgemm_128x64_tn |
   
   2. (B, T0, K, C) X (B, T1, K, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 40.665ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 185.76us | volta_sgemm_128x64_tn |
         | 89.519us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) |
      - PyTorch implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 193.02us | volta_sgemm_128x64_tn |
         | 61.967us | \_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_ |
   
   
   3. (B, K, T0, C) X (T1, K, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 40.551ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 322.33us | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_16broadcast_kernelINS0_10mshadow_op8identityEEEJPfS7_N7mshadow5ShapeILi5EEESA_NS_9OpReqTypeEmEEEviDpT0\_|
         | 183.23us | volta_sgemm_128x64_nn |
         | 120.13us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) |
         | 5.3120us | void mxnet::op::cuda::transpose_pseudo2D<float, unsigned long, bool=0>(float*, float, int, int, int, int) |
      - PyTorch Implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 152.16us | volta_sgemm_128x64_tn |
         | 28.704us | \_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_ |
   
   @yzhliu @hzfan @haojin2 @reminisce @szha 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] sxjscience commented on issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum
URL: https://github.com/apache/incubator-mxnet/issues/18043#issuecomment-613053572
 
 
   @ptrendx Would you have time to take a look? I'm planning to use it in the numpy version of GluonNLP and find that our einsum operator's performance is not so good.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] ptrendx commented on issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum

Posted by GitBox <gi...@apache.org>.
ptrendx commented on issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum
URL: https://github.com/apache/incubator-mxnet/issues/18043#issuecomment-613070091
 
 
   Will take a look, but not quite sure when (probably not this week), so if somebody wants to work on it in the meantime, reassign to yourself and go for it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] sxjscience commented on issue #18043: [Performance][Numpy] np.einsum can be 500 times slower than torch.einsum

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #18043:
URL: https://github.com/apache/incubator-mxnet/issues/18043#issuecomment-616339427


   Here are the einsum workloads that are related to attention:
   
   Consider the attention cell with query, key, value -> out
   We denote the batch_size as B, num_heads as `K`, the query_length as `L_q`, the mem_length as `L_m`, key dimension as `C_k`, value dimension as `C_v`. In the new version of GluonNLP, we will support different layouts for attention cell:
   
   - layout = 'NKT'
      - query.shape = `(B, K, L_q, C_k)`
      - key.shape = `(B, K, L_m, C_k)`
      - valule.shape = `(B, K, L_m, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums in the implementation:
      - 'bnic,bnjc->bnij'
      - 'bnic,bnjc->binc'
   
   - layout = 'NTK'
      - query.shape = `(B, L_q, K, C_k)`
      - key.shape = `(B, L_m, K, C_k)`
      - value.shape = `(B, L_m, K, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums:
      - `'binc,bjnc->bnij'`
      - `'bnij,bjnc->binc'`
   
   - layout = 'TNK'
      - query.shape = `(L_q, B, K, C_k)`
      - key.shape = `(L_m, B, K, C_k)`
      - value.shape = `(L_m, B, K, C_v)`
      - out.shape = `(L_q, B, K * C_v)`
   
      We need the following einsums:
      - `'ibnc,jbnc->bnij'`
      - `'bnij,jbnc->ibnc'`
   
   Actually, `out = np.einsum('ibnc,jbnc->bnij', A, B)` can be implemented via a single `cublasGemmStridedBatched` call. Consider the (i, j)th element in the output, we have
   
   ```
   out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
   
   This can be implemented via a single GEMM, thus, the following calculation can be implemented via a single batched GEMM with specific parameters.
   ```
   for i in 1 -> B
      for j in 1 -> K
         out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
    This is actually the technique used in `interleaved_matmul`. Thus, we should be able to get rid of the `interleaved_matmul` when we have accelerated the einsum.
   
   @ptrendx @eric-haibin-lin 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience edited a comment on issue #18043: [Performance][Numpy] np.einsum can be 500 times slower than torch.einsum

Posted by GitBox <gi...@apache.org>.
sxjscience edited a comment on issue #18043:
URL: https://github.com/apache/incubator-mxnet/issues/18043#issuecomment-616339427


   Here are the einsum workloads that are related to attention:
   
   Consider the attention cell with query, key, value -> out
   We denote the batch_size as `B`, num_heads as `K`, the query_length as `L_q`, the mem_length as `L_m`, key dimension as `C_k`, value dimension as `C_v`. In the numpy version of GluonNLP, we will support different layouts for attention cell:
   
   - layout = 'NKT'
      - query.shape = `(B, K, L_q, C_k)`
      - key.shape = `(B, K, L_m, C_k)`
      - valule.shape = `(B, K, L_m, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums in the implementation:
      - 'bnic,bnjc->bnij'
      - 'bnic,bnjc->binc'
   
   - layout = 'NTK'
      - query.shape = `(B, L_q, K, C_k)`
      - key.shape = `(B, L_m, K, C_k)`
      - value.shape = `(B, L_m, K, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums:
      - `'binc,bjnc->bnij'`
      - `'bnij,bjnc->binc'`
   
   - layout = 'TNK'
      - query.shape = `(L_q, B, K, C_k)`
      - key.shape = `(L_m, B, K, C_k)`
      - value.shape = `(L_m, B, K, C_v)`
      - out.shape = `(L_q, B, K * C_v)`
   
      We need the following einsums:
      - `'ibnc,jbnc->bnij'`
      - `'bnij,jbnc->ibnc'`
   
   Actually, `out = np.einsum('ibnc,jbnc->bnij', A, B)` can be implemented via a single `cublasGemmStridedBatched` call. Consider the (i, j)th element in the output, we have
   
   ```
   out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
   
   This can be implemented via a single GEMM, thus, the following calculation can be implemented via a single batched GEMM with specific parameters.
   ```
   for i in 1 -> B
      for j in 1 -> K
         out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
    This is actually the technique used in `interleaved_matmul`. Thus, we should be able to get rid of the `interleaved_matmul` when we have accelerated the einsum.
   
   @ptrendx @eric-haibin-lin 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org