You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/15 20:12:26 UTC

[GitHub] piiswrong closed pull request #8846: Batching improvements for GEMM/TRSM operators and full MKL usage docs.

piiswrong closed pull request #8846: Batching improvements for GEMM/TRSM operators and full MKL usage docs.
URL: https://github.com/apache/incubator-mxnet/pull/8846
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/MKL_README.md b/MKL_README.md
index 80a31c9a40..0f97416ac3 100644
--- a/MKL_README.md
+++ b/MKL_README.md
@@ -1,3 +1,22 @@
+# Full MKL Installation
+
+## Build/Install MXNet with a full MKL installation:
+Installing and enabling the full MKL installation enables MKL support for all operators under the linalg namespace.
+
+  1. Download and install the latest full MKL version following instructions on the [intel website.](https://software.intel.com/en-us/articles/intel-mkl-111-install-guide)
+
+  2. Set USE_BLAS=mkl in make/config.mk
+
+        1.1 Set ADD_LDFLAGS=-L<path/to/mkl/lib/folder> (ex. ADD_LDFLAGS=-L/opt/intel/compilers_and_libraries_2018.0.128/linux/mkl/lib)
+
+        1.1 Set ADD_CFLAGS=-I<path/to/mkl/include/folder> (ex. ADD_CFLAGS=-L/opt/intel/compilers_and_libraries_2018.0.128/linux/mkl/include)
+
+  3. Run 'make -j ${nproc}'
+
+  4. Navigate into the python directory
+
+  5. Run 'sudo python setup.py install'
+
 # MKL2017 PLUGIN
 
 MKL2017 is an INTEL released library to accelerate Deep Neural Network (DNN) applications on Intel architecture.
diff --git a/make/config.mk b/make/config.mk
index 9f7564b88f..a322fee062 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -110,21 +110,13 @@ USE_LAPACK = 1
 # path to lapack library in case of a non-standard installation
 USE_LAPACK_PATH =
 
-# by default, disable lapack when using MKL
-# switch on when there is a full installation of MKL available (not just MKL2017/MKL_ML)
-ifeq ($(USE_BLAS), mkl)
-USE_LAPACK = 0
-endif
-
 # add path to intel library, you may need it for MKL, if you did not add the path
 # to environment variable
 USE_INTEL_PATH = NONE
 
 # If use MKL only for BLAS, choose static link automatically to allow python wrapper
-ifeq ($(USE_MKL2017), 0)
 ifeq ($(USE_BLAS), mkl)
 USE_STATIC_MKL = 1
-endif
 else
 USE_STATIC_MKL = NONE
 endif
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index b3e6573f78..b2a672ffd9 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -69,14 +69,14 @@ void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2
                 A.dptr_, A.stride_, B.dptr_, B.stride_, beta, C.dptr_, C.stride_); \
 }
 
-#define LINALG_CPU_BATCH_GEMM(DType) \
+#define LINALG_XPU_BATCH_GEMM(xpu, DType) \
 template<> inline \
-void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
-                                   const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \
-                                   bool tA, bool tB, Stream<cpu> *s) { \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, \
+                                   const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
   linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
   for (index_t i = 0; i < A.size(0); ++i) { \
-    linalg_gemm(A[i], B[i], C[i], alpha, beta, tA, tB); \
+    linalg_gemm(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
   } \
 }
 
@@ -90,11 +90,11 @@ void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2
   LOG(FATAL) << "linalg_gemm (without req arg) not implemented by mxnet for cpu, needs cblas!"; \
 }
 
-#define LINALG_CPU_BATCH_GEMM(DType) \
+#define LINALG_XPU_BATCH_GEMM(xpu, DType) \
 template<> inline \
-void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
-                                   const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \
-                                   bool tA, bool tB, Stream<cpu> *s) { \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, \
+                                   const Tensor<xpu, 3, DType>& C, DType alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
   LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
 }
 
@@ -103,8 +103,8 @@ void linalg_batch_gemm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<
 LINALG_CPU_GEMM(sgemm, float)
 LINALG_CPU_GEMM(dgemm, double)
 
-LINALG_CPU_BATCH_GEMM(float)
-LINALG_CPU_BATCH_GEMM(double)
+LINALG_XPU_BATCH_GEMM(cpu, float)
+LINALG_XPU_BATCH_GEMM(cpu, double)
 
 // Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
 template<> inline
@@ -119,13 +119,6 @@ void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half:
 
 #ifdef __CUDACC__
 
-template<typename DType>
-__global__ void linalgCollectBatchOffsetsGPU(DType *a[], DType* b, int stride, int N) {
-    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
-      a[i] = b + i * stride;
-    }
-}
-
 // cublas col-major processing accounted for by switching first two operands
 
 #define LINALG_GPU_GEMM(fname, DType) \
@@ -195,43 +188,36 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
 #endif  // CUDA_VERSION >= 7050
 }
 
-
+// As of cuda8, cublas has implemented a strided version of batch gemm.
+#if CUDA_VERSION < 8000
+  LINALG_XPU_BATCH_GEMM(gpu, float)
+  LINALG_XPU_BATCH_GEMM(gpu, double)
+#else
 #define LINALG_GPU_BATCH_GEMM(fname, DType) \
-template<> inline \
-void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
-                                   const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \
-                                   bool tA, bool tB, Stream<gpu> *s) { \
-  using namespace mxnet; \
-  using mshadow::gpu; \
-  CHECK_NOTNULL(s); \
-  linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
-  check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
-  Storage::Handle offsetsA, offsetsB, offsetsC; \
-  offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \
-  offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \
-  offsetsC = Storage::Get()->Alloc(sizeof(DType*)*C.size(0), Context::GPU()); \
-  using namespace mshadow::cuda; \
-  int ngrid = std::min(kMaxGridNum, \
-                       static_cast<int>((A.size(0) + kBaseThreadNum - 1) / kBaseThreadNum)); \
-  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
-    (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, A.size(0)); \
-  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
-    (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, B.size(0)); \
-  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
-    (static_cast<DType **>(offsetsC.dptr), C.dptr_, C.size(1)*C.stride_, C.size(0)); \
-  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
-                            (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
-                            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
-                            C.size(2), C.size(1), (tB ? B.size(2) : B.size(1)), \
-                            &alpha, static_cast<const DType **>(offsetsB.dptr), B.stride_, \
-                            static_cast<const DType **>(offsetsA.dptr),  A.stride_, \
-                            &beta, static_cast<DType **>(offsetsC.dptr), C.stride_, A.size(0))) \
-  Storage::Get()->Free(offsetsA); \
-  Storage::Get()->Free(offsetsB); \
-  Storage::Get()->Free(offsetsC); \
-}
-LINALG_GPU_BATCH_GEMM(SgemmBatched, float)
-LINALG_GPU_BATCH_GEMM(DgemmBatched, double)
+  template<> inline \
+  void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                     const Tensor<gpu, 3, DType>& B, \
+                                     const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \
+                                     bool tA, bool tB, Stream<gpu> *s) { \
+    using namespace mxnet; \
+    using mshadow::gpu; \
+    CHECK_NOTNULL(s); \
+    linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+    check_gemm(A[0], B[0], C[0], alpha, beta, tA, tB); \
+    using namespace mshadow::cuda; \
+    CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                              (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                              (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                              C.size(2), C.size(1), (tB ? B.size(2) : B.size(1)), \
+                              &alpha, B.dptr_, B.stride_, B.size(1) * B.stride_, \
+                              A.dptr_,  A.stride_, A.size(1) * A.stride_, \
+                              &beta, C.dptr_, C.stride_, C.size(1) * C.stride_, A.size(0))) \
+  }
+
+  LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
+  LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
+
+#endif  // CUDA < 8000
 
 #endif  // __CUDACC__
 
@@ -266,13 +252,13 @@ void linalg_trsm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2
                 A.stride_, B.dptr_, B.stride_); \
 }
 
-#define LINALG_CPU_BATCH_TRSM(DType) \
+#define LINALG_XPU_BATCH_TRSM(xpu, DType) \
 template<> inline \
-void linalg_batch_trsm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
-                   DType alpha, bool rightside, bool lower, bool transpose, Stream<cpu> *s) { \
+void linalg_batch_trsm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, \
+                   DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) { \
   linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \
   for (index_t i = 0; i < A.size(0); ++i) { \
-    linalg_trsm(A[i], B[i], alpha, rightside, lower, transpose); \
+    linalg_trsm(A[i], B[i], alpha, rightside, lower, transpose, s); \
   } \
 }
 
@@ -285,10 +271,10 @@ void linalg_trsm<cpu, DType>(const Tensor<cpu, 2, DType>& A, const Tensor<cpu, 2
   LOG(FATAL) << "linalg_trsm not implemented, needs cblas!"; \
 }
 
-#define LINALG_CPU_BATCH_TRSM(DType) \
+#define LINALG_XPU_BATCH_TRSM(xpu, DType) \
 template<> inline \
-void linalg_batch_trsm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
-                   DType alpha, bool rightside, bool lower, bool transpose, Stream<cpu> *s) { \
+void linalg_batch_trsm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B, \
+                   DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) { \
   LOG(FATAL) << "linalg_batch_trsm not implemented, needs cblas!"; \
 }
 
@@ -297,8 +283,8 @@ void linalg_batch_trsm<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<
 LINALG_CPU_TRSM(strsm, float)
 LINALG_CPU_TRSM(dtrsm, double)
 
-LINALG_CPU_BATCH_TRSM(float)
-LINALG_CPU_BATCH_TRSM(double)
+LINALG_XPU_BATCH_TRSM(cpu, float)
+LINALG_XPU_BATCH_TRSM(cpu, double)
 
 #ifdef __CUDACC__
 
@@ -322,37 +308,8 @@ void linalg_trsm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2
 LINALG_GPU_TRSM(Strsm, float)
 LINALG_GPU_TRSM(Dtrsm, double)
 
-#define LINALG_GPU_BATCH_TRSM(fname, DType) \
-template<> inline \
-void linalg_batch_trsm<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
-                   DType alpha, bool rightside, bool lower, bool transpose, Stream<gpu> *s) { \
-  using namespace mxnet; \
-  using mshadow::gpu; \
-  CHECK_NOTNULL(s); \
-  linalg_check_batch_size(A.size(0), B.size(0), B.size(0)); \
-  check_trsm(A[0], B[0], alpha, rightside, lower, transpose); \
-  Storage::Handle offsetsA, offsetsB; \
-  offsetsA = Storage::Get()->Alloc(sizeof(DType*)*A.size(0), Context::GPU()); \
-  offsetsB = Storage::Get()->Alloc(sizeof(DType*)*B.size(0), Context::GPU()); \
-  using namespace mshadow::cuda; \
-  int ngrid = std::min(kMaxGridNum, \
-                       static_cast<int>((A.size(0) + kBaseThreadNum - 1) / kBaseThreadNum)); \
-  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
-    (static_cast<DType **>(offsetsA.dptr), A.dptr_, A.size(1)*A.stride_, A.size(0)); \
-  linalgCollectBatchOffsetsGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
-    (static_cast<DType **>(offsetsB.dptr), B.dptr_, B.size(1)*B.stride_, A.size(0)); \
-  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
-                            (rightside ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT), \
-                            (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \
-                            (transpose ? CUBLAS_OP_T : CUBLAS_OP_N), \
-                            CUBLAS_DIAG_NON_UNIT, B.size(2), B.size(1), &alpha, \
-                            static_cast<const DType **>(offsetsA.dptr), A.stride_, \
-                            static_cast<DType **>(offsetsB.dptr), B.stride_, A.size(0))); \
-  Storage::Get()->Free(offsetsA); \
-  Storage::Get()->Free(offsetsB); \
-}
-LINALG_GPU_BATCH_TRSM(StrsmBatched, float)
-LINALG_GPU_BATCH_TRSM(DtrsmBatched, double)
+LINALG_XPU_BATCH_TRSM(gpu, float)
+LINALG_XPU_BATCH_TRSM(gpu, double)
 
 #endif  // __CUDACC__
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services