You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by GitBox <> on 2018/11/21 21:09:15 UTC

[GitHub] eric-haibin-lin closed pull request #13336: GEMM Tensor Core Support

eric-haibin-lin closed pull request #13336: GEMM Tensor Core Support

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index d7c3c651916..4e63f61f105 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -315,9 +315,57 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
                               &beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0))) \
-  LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
   LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
+  #if CUDA_VERSION < 9010
+  LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
+  #else
+    template <>
+    inline void linalg_batch_gemm<gpu, float>(const Tensor<gpu, 3, float>& A,
+                                              const Tensor<gpu, 3, float>& B,
+                                              const Tensor<gpu, 3, float>& C,
+                                              float alpha, float beta, bool tA,
+                                              bool tB, Stream<gpu>* s) {
+      using namespace mxnet;
+      using mshadow::gpu;
+      CHECK_NOTNULL(s);
+      linalg_check_batch_size(A.size(0), B.size(0), C.size(0));
+      check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB);
+      auto blas_handle = Stream<gpu>::GetBlasHandle(s);
+      bool use_tensor_ops =
+          GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion();
+      using namespace mshadow::cuda;
+      auto cublas_math_mode =
+          use_tensor_ops ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
+      auto previous_math_mode = SetCublasMathMode(blas_handle, cublas_math_mode);
+      // cublasGemmStridedBatchedEx is only supported for GPU with architecture
+      // capabilities equal or greater than 5.0. Fall back to
+      // cublasSgemmStridedBatched, which doesn't support implicit conversion
+      // to half-precision to use TensorCores
+      auto cc_major = (s->prop).major;
+      if ((cc_major >= 5) && use_tensor_ops) {
+        CUBLAS_CALL(cublasGemmStridedBatchedEx(
+            blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
+            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(2), C.size(1),
+            (tB ? B.size(2) : B.size(1)), &alpha, B.dptr_, CUDA_R_32F,
+            B.stride_, B.size(1) * B.stride_, A.dptr_, CUDA_R_32F, A.stride_,
+            A.size(1) * A.stride_, &beta, C.dptr_, CUDA_R_32F, C.stride_,
+            C.size(1) * C.stride_, A.size(0), CUDA_R_32F,
+      } else {
+        CUBLAS_CALL(cublasSgemmStridedBatched(
+            blas_handle, (tB ? CUBLAS_OP_T : CUBLAS_OP_N),
+            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), C.size(2), C.size(1),
+            (tB ? B.size(2) : B.size(1)), &alpha, B.dptr_, B.stride_,
+            B.size(1) * B.stride_, A.dptr_, A.stride_, A.size(1) * A.stride_,
+            &beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0)));
+      }
+      SetCublasMathMode(blas_handle, previous_math_mode);
+    }
+  #endif  // CUDA_VERSION < 9010
 // Version where matrix rows are given by second axis.
 #define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
   template<> inline \
diff --git a/src/operator/ b/src/operator/
index 9ba764904e1..82b03c0fafc 100644
--- a/src/operator/
+++ b/src/operator/
 .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
 implemented, with both multi-layer and bidirectional support.
+When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
+and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
+pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
+Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
 **Vanilla RNN**
 Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index f8a130d0ce4..0f3c2954a0f 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -62,6 +62,11 @@ calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, *B
 without the overhead of the additional swapaxis operations.
+When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
+and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
+pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
+Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
 .. note:: The operator supports float32 and float64 data types only.
@@ -134,6 +139,11 @@ calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, ax
 without the overhead of the additional swapaxis operations.
+When the input data is of type float32 and the environment variables MXNET_CUDA_ALLOW_TENSOR_CORE
+and MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION are set to 1, this operator will try to use
+pseudo-float16 precision (float32 math with float16 I/O) precision in order to use
+Tensor Cores on suitable NVIDIA GPUs. This can sometimes give significant speedups.
 .. note:: The operator supports float32 and float64 data types only.
diff --git a/tests/python/unittest/ b/tests/python/unittest/
index 1bf9ca0237a..6babd50633e 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -29,6 +29,7 @@
 from mxnet.base import py_str, MXNetError, _as_list
 from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
 import unittest
+import os
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4):
     dshape = (N, T, I)
@@ -5224,46 +5225,10 @@ def test_deformable_psroipooling():
                                                grad_nodes=grad_nodes, ctx=mx.gpu(0))
-# Helper functions for test_laop
-def _make_symm_symbol(a, ndims):
-    assert ndims >= 2
-    tr_shape = list(range(ndims))
-    tr_shape[-1] = ndims-2
-    tr_shape[-2] = ndims-1
-    tr_shape = tuple(tr_shape)
-    return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))
-def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
-    assert ndims >= 2
-    # The last two dimensions must both be m
-    # Create mask for lower triangle and diagonal
-    index = mx.sym.arange(start=0, stop=m, step=1, dtype=np.int32)
-    lt_mask = mx.sym.one_hot(index, depth=m, dtype=dtype)
-    for j in range(1, m):
-        part1 = mx.sym.zeros(shape=(j, m), dtype=dtype)
-        index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
-        part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
-        lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
-    if not lower:
-        lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
-        lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
-    shp = tuple([1]*(ndims-2) + [m, m])
-    lt_mask = mx.sym.reshape(lt_mask, shape=shp)
-    return mx.sym.broadcast_mul(a, lt_mask)
-# @ankkhedia: Getting rid of fixed seed as flakiness could not be reproduced
-# tracked at
-def test_laop():
-    dtype = np.float64
-    rtol_fw = 1e-7
-    atol_fw = 1e-9
+def _gemm_test_helper(dtype, grad_check, rtol_fw = 1e-7, atol_fw = 1e-9):
     num_eps = 1e-6
     rtol_bw = 1e-5
     atol_bw = 1e-6
-    # enable numerical checking of gradients
-    grad_check = 1
     data1 = mx.symbol.Variable('data1')
     data2 = mx.symbol.Variable('data2')
@@ -5278,15 +5243,14 @@ def test_laop():
     rep_3x = lambda a, m, n :\
         np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))
-    # Test gemm separately from other la-operators.
     shape1 = (2, 3)
     shape2 = (3, 2)
     shape3 = (3, 3)
     shape4 = (2, 2)
-    data_in1 = np.random.uniform(1, 10, shape1)
-    data_in2 = np.random.uniform(1, 10, shape2)
-    data_in3 = np.random.uniform(1, 10, shape3)
-    data_in4 = np.random.uniform(1, 10, shape4)
+    data_in1 = np.random.uniform(1, 10, shape1).astype(dtype)
+    data_in2 = np.random.uniform(1, 10, shape2).astype(dtype)
+    data_in3 = np.random.uniform(1, 10, shape3).astype(dtype)
+    data_in4 = np.random.uniform(1, 10, shape4).astype(dtype)
     # Check all transpositions of gemm operator.
     data_in1_t = np.transpose(data_in1)
     data_in2_t = np.transpose(data_in2)
@@ -5388,7 +5352,71 @@ def test_laop():
     if grad_check == 1:
         check_grad(test_gemm, [a2, b2])
-    # Now test all the other operators.
+# Test gemm separately from other la-operators.
+def test_gemm():
+    _gemm_test_helper(np.float64, True)
+    _gemm_test_helper(np.float32, False, rtol_fw = 1e-5, atol_fw = 1e-7)
+    if default_context().device_type == 'gpu':
+        _gemm_test_helper(np.float32, False, rtol_fw = 2e-5, atol_fw = 2e-7)
+# Helper functions for test_laop
+def _make_symm_symbol(a, ndims):
+    assert ndims >= 2
+    tr_shape = list(range(ndims))
+    tr_shape[-1] = ndims-2
+    tr_shape[-2] = ndims-1
+    tr_shape = tuple(tr_shape)
+    return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))
+def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
+    assert ndims >= 2
+    # The last two dimensions must both be m
+    # Create mask for lower triangle and diagonal
+    index = mx.sym.arange(start=0, stop=m, step=1, dtype=np.int32)
+    lt_mask = mx.sym.one_hot(index, depth=m, dtype=dtype)
+    for j in range(1, m):
+        part1 = mx.sym.zeros(shape=(j, m), dtype=dtype)
+        index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
+        part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
+        lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
+    if not lower:
+        lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
+        lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
+    shp = tuple([1]*(ndims-2) + [m, m])
+    lt_mask = mx.sym.reshape(lt_mask, shape=shp)
+    return mx.sym.broadcast_mul(a, lt_mask)
+# @ankkhedia: Getting rid of fixed seed as flakiness could not be reproduced
+# tracked at
+def test_laop():
+    dtype = np.float64
+    rtol_fw = 1e-7
+    atol_fw = 1e-9
+    num_eps = 1e-6
+    rtol_bw = 1e-5
+    atol_bw = 1e-6
+    # enable numerical checking of gradients
+    grad_check = 1
+    data1 = mx.symbol.Variable('data1')
+    data2 = mx.symbol.Variable('data2')
+    data3 = mx.symbol.Variable('data3')
+    check_fw = lambda sym, location, expected :\
+        check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
+                               atol=atol_fw, dtype=dtype)
+    check_grad = lambda sym, location:\
+        check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
+                               atol=atol_bw, dtype=dtype)
+    rep_3x = lambda a, m, n :\
+        np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))
     for lower in [True, False]:
         upper = not lower


This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:

With regards,
Apache Git Services