You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/18 10:06:27 UTC

[GitHub] [incubator-tvm] minminsun commented on a change in pull request #4353: [Perf] Enhance cudnn and cublas backend and enable TensorCore

minminsun commented on a change in pull request #4353: [Perf] Enhance cudnn and cublas backend and enable TensorCore
URL: https://github.com/apache/incubator-tvm/pull/4353#discussion_r347291047
 
 

 ##########
 File path: src/runtime/contrib/cublas/cublas.cc
 ##########
 @@ -124,35 +169,203 @@ struct CublasDgemmBatchOp {
   }
 };
 
+// Check cublas supported mix-precision computation type and return computeType
+bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_support = true) {
+  if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
+    return TypeMatch(in_dtype, kDLInt, 8);
+  } else if (TypeMatch(out_dtype, kDLFloat, 32)) {
+    return TypeMatch(in_dtype, kDLInt, 8) ||
+           TypeMatch(in_dtype, kDLFloat, 16);
+  } else {
+    return false;
+  }
+}
+
+inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) {
+  DLTensor *A = args[0];
+  DLTensor *B = args[1];
+  DLTensor *C = args[2];
+  bool transa = args[3];
+  bool transb = args[4];
+  CHECK_EQ(A->ndim, 2);
+  CHECK_EQ(B->ndim, 2);
+  CHECK_EQ(C->ndim, 2);
+
+  CHECK_EQ(ElementStride(A), 1);
+  CHECK_EQ(ElementStride(B), 1);
+  CHECK_EQ(ElementStride(C), 1);
+
+  CHECK(TypeEqual(A->dtype, B->dtype));
+
+  // C can never be transposed.
+  CHECK(!IsInPlaceTransposed(C));
+
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type";
+  CHECK(!TypeMatch(A->dtype, kDLInt, 8) ||
+      ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
+  CHECK(!TypeMatch(B->dtype, kDLInt, 8) ||
+      ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm";
+  double alpha = args.size() > 5 ? args[5] : 1.0;
+  double beta = args.size() > 6 ? args[6] : 0.0;
+
+  cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
+  cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
+  cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
+  void *alpha_ptr = nullptr, *beta_ptr = nullptr;
+  auto alpha_int = static_cast<int32_t>(alpha);
+  auto beta_int = static_cast<int32_t>(beta);
+  auto alpha_float = static_cast<float>(alpha);
+  auto beta_float = static_cast<float>(beta);
+  if (C->dtype.code == kDLInt) {
+    alpha_ptr = &alpha_int;
+    beta_ptr = &beta_int;
+  } else if (C->dtype.code == kDLFloat) {
+    alpha_ptr = &alpha_float;
+    beta_ptr = &beta_float;
+  }
+
+  auto A_data = reinterpret_cast<void *>(static_cast<char *>(A->data) + A->byte_offset);
+  auto B_data = reinterpret_cast<void *>(static_cast<char *>(B->data) + B->byte_offset);
+  auto C_data = reinterpret_cast<void *>(static_cast<char *>(C->data) + C->byte_offset);
+
+  CHECK_CUBLAS_ERROR(cublasGemmEx(hdl,
+                                 BooleanToTranspose(transb),
+                                 BooleanToTranspose(transa),
+                                 ColumnCount(B, transb),
+                                 RowCount(A, transa),
+                                 ColumnCount(A, transa),
+                                 alpha_ptr,
+                                 B_data, cuda_in_type, ColumnStride(B),
+                                 A_data, cuda_in_type, ColumnStride(A),
+                                 beta_ptr,
+                                 C_data, cuda_out_type, ColumnStride(C),
+                                 cuda_out_type, algo));
 
 Review comment:
   The second to last arg of cublasGemmEx is computation type. In the case where both input type and output type are fp16, we noticed that computation type fp16 results in much lower precision than computation type fp32. So setting computation type the same as output type here may lead to precision dropping for output type fp16/int8.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services