You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ar...@apache.org on 2019/05/20 13:40:50 UTC

[incubator-mxnet] branch master updated: Add matrix inversion operator in linalg (#14963)

This is an automated email from the ASF dual-hosted git repository.

arcadiaphy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cbfe48  Add matrix inversion operator in linalg (#14963)
3cbfe48 is described below

commit 3cbfe48c27b370babf7994cd4ffdd5d158a5a8ed
Author: Wang Jiajun <wa...@gmail.com>
AuthorDate: Mon May 20 08:40:07 2019 -0500

    Add matrix inversion operator in linalg (#14963)
    
    * add inverse cpu
    
    * add comment
    
    * add inverse backward cpu
    
    * add inverse gpu
    
    * able to compile
    
    * fix
    
    * fix
    
    * guard for lower version cuda
    
    * update docs
    
    * update docs
    
    * fix misaligned memory
    
    * add test
    
    * fix lint
    
    * fix android
    
    * fix indent
    
    * change transfer gradient
    
    * fix
    
    * refactor test
    
    * delete unnecessary copy
    
    * trigger CI
    
    * fix test
---
 docs/api/python/symbol/linalg.md       |   1 +
 src/operator/c_lapack_api.cc           |  26 +++-
 src/operator/c_lapack_api.h            | 100 ++++++++++++-
 src/operator/linalg.h                  |  49 +++++++
 src/operator/linalg_impl.h             | 254 ++++++++++++++++++++++++++++++++-
 src/operator/tensor/la_op-inl.h        |  40 +++++-
 src/operator/tensor/la_op.cc           |  50 +++++++
 src/operator/tensor/la_op.cu           |   6 +
 src/operator/tensor/la_op.h            |  15 ++
 tests/python/unittest/test_operator.py |  28 ++++
 10 files changed, 556 insertions(+), 13 deletions(-)

diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md
index 5b467b5..436bab7 100644
--- a/docs/api/python/symbol/linalg.md
+++ b/docs/api/python/symbol/linalg.md
@@ -59,6 +59,7 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
     makediag
     extracttrian
     maketrian
+    inverse
 ```
 
 ## API Reference
diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc
index c6293bf..33a5b08 100644
--- a/src/operator/c_lapack_api.cc
+++ b/src/operator/c_lapack_api.cc
@@ -36,15 +36,29 @@
 
   #define MXNET_LAPACK_CWRAPPER2(func, dtype) \
   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
-                                 int lda, dtype* tau, dtype* work, int lwork) { \
+                          int lda, dtype* tau, dtype* work, int lwork) { \
     LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
     return 1; \
   }
 
   #define MXNET_LAPACK_CWRAPPER3(func, dtype) \
   int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
-                                 int lda, dtype *w, dtype *work, int lwork, \
-                                 int *iwork, int liwork) { \
+                          int lda, dtype *w, dtype *work, int lwork, \
+                          int *iwork, int liwork) { \
+    LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
+    return 1; \
+  }
+
+  #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
+  int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
+                          dtype *a, int lda, int *ipiv) { \
+    LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
+    return 1; \
+  }
+
+  #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
+  int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
+                          int *ipiv, dtype *work, int lwork) { \
     LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
     return 1; \
   }
@@ -69,4 +83,10 @@
 
   MXNET_LAPACK_CWRAPPER3(ssyevd, float)
   MXNET_LAPACK_CWRAPPER3(dsyevd, double)
+
+  MXNET_LAPACK_CWRAPPER4(sgetrf, float)
+  MXNET_LAPACK_CWRAPPER4(dgetrf, double)
+
+  MXNET_LAPACK_CWRAPPER5(sgetri, float)
+  MXNET_LAPACK_CWRAPPER5(dgetri, double)
 #endif  // MSHADOW_USE_MKL == 0
diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h
index cd69775..c63229c 100644
--- a/src/operator/c_lapack_api.h
+++ b/src/operator/c_lapack_api.h
@@ -119,6 +119,30 @@ extern "C" {
 
   MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
   MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)
+
+  #ifdef __ANDROID__
+    #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
+      int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
+  #else
+    #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
+      void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
+  #endif
+
+  MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
+  MXNET_LAPACK_FSIG_GETRF(dgetrf, double)
+
+  #ifdef __ANDROID__
+    #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
+      int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
+                  int *lwork, int *info);
+  #else
+    #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
+      void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
+                   int *lwork, int *info);
+  #endif
+
+  MXNET_LAPACK_FSIG_GETRI(sgetri, float)
+  MXNET_LAPACK_FSIG_GETRI(dgetri, double)
 }
 
 #endif  // MSHADOW_USE_MKL == 0
@@ -171,8 +195,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
   // MXNET_LAPACK-signature and have to be wrapped.
   #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
   inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
-                                          dtype *a, int lda, dtype* tau, \
-                                          dtype* work, int lwork) { \
+                                          dtype *a, int lda, dtype *tau, \
+                                          dtype *work, int lwork) { \
     if (lwork != -1) { \
       return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
     } \
@@ -184,8 +208,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
 
   #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
   inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
-                                          dtype *a, int lda, dtype* tau, \
-                                          dtype* work, int lwork) { \
+                                          dtype *a, int lda, dtype *tau, \
+                                          dtype *work, int lwork) { \
     if (lwork != -1) { \
       return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
     } \
@@ -215,6 +239,21 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
   MXNET_LAPACK_CWRAP_SYEVD(s, float)
   MXNET_LAPACK_CWRAP_SYEVD(d, double)
 
+  #define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
+  #define MXNET_LAPACK_dgetrf LAPACKE_dgetrf
+
+  #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
+  inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
+                                          int *ipiv, dtype *work, int lwork) { \
+    if (lwork != -1) { \
+      return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
+    } \
+    *work = 0; \
+    return 0; \
+  }
+  MXNET_LAPACK_CWRAP_GETRI(s, float)
+  MXNET_LAPACK_CWRAP_GETRI(d, double)
+
 #elif MXNET_USE_LAPACK
 
   #define MXNET_LAPACK_ROW_MAJOR 101
@@ -322,6 +361,38 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
   MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
   MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
 
+  // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
+  // (MXNet) for performance.
+  #define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
+  inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
+                                          dtype *a, int lda, int *ipiv) { \
+    if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
+      CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
+      return 1; \
+    } else { \
+      int info(0); \
+      prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
+      return info; \
+    } \
+  }
+  MXNET_LAPACK_CWRAP_GETRF(s, float)
+  MXNET_LAPACK_CWRAP_GETRF(d, double)
+
+  #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
+  inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
+                                          int *ipiv, dtype *work, int lwork) { \
+    if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
+      CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
+      return 1; \
+    } else { \
+      int info(0); \
+      prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
+      return info; \
+    } \
+  }
+  MXNET_LAPACK_CWRAP_GETRI(s, float)
+  MXNET_LAPACK_CWRAP_GETRI(d, double)
+
 #else
 
 
@@ -335,12 +406,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
 
   #define MXNET_LAPACK_CWRAPPER2(func, dtype) \
   int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
-                                 int lda, dtype* tau, dtype* work, int lwork);
+                          int lda, dtype* tau, dtype* work, int lwork);
 
   #define MXNET_LAPACK_CWRAPPER3(func, dtype) \
   int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
-                                 int lda, dtype *w, dtype *work, int lwork, \
-                                 int *iwork, int liwork);
+                          int lda, dtype *w, dtype *work, int lwork, \
+                          int *iwork, int liwork);
+
+  #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
+  int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
+                          dtype *a, int lda, int *ipiv);
+
+  #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
+  int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
+                          int *ipiv, dtype *work, int lwork);
 
   #define MXNET_LAPACK_UNAVAILABLE(func) \
   int mxnet_lapack_##func(...);
@@ -359,9 +438,16 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
 
   MXNET_LAPACK_CWRAPPER3(ssyevd, float)
   MXNET_LAPACK_CWRAPPER3(dsyevd, double)
+
+  MXNET_LAPACK_CWRAPPER4(sgetrf, float)
+  MXNET_LAPACK_CWRAPPER4(dgetrf, double)
+
+  MXNET_LAPACK_CWRAPPER5(sgetri, float)
+  MXNET_LAPACK_CWRAPPER5(dgetri, double)
   #undef MXNET_LAPACK_CWRAPPER1
   #undef MXNET_LAPACK_CWRAPPER2
   #undef MXNET_LAPACK_CWRAPPER3
+  #undef MXNET_LAPACK_CWRAPPER4
   #undef MXNET_LAPACK_UNAVAILABLE
 #endif
 
diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index dc59400..ee713e5 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -191,6 +191,55 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
                                  const Tensor<xpu, 1, DType>& L,
                                  Stream<xpu> *s = 0);
 
+//////////////////////////////// GETRF ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
+// LAPACK documentation for further details.
+// Note that this is A = getrf(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_getrf(const Tensor<xpu, 2, DType>& A,
+                  const Tensor<xpu, 1, DType>& work,
+                  Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
+                        const Tensor<xpu, 1, DType>& work,
+                        Stream<xpu> *s = 0);
+
+//////////////////////////////// GETRI ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getri". Please refer to the
+// LAPACK documentation for further details.
+// Note that this is A = getri(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_getri(const Tensor<xpu, 2, DType>& A,
+                  const Tensor<xpu, 1, DType>& work,
+                  Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
+                        const Tensor<xpu, 3, DType>& B,
+                        const Tensor<xpu, 1, DType>& work,
+                        Stream<xpu> *s = 0);
+
+// This function determines the amount of workspace needed for linalg_getri to operate
+// on a batch of matrices which is returned as number of elements of type DType.
+template<typename xpu, typename DType>
+int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
+                                 Stream<xpu> *s = 0);
+
+//////////////////////////////// INVERSE ////////////////////////////////////////////
+
+// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
+// Note that A = inverse(B)
+template<typename xpu, typename DType>
+void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
+                          const Tensor<xpu, 3, DType>& B,
+                          const Tensor<xpu, 1, DType>& work,
+                          Stream<xpu> *s = 0);
+
 #include "linalg_impl.h"
 
 #endif  // MXNET_OPERATOR_LINALG_H_
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index 4e63f61..718e3f9 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -30,6 +30,7 @@
 #include <algorithm>
 
 #include "../common/cuda_utils.h"
+#include "mxnet_op.h"
 
 // Convenience functions.
 inline void linalg_check_batch_size(int A, int B, int C) {
@@ -1133,7 +1134,7 @@ void linalg_syevd<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
                        A.dptr_, A.stride_, L.dptr_, work.dptr_, -1, &liwork, \
                       -1); \
   int lwork(static_cast<int>(*work.dptr_)); \
-  int *iwork = static_cast<int*>(static_cast<void*>(work.dptr_ + lwork)); \
+  int *iwork = static_cast<int *>(static_cast<void *>(work.dptr_ + lwork)); \
   int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
                                A.dptr_, A.stride_, L.dptr_, work.dptr_, \
                                lwork, iwork, liwork)); \
@@ -1233,4 +1234,255 @@ LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double)
 
 #endif  // __CUDACC__
 
+//////////////////////////////// GETRF ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getrf"
+
+// The input of this function should be col-major for performance.
+// Tensor work holds space for ipiv in getrf
+#define LINALG_CPU_GETRF(fname, DType) \
+template<> inline \
+void linalg_getrf<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+                              const Tensor<cpu, 1, DType>& work, \
+                              Stream<cpu> *s) { \
+  int *ipiv = reinterpret_cast<int *>(work.dptr_); \
+  int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \
+                               A.dptr_, A.stride_, ipiv)); \
+  CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+
+LINALG_CPU_GETRF(sgetrf, float)
+LINALG_CPU_GETRF(dgetrf, double)
+
+#ifdef __CUDACC__
+
+// "getrfBatched" and "getriBatched" in cuBLAS must have DType *matrices[] as input
+// to store the pointers of each batch matrix. This kernel is used to build the
+// pointer array.
+struct set_matrix : public mxnet::op::mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType **p, DType *m, int step) {
+    p[i] = m + i * step;
+  }
+};
+
+// GETRF only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+// Since there is no "getri" in cuSolver, we are using batched version of
+// "getrf" and "getri" in cuBLAS here. These routines are good for large
+// batches of small matrices, so performance issue may happen when computing
+// large matices. We leave it here until MAGMA which has "getri" is introduced
+// into MXNet.
+#define LINALG_GPU_BATCH_GETRF(fname, DType) \
+template<> inline \
+void linalg_batch_getrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                    const Tensor<gpu, 1, DType>& work, \
+                                    Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using namespace mxnet::op::mxnet_op; \
+  CHECK_NOTNULL(s); \
+  Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
+  const Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
+  int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
+  int *info = pivot + A.size(0) * A.size(1); \
+  Copy(temp, A, s); \
+  Kernel<set_matrix, gpu>::Launch(s, temp.size(0), \
+                                  A_ptr, temp.dptr_, \
+                                  temp.size(1) * temp.size(2)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            A.size(1), A_ptr, A.size(2), pivot, \
+                            info, A.size(0))) \
+  Storage::Get()->Free(A_ptr_buf); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_GETRF(fname, DType) \
+template<> inline \
+void linalg_batch_getrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                    const Tensor<gpu, 1, DType>& work, \
+                                    Stream<gpu> *s) { \
+  LOG(FATAL) << "batched getrf requires CUDA version >= 8.0!"; \
+}
+
+#endif  // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_GETRF(SgetrfBatched, float)
+LINALG_GPU_BATCH_GETRF(DgetrfBatched, double)
+
+#endif  // __CUDACC__
+
+//////////////////////////////// GETRI ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getri"
+
+// The input of this function should be col-major for performance.
+// Tensor work holds space for ipiv, work in getri
+#define LINALG_CPU_GETRI(fname, DType) \
+template<> inline \
+void linalg_getri<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+                              const Tensor<cpu, 1, DType>& work, \
+                              Stream<cpu> *s) { \
+  DType wkopt; \
+  MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \
+                       A.stride_, nullptr, &wkopt, -1); \
+  int lwork(static_cast<int>(wkopt)); \
+  int *ipiv = reinterpret_cast<int *>(work.dptr_); \
+  DType *pwork = reinterpret_cast<DType *>(ipiv + A.size(0)); \
+  int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \
+                               A.stride_, ipiv, pwork, lwork)); \
+  CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+LINALG_CPU_GETRI(sgetri, float)
+LINALG_CPU_GETRI(dgetri, double)
+
+// Query workspace for the whole batch of matrices.For cpu version, the workspace
+// is re-used, so space for only one matrix is enough.
+#define LINALG_CPU_GETRI_WORKSPACE_QUERY(func, DType) \
+template<> inline \
+int linalg_getri_workspace_query<cpu, DType>(const Tensor<cpu, 3, DType>& A, \
+                                             Stream<cpu> *s) { \
+  const Tensor<cpu, 2, DType>& matrix = A[0]; \
+  DType lwork(0); \
+  MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, matrix.size(0), matrix.dptr_, \
+                      matrix.stride_, nullptr, &lwork, -1); \
+  int ipiv = (sizeof(int) * matrix.size(0) + sizeof(DType) - 1) / sizeof(DType); \
+  return ipiv + static_cast<int>(lwork); \
+}
+LINALG_CPU_GETRI_WORKSPACE_QUERY(sgetri, float)
+LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double)
+
+#ifdef __CUDACC__
+
+// GETRI only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+// Since there is no "getri" in cuSolver, we are using batched version of
+// "getrf" and "getri" in cuBLAS here. These routines are good for large
+// batches of small matrices, so performance issue may happen when computing
+// large matices. We leave it here until MAGMA which has "getri" is introduced
+// into MXNet.
+#define LINALG_GPU_BATCH_GETRI(fname, DType) \
+template<> inline \
+void linalg_batch_getri<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                    const Tensor<gpu, 3, DType>& B, \
+                                    const Tensor<gpu, 1, DType>& work, \
+                                    Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using namespace mxnet::op::mxnet_op; \
+  CHECK_NOTNULL(s); \
+  Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
+  Storage::Handle B_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  DType **B_ptr = static_cast<DType **>(B_ptr_buf.dptr); \
+  Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
+  int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
+  int *info = pivot + A.size(0) * A.size(1); \
+  Kernel<set_matrix, gpu>::Launch(s, A.size(0), \
+                                  A_ptr, A.dptr_, \
+                                  A.size(1) * A.size(2)); \
+  Kernel<set_matrix, gpu>::Launch(s, temp.size(0), \
+                                  B_ptr, temp.dptr_, \
+                                  temp.size(1) * temp.size(2)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            A.size(1), const_cast<const DType **>(B_ptr), \
+                            B.size(2), const_cast<const int *>(pivot), \
+                            A_ptr, A.size(2), info, A.size(0))) \
+  Storage::Get()->Free(A_ptr_buf); \
+  Storage::Get()->Free(B_ptr_buf); \
+}
+
+#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \
+template<> inline \
+int linalg_getri_workspace_query<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                             Stream<gpu> *s) { \
+  int pivot_size = sizeof(int) * A.size(0) * A.size(1); \
+  int info_size = sizeof(int) * A.size(0); \
+  int matrix_size = sizeof(DType) * A.shape_.Size(); \
+  return (pivot_size + info_size + matrix_size + sizeof(DType) - 1) / sizeof(DType); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_GETRI(fname, DType) \
+template<> inline \
+void linalg_batch_getri<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                    const Tensor<gpu, 3, DType>& B, \
+                                    const Tensor<gpu, 1, DType>& work, \
+                                    Stream<gpu> *s) { \
+  LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \
+}
+
+#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \
+template<> inline \
+int linalg_getri_workspace_query<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+                                             Stream<gpu> *s) { \
+  LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \
+}
+
+#endif  // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_GETRI(SgetriBatched, float)
+LINALG_GPU_BATCH_GETRI(DgetriBatched, double)
+
+LINALG_GPU_GETRI_WORKSPACE_QUERY(SgetriBatched, float)
+LINALG_GPU_GETRI_WORKSPACE_QUERY(DgetriBatched, double)
+
+#endif  // __CUDACC__
+
+//////////////////////////////// INVERSE ////////////////////////////////////////////
+
+// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
+
+// Note A = inverse(B)
+#define LINALG_CPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+                                      const Tensor<xpu, 3, DType>& B, \
+                                      const Tensor<xpu, 1, DType>& work, \
+                                      Stream<cpu> *s) { \
+  if (A.dptr_ != B.dptr_) Copy(A, B, s); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_getrf(A[i], work, s); \
+    linalg_getri(A[i], work, s); \
+  } \
+}
+LINALG_CPU_BATCH_INVERSE(cpu, float)
+LINALG_CPU_BATCH_INVERSE(cpu, double)
+
+#ifdef __CUDACC__
+
+// GETRF and GETRI only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+#define LINALG_GPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+                                      const Tensor<xpu, 3, DType>& B, \
+                                      const Tensor<xpu, 1, DType>& work, \
+                                      Stream<gpu> *s) { \
+  linalg_batch_getrf(B, work, s); \
+  linalg_batch_getri(A, B, work, s); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+                                      const Tensor<xpu, 3, DType>& B, \
+                                      const Tensor<xpu, 1, DType>& work, \
+                                      Stream<gpu> *s) { \
+  LOG(FATAL) << "batched getrf and getri requires CUDA version >= 8.0!"; \
+}
+
+#endif  // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_INVERSE(gpu, float)
+LINALG_GPU_BATCH_INVERSE(gpu, double)
+
+#endif  // __CUDACC__
+
 #endif  // MXNET_OPERATOR_LINALG_IMPL_H_
diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h
index bda8137..4dead87 100644
--- a/src/operator/tensor/la_op-inl.h
+++ b/src/operator/tensor/la_op-inl.h
@@ -98,11 +98,16 @@ struct gemm {
 struct gemm2 {
   template<typename xpu, int dim, typename DType>
   static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, DType alpha, bool tA, bool tB,
+                 Stream<xpu> *s) {
+    gemm::op(A, B, C, DType(alpha), DType(0), tA, tB, s);
+  }
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
                  const Tensor<xpu, dim, DType>& C, Stream<xpu> *s,
                  const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed);
-    gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a,
-             param.transpose_b, s);
+    op(A, B, C, DType(param.alpha), param.transpose_a, param.transpose_b, s);
   }
   template<typename xpu, int dim, typename DType>
   static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
@@ -448,6 +453,22 @@ struct syevd {
   }
 };
 
+// A = inverse(B).
+struct inverse {
+  template<typename xpu, typename DType>
+  static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
+                 const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    // Reserve workspace (size determined by query)
+    int lwork(linalg_getri_workspace_query(A, s));
+    Tensor<xpu, 1, DType> work = ctx.requested[0]
+      .get_space_typed<xpu, 1, DType>(Shape1(lwork), s);
+    // Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose
+    // A even if we are using the col-major version of getrf and getri routines.
+    linalg_batch_inverse(A, B, work, s);
+  }
+};
+
 // Backward operators (always using batch processing)
 
 struct gemm_backward {
@@ -789,6 +810,21 @@ struct syevd_backward {
   }
 };
 
+struct inverse_backward {
+  template<typename xpu, typename DType>
+  static void op(const Tensor<xpu, 3, DType>& dA,
+                 const Tensor<xpu, 3, DType>& A,
+                 const Tensor<xpu, 3, DType>& dB,
+                 const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+    // Backward of A = inverse(B)
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 3, DType> temp = ctx.requested[0]
+      .get_space_typed<xpu, 3, DType>(A.shape_, s);
+    gemm2::op(dA, A, temp, DType(1), false, true, s);
+    gemm2::op(A, temp, dB, DType(-1), true, false, s);
+  }
+};
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index d6e64c4..2fa1fd3 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -889,5 +889,55 @@ NNVM_REGISTER_OP(_backward_linalg_syevd)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", LaOpBackwSyevd<cpu, syevd_backward>);
 
+NNVM_REGISTER_OP(_linalg_inverse)
+.add_alias("linalg_inverse")
+.describe(R"code(Compute the inverse of a matrix.
+Input is a tensor *A* of dimension *n >= 2*.
+
+If *n=2*, *A* is a square matrix. We compute:
+
+  *out* = *A*\ :sup:`-1`
+
+If *n>2*, *inverse* is performed separately on the trailing two dimensions
+for all inputs (batch mode).
+
+.. note:: The operator supports float32 and float64 data types only.
+
+Examples::
+
+   // Single matrix inversion
+   A = [[1., 4.], [2., 3.]]
+   inverse(A) = [[-0.6, 0.8], [0.4, -0.2]]
+
+   // Batch matrix inversion
+   A = [[[1., 4.], [2., 3.]],
+        [[1., 3.], [2., 4.]]]
+   inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]],
+                 [[-2., 1.5], [1., -0.5]]]
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
+  { return std::vector<std::string>{"A"}; } )
+.set_attr<mxnet::FInferShape>("FInferShape", InverseShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+  { return std::vector<std::pair<int, int>>{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 1, 1, inverse>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_linalg_inverse"})
+.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix");
+
+NNVM_REGISTER_OP(_backward_linalg_inverse)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+  { return std::vector<std::pair<int, int> >{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+  { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, inverse_backward>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index ec310fe..3ef714e 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -93,6 +93,12 @@ NNVM_REGISTER_OP(_linalg_potri)
 NNVM_REGISTER_OP(_backward_linalg_potri)
 .set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 1, potri_backward>);
 
+NNVM_REGISTER_OP(_linalg_inverse)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, inverse>);
+
+NNVM_REGISTER_OP(_backward_linalg_inverse)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, inverse_backward>);
+
 #if MXNET_USE_CUSOLVER == 1
 
 NNVM_REGISTER_OP(_linalg_potrf)
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 3b36f7c..5b0c7e3 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -398,6 +398,21 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs,
   return false;
 }
 
+// Shape inference function for linalg_inverse
+// Inputs: A. Outputs: inverse(A)
+inline bool InverseShape(const nnvm::NodeAttrs& attrs,
+                         mxnet::ShapeVector* in_attrs,
+                         mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  const mxnet::TShape& in = (*in_attrs)[0];
+  const int ndim(in.ndim());
+  CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2";
+  CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal";
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, in);
+  return true;
+}
+
 // Shape inference function for linalg_syevd
 // Inputs: A. Outputs: U, L
 inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1768da2..ee94629 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6425,6 +6425,34 @@ def test_laop_5():
                     check_symbolic_forward(test_trian, [data_in], [res_trian])
                     check_numeric_gradient(test_trian, [data_in])
 
+# Tests for linalg.inverse
+@with_seed()
+def test_laop_6():
+    dtype = np.float64
+    rtol_fw = 1e-7
+    atol_fw = 1e-9
+    num_eps = 1e-6
+    rtol_bw = 1e-4
+    atol_bw = 1e-6
+
+    data = mx.symbol.Variable('data')
+
+    check_fw = lambda sym, location, expected:\
+        check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
+                               atol=atol_fw, dtype=dtype)
+    check_grad = lambda sym, location:\
+        check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
+                               atol=atol_bw, dtype=dtype)
+
+    a = np.sqrt(np.arange(4 * 4)).reshape(4, 4)
+    a = np.tile(a, (3, 1, 1))
+    r = np.eye(4)
+    r = np.tile(r, (3, 1, 1))
+    test_inverse = mx.sym.linalg.inverse(data)
+    test_eye = mx.sym.linalg.gemm2(data, test_inverse)
+    check_fw(test_eye, [a], [r])
+    check_grad(test_inverse, [a])
+
 @with_seed()
 def test_stack():
     for _ in range(100):