You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/20 11:14:11 UTC

[GitHub] sbodenstein opened a new pull request #13336: GEMM Tensor Core Support

sbodenstein opened a new pull request #13336: GEMM Tensor Core Support
URL: https://github.com/apache/incubator-mxnet/pull/13336
 
 
   This PR adds support for `float32` `linalg.gemm` and `linalg.gemm2` ops to make use of Tensor Cores via implicit casting. It is motivated by there being no way of doing `GEMM` with Tensor Core support in MXNet currently (`linalg.gemm` and `linalg.gemm2` don't support `float16`, and `batch_dot` is super slow for `float16` #11796). It is critical for applications such as Transformers to have a fast version of this operator.
   
   A simple benchmark shows significant speedups on a V100:
   ```
   def benchmark(dtype, dev, shape):
       ctx = mx.gpu()
       a = mx.nd.random.normal(0, 1, shape=shape, ctx=dev, dtype=dtype)
       b = mx.nd.random.normal(0, 1, shape = shape, ctx=dev, dtype=dtype)
       s = mx.symbol.linalg.gemm2(mx.sym.Variable("a"), mx.sym.Variable("b"), transpose_b=False, alpha=1.0)
       e = s.bind(dev, {"a":a,"b":b})
       
       # warmup
       e.forward()
       e.outputs[0].asnumpy()
   
       begin = time.time()
       for i in range(100):
           e.forward()
       e.outputs[0].asnumpy()
   
       end = time.time()
       print(end - begin)
       return None
   shape = (64, 1024, 1024)
   os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "0"
   benchmark("float32", mx.gpu(), shape)
   os.environ["MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION"] = "1"
   benchmark("float32", mx.gpu(), shape)
   ```
   gives `1.153s` and `0.3957s` (so a factor 2.9x speedup). Its also around 50% faster letting cuBLAS do the casting than using MXNet ops to cast to `float16`, do GEMM in `float16`, and then back to `float32`.
   
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services