You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ar...@apache.org on 2019/05/20 13:40:50 UTC
[incubator-mxnet] branch master updated: Add matrix inversion
operator in linalg (#14963)
This is an automated email from the ASF dual-hosted git repository.
arcadiaphy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3cbfe48 Add matrix inversion operator in linalg (#14963)
3cbfe48 is described below
commit 3cbfe48c27b370babf7994cd4ffdd5d158a5a8ed
Author: Wang Jiajun <wa...@gmail.com>
AuthorDate: Mon May 20 08:40:07 2019 -0500
Add matrix inversion operator in linalg (#14963)
* add inverse cpu
* add comment
* add inverse backward cpu
* add inverse gpu
* able to compile
* fix
* fix
* guard for lower version cuda
* update docs
* update docs
* fix misaligned memory
* add test
* fix lint
* fix android
* fix indent
* change transfer gradient
* fix
* refactor test
* delete unnecessary copy
* trigger CI
* fix test
---
docs/api/python/symbol/linalg.md | 1 +
src/operator/c_lapack_api.cc | 26 +++-
src/operator/c_lapack_api.h | 100 ++++++++++++-
src/operator/linalg.h | 49 +++++++
src/operator/linalg_impl.h | 254 ++++++++++++++++++++++++++++++++-
src/operator/tensor/la_op-inl.h | 40 +++++-
src/operator/tensor/la_op.cc | 50 +++++++
src/operator/tensor/la_op.cu | 6 +
src/operator/tensor/la_op.h | 15 ++
tests/python/unittest/test_operator.py | 28 ++++
10 files changed, 556 insertions(+), 13 deletions(-)
diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md
index 5b467b5..436bab7 100644
--- a/docs/api/python/symbol/linalg.md
+++ b/docs/api/python/symbol/linalg.md
@@ -59,6 +59,7 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
makediag
extracttrian
maketrian
+ inverse
```
## API Reference
diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc
index c6293bf..33a5b08 100644
--- a/src/operator/c_lapack_api.cc
+++ b/src/operator/c_lapack_api.cc
@@ -36,15 +36,29 @@
#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
- int lda, dtype* tau, dtype* work, int lwork) { \
+ int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
- int lda, dtype *w, dtype *work, int lwork, \
- int *iwork, int liwork) { \
+ int lda, dtype *w, dtype *work, int lwork, \
+ int *iwork, int liwork) { \
+ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
+ return 1; \
+ }
+
+ #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
+ int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
+ dtype *a, int lda, int *ipiv) { \
+ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
+ return 1; \
+ }
+
+ #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
+ int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
+ int *ipiv, dtype *work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
@@ -69,4 +83,10 @@
MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)
+
+ MXNET_LAPACK_CWRAPPER4(sgetrf, float)
+ MXNET_LAPACK_CWRAPPER4(dgetrf, double)
+
+ MXNET_LAPACK_CWRAPPER5(sgetri, float)
+ MXNET_LAPACK_CWRAPPER5(dgetri, double)
#endif // MSHADOW_USE_MKL == 0
diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h
index cd69775..c63229c 100644
--- a/src/operator/c_lapack_api.h
+++ b/src/operator/c_lapack_api.h
@@ -119,6 +119,30 @@ extern "C" {
MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)
+
+ #ifdef __ANDROID__
+ #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
+ int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
+ #else
+ #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \
+ void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info);
+ #endif
+
+ MXNET_LAPACK_FSIG_GETRF(sgetrf, float)
+ MXNET_LAPACK_FSIG_GETRF(dgetrf, double)
+
+ #ifdef __ANDROID__
+ #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
+ int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
+ int *lwork, int *info);
+ #else
+ #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \
+ void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \
+ int *lwork, int *info);
+ #endif
+
+ MXNET_LAPACK_FSIG_GETRI(sgetri, float)
+ MXNET_LAPACK_FSIG_GETRI(dgetri, double)
}
#endif // MSHADOW_USE_MKL == 0
@@ -171,8 +195,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
// MXNET_LAPACK-signature and have to be wrapped.
#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \
- dtype *a, int lda, dtype* tau, \
- dtype* work, int lwork) { \
+ dtype *a, int lda, dtype *tau, \
+ dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \
} \
@@ -184,8 +208,8 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \
inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \
- dtype *a, int lda, dtype* tau, \
- dtype* work, int lwork) { \
+ dtype *a, int lda, dtype *tau, \
+ dtype *work, int lwork) { \
if (lwork != -1) { \
return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \
} \
@@ -215,6 +239,21 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(s, float)
MXNET_LAPACK_CWRAP_SYEVD(d, double)
+ #define MXNET_LAPACK_sgetrf LAPACKE_sgetrf
+ #define MXNET_LAPACK_dgetrf LAPACKE_dgetrf
+
+ #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
+ inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
+ int *ipiv, dtype *work, int lwork) { \
+ if (lwork != -1) { \
+ return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \
+ } \
+ *work = 0; \
+ return 0; \
+ }
+ MXNET_LAPACK_CWRAP_GETRI(s, float)
+ MXNET_LAPACK_CWRAP_GETRI(d, double)
+
#elif MXNET_USE_LAPACK
#define MXNET_LAPACK_ROW_MAJOR 101
@@ -322,6 +361,38 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
+ // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format
+ // (MXNet) for performance.
+ #define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \
+ inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \
+ dtype *a, int lda, int *ipiv) { \
+ if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
+ CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
+ return 1; \
+ } else { \
+ int info(0); \
+ prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \
+ return info; \
+ } \
+ }
+ MXNET_LAPACK_CWRAP_GETRF(s, float)
+ MXNET_LAPACK_CWRAP_GETRF(d, double)
+
+ #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \
+ inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \
+ int *ipiv, dtype *work, int lwork) { \
+ if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
+ CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \
+ return 1; \
+ } else { \
+ int info(0); \
+ prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \
+ return info; \
+ } \
+ }
+ MXNET_LAPACK_CWRAP_GETRI(s, float)
+ MXNET_LAPACK_CWRAP_GETRI(d, double)
+
#else
@@ -335,12 +406,20 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
- int lda, dtype* tau, dtype* work, int lwork);
+ int lda, dtype* tau, dtype* work, int lwork);
#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
- int lda, dtype *w, dtype *work, int lwork, \
- int *iwork, int liwork);
+ int lda, dtype *w, dtype *work, int lwork, \
+ int *iwork, int liwork);
+
+ #define MXNET_LAPACK_CWRAPPER4(func, dtype) \
+ int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \
+ dtype *a, int lda, int *ipiv);
+
+ #define MXNET_LAPACK_CWRAPPER5(func, dtype) \
+ int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \
+ int *ipiv, dtype *work, int lwork);
#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...);
@@ -359,9 +438,16 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)
+
+ MXNET_LAPACK_CWRAPPER4(sgetrf, float)
+ MXNET_LAPACK_CWRAPPER4(dgetrf, double)
+
+ MXNET_LAPACK_CWRAPPER5(sgetri, float)
+ MXNET_LAPACK_CWRAPPER5(dgetri, double)
#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
+ #undef MXNET_LAPACK_CWRAPPER4
#undef MXNET_LAPACK_UNAVAILABLE
#endif
diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index dc59400..ee713e5 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -191,6 +191,55 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
const Tensor<xpu, 1, DType>& L,
Stream<xpu> *s = 0);
+//////////////////////////////// GETRF ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
+// LAPACK documentation for further details.
+// Note that this is A = getrf(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_getrf(const Tensor<xpu, 2, DType>& A,
+ const Tensor<xpu, 1, DType>& work,
+ Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
+ const Tensor<xpu, 1, DType>& work,
+ Stream<xpu> *s = 0);
+
+//////////////////////////////// GETRI ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getri". Please refer to the
+// LAPACK documentation for further details.
+// Note that this is A = getri(A), so A is input and output parameter.
+
+template<typename xpu, typename DType>
+void linalg_getri(const Tensor<xpu, 2, DType>& A,
+ const Tensor<xpu, 1, DType>& work,
+ Stream<xpu> *s = 0);
+
+template<typename xpu, typename DType>
+void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
+ const Tensor<xpu, 3, DType>& B,
+ const Tensor<xpu, 1, DType>& work,
+ Stream<xpu> *s = 0);
+
+// This function determines the amount of workspace needed for linalg_getri to operate
+// on a batch of matrices which is returned as number of elements of type DType.
+template<typename xpu, typename DType>
+int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
+ Stream<xpu> *s = 0);
+
+//////////////////////////////// INVERSE ////////////////////////////////////////////
+
+// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
+// Note that A = inverse(B)
+template<typename xpu, typename DType>
+void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
+ const Tensor<xpu, 3, DType>& B,
+ const Tensor<xpu, 1, DType>& work,
+ Stream<xpu> *s = 0);
+
#include "linalg_impl.h"
#endif // MXNET_OPERATOR_LINALG_H_
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index 4e63f61..718e3f9 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -30,6 +30,7 @@
#include <algorithm>
#include "../common/cuda_utils.h"
+#include "mxnet_op.h"
// Convenience functions.
inline void linalg_check_batch_size(int A, int B, int C) {
@@ -1133,7 +1134,7 @@ void linalg_syevd<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
A.dptr_, A.stride_, L.dptr_, work.dptr_, -1, &liwork, \
-1); \
int lwork(static_cast<int>(*work.dptr_)); \
- int *iwork = static_cast<int*>(static_cast<void*>(work.dptr_ + lwork)); \
+ int *iwork = static_cast<int *>(static_cast<void *>(work.dptr_ + lwork)); \
int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
A.dptr_, A.stride_, L.dptr_, work.dptr_, \
lwork, iwork, liwork)); \
@@ -1233,4 +1234,255 @@ LINALG_GPU_SYEVD_WORKSPACE_QUERY(DnDsyevd, double)
#endif // __CUDACC__
+//////////////////////////////// GETRF ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getrf"
+
+// The input of this function should be col-major for performance.
+// Tensor work holds space for ipiv in getrf
+#define LINALG_CPU_GETRF(fname, DType) \
+template<> inline \
+void linalg_getrf<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+ const Tensor<cpu, 1, DType>& work, \
+ Stream<cpu> *s) { \
+ int *ipiv = reinterpret_cast<int *>(work.dptr_); \
+ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(1), A.size(0), \
+ A.dptr_, A.stride_, ipiv)); \
+ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+
+LINALG_CPU_GETRF(sgetrf, float)
+LINALG_CPU_GETRF(dgetrf, double)
+
+#ifdef __CUDACC__
+
+// "getrfBatched" and "getriBatched" in cuBLAS must have DType *matrices[] as input
+// to store the pointers of each batch matrix. This kernel is used to build the
+// pointer array.
+struct set_matrix : public mxnet::op::mxnet_op::tunable {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, DType **p, DType *m, int step) {
+ p[i] = m + i * step;
+ }
+};
+
+// GETRF only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+// Since there is no "getri" in cuSolver, we are using batched version of
+// "getrf" and "getri" in cuBLAS here. These routines are good for large
+// batches of small matrices, so performance issue may happen when computing
+// large matices. We leave it here until MAGMA which has "getri" is introduced
+// into MXNet.
+#define LINALG_GPU_BATCH_GETRF(fname, DType) \
+template<> inline \
+void linalg_batch_getrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ const Tensor<gpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ using namespace mxnet; \
+ using namespace mxnet::op::mxnet_op; \
+ CHECK_NOTNULL(s); \
+ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+ DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
+ const Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
+ int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
+ int *info = pivot + A.size(0) * A.size(1); \
+ Copy(temp, A, s); \
+ Kernel<set_matrix, gpu>::Launch(s, temp.size(0), \
+ A_ptr, temp.dptr_, \
+ temp.size(1) * temp.size(2)); \
+ CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+ A.size(1), A_ptr, A.size(2), pivot, \
+ info, A.size(0))) \
+ Storage::Get()->Free(A_ptr_buf); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_GETRF(fname, DType) \
+template<> inline \
+void linalg_batch_getrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ const Tensor<gpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ LOG(FATAL) << "batched getrf requires CUDA version >= 8.0!"; \
+}
+
+#endif // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_GETRF(SgetrfBatched, float)
+LINALG_GPU_BATCH_GETRF(DgetrfBatched, double)
+
+#endif // __CUDACC__
+
+//////////////////////////////// GETRI ////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "getri"
+
+// The input of this function should be col-major for performance.
+// Tensor work holds space for ipiv, work in getri
+#define LINALG_CPU_GETRI(fname, DType) \
+template<> inline \
+void linalg_getri<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+ const Tensor<cpu, 1, DType>& work, \
+ Stream<cpu> *s) { \
+ DType wkopt; \
+ MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \
+ A.stride_, nullptr, &wkopt, -1); \
+ int lwork(static_cast<int>(wkopt)); \
+ int *ipiv = reinterpret_cast<int *>(work.dptr_); \
+ DType *pwork = reinterpret_cast<DType *>(ipiv + A.size(0)); \
+ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_COL_MAJOR, A.size(0), A.dptr_, \
+ A.stride_, ipiv, pwork, lwork)); \
+ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+LINALG_CPU_GETRI(sgetri, float)
+LINALG_CPU_GETRI(dgetri, double)
+
+// Query workspace for the whole batch of matrices.For cpu version, the workspace
+// is re-used, so space for only one matrix is enough.
+#define LINALG_CPU_GETRI_WORKSPACE_QUERY(func, DType) \
+template<> inline \
+int linalg_getri_workspace_query<cpu, DType>(const Tensor<cpu, 3, DType>& A, \
+ Stream<cpu> *s) { \
+ const Tensor<cpu, 2, DType>& matrix = A[0]; \
+ DType lwork(0); \
+ MXNET_LAPACK_##func(MXNET_LAPACK_COL_MAJOR, matrix.size(0), matrix.dptr_, \
+ matrix.stride_, nullptr, &lwork, -1); \
+ int ipiv = (sizeof(int) * matrix.size(0) + sizeof(DType) - 1) / sizeof(DType); \
+ return ipiv + static_cast<int>(lwork); \
+}
+LINALG_CPU_GETRI_WORKSPACE_QUERY(sgetri, float)
+LINALG_CPU_GETRI_WORKSPACE_QUERY(dgetri, double)
+
+#ifdef __CUDACC__
+
+// GETRI only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+// Since there is no "getri" in cuSolver, we are using batched version of
+// "getrf" and "getri" in cuBLAS here. These routines are good for large
+// batches of small matrices, so performance issue may happen when computing
+// large matices. We leave it here until MAGMA which has "getri" is introduced
+// into MXNet.
+#define LINALG_GPU_BATCH_GETRI(fname, DType) \
+template<> inline \
+void linalg_batch_getri<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ const Tensor<gpu, 3, DType>& B, \
+ const Tensor<gpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ using namespace mxnet; \
+ using namespace mxnet::op::mxnet_op; \
+ CHECK_NOTNULL(s); \
+ Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+ DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
+ Storage::Handle B_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+ DType **B_ptr = static_cast<DType **>(B_ptr_buf.dptr); \
+ Tensor<gpu, 3, DType> temp(work.dptr_, A.shape_, s); \
+ int *pivot = reinterpret_cast<int *>(temp.dptr_ + temp.shape_.Size()); \
+ int *info = pivot + A.size(0) * A.size(1); \
+ Kernel<set_matrix, gpu>::Launch(s, A.size(0), \
+ A_ptr, A.dptr_, \
+ A.size(1) * A.size(2)); \
+ Kernel<set_matrix, gpu>::Launch(s, temp.size(0), \
+ B_ptr, temp.dptr_, \
+ temp.size(1) * temp.size(2)); \
+ CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+ A.size(1), const_cast<const DType **>(B_ptr), \
+ B.size(2), const_cast<const int *>(pivot), \
+ A_ptr, A.size(2), info, A.size(0))) \
+ Storage::Get()->Free(A_ptr_buf); \
+ Storage::Get()->Free(B_ptr_buf); \
+}
+
+#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \
+template<> inline \
+int linalg_getri_workspace_query<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ Stream<gpu> *s) { \
+ int pivot_size = sizeof(int) * A.size(0) * A.size(1); \
+ int info_size = sizeof(int) * A.size(0); \
+ int matrix_size = sizeof(DType) * A.shape_.Size(); \
+ return (pivot_size + info_size + matrix_size + sizeof(DType) - 1) / sizeof(DType); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_GETRI(fname, DType) \
+template<> inline \
+void linalg_batch_getri<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ const Tensor<gpu, 3, DType>& B, \
+ const Tensor<gpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \
+}
+
+#define LINALG_GPU_GETRI_WORKSPACE_QUERY(fname, DType) \
+template<> inline \
+int linalg_getri_workspace_query<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
+ Stream<gpu> *s) { \
+ LOG(FATAL) << "batched getri requires CUDA version >= 8.0!"; \
+}
+
+#endif // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_GETRI(SgetriBatched, float)
+LINALG_GPU_BATCH_GETRI(DgetriBatched, double)
+
+LINALG_GPU_GETRI_WORKSPACE_QUERY(SgetriBatched, float)
+LINALG_GPU_GETRI_WORKSPACE_QUERY(DgetriBatched, double)
+
+#endif // __CUDACC__
+
+//////////////////////////////// INVERSE ////////////////////////////////////////////
+
+// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
+
+// Note A = inverse(B)
+#define LINALG_CPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+ const Tensor<xpu, 3, DType>& B, \
+ const Tensor<xpu, 1, DType>& work, \
+ Stream<cpu> *s) { \
+ if (A.dptr_ != B.dptr_) Copy(A, B, s); \
+ for (index_t i = 0; i < A.size(0); ++i) { \
+ linalg_getrf(A[i], work, s); \
+ linalg_getri(A[i], work, s); \
+ } \
+}
+LINALG_CPU_BATCH_INVERSE(cpu, float)
+LINALG_CPU_BATCH_INVERSE(cpu, double)
+
+#ifdef __CUDACC__
+
+// GETRF and GETRI only available with cuda8 or higher.
+#if CUDA_VERSION >= 8000
+
+#define LINALG_GPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+ const Tensor<xpu, 3, DType>& B, \
+ const Tensor<xpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ linalg_batch_getrf(B, work, s); \
+ linalg_batch_getri(A, B, work, s); \
+}
+
+#else
+
+#define LINALG_GPU_BATCH_INVERSE(xpu, DType) \
+template<> inline \
+void linalg_batch_inverse<xpu, DType>(const Tensor<xpu, 3, DType>& A, \
+ const Tensor<xpu, 3, DType>& B, \
+ const Tensor<xpu, 1, DType>& work, \
+ Stream<gpu> *s) { \
+ LOG(FATAL) << "batched getrf and getri requires CUDA version >= 8.0!"; \
+}
+
+#endif // CUDA_VERSION >= 8000
+
+LINALG_GPU_BATCH_INVERSE(gpu, float)
+LINALG_GPU_BATCH_INVERSE(gpu, double)
+
+#endif // __CUDACC__
+
#endif // MXNET_OPERATOR_LINALG_IMPL_H_
diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h
index bda8137..4dead87 100644
--- a/src/operator/tensor/la_op-inl.h
+++ b/src/operator/tensor/la_op-inl.h
@@ -98,11 +98,16 @@ struct gemm {
struct gemm2 {
template<typename xpu, int dim, typename DType>
static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+ const Tensor<xpu, dim, DType>& C, DType alpha, bool tA, bool tB,
+ Stream<xpu> *s) {
+ gemm::op(A, B, C, DType(alpha), DType(0), tA, tB, s);
+ }
+ template<typename xpu, int dim, typename DType>
+ static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
const Tensor<xpu, dim, DType>& C, Stream<xpu> *s,
const nnvm::NodeAttrs& attrs) {
const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed);
- gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a,
- param.transpose_b, s);
+ op(A, B, C, DType(param.alpha), param.transpose_a, param.transpose_b, s);
}
template<typename xpu, int dim, typename DType>
static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
@@ -448,6 +453,22 @@ struct syevd {
}
};
+// A = inverse(B).
+struct inverse {
+ template<typename xpu, typename DType>
+ static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
+ const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ // Reserve workspace (size determined by query)
+ int lwork(linalg_getri_workspace_query(A, s));
+ Tensor<xpu, 1, DType> work = ctx.requested[0]
+ .get_space_typed<xpu, 1, DType>(Shape1(lwork), s);
+ // Since inverse(A) = trans(inverse(trans(A))), so we don't need to transpose
+ // A even if we are using the col-major version of getrf and getri routines.
+ linalg_batch_inverse(A, B, work, s);
+ }
+};
+
// Backward operators (always using batch processing)
struct gemm_backward {
@@ -789,6 +810,21 @@ struct syevd_backward {
}
};
+struct inverse_backward {
+ template<typename xpu, typename DType>
+ static void op(const Tensor<xpu, 3, DType>& dA,
+ const Tensor<xpu, 3, DType>& A,
+ const Tensor<xpu, 3, DType>& dB,
+ const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+ // Backward of A = inverse(B)
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ Tensor<xpu, 3, DType> temp = ctx.requested[0]
+ .get_space_typed<xpu, 3, DType>(A.shape_, s);
+ gemm2::op(dA, A, temp, DType(1), false, true, s);
+ gemm2::op(A, temp, dB, DType(-1), true, false, s);
+ }
+};
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index d6e64c4..2fa1fd3 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -889,5 +889,55 @@ NNVM_REGISTER_OP(_backward_linalg_syevd)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackwSyevd<cpu, syevd_backward>);
+NNVM_REGISTER_OP(_linalg_inverse)
+.add_alias("linalg_inverse")
+.describe(R"code(Compute the inverse of a matrix.
+Input is a tensor *A* of dimension *n >= 2*.
+
+If *n=2*, *A* is a square matrix. We compute:
+
+ *out* = *A*\ :sup:`-1`
+
+If *n>2*, *inverse* is performed separately on the trailing two dimensions
+for all inputs (batch mode).
+
+.. note:: The operator supports float32 and float64 data types only.
+
+Examples::
+
+ // Single matrix inversion
+ A = [[1., 4.], [2., 3.]]
+ inverse(A) = [[-0.6, 0.8], [0.4, -0.2]]
+
+ // Batch matrix inversion
+ A = [[[1., 4.], [2., 3.]],
+ [[1., 3.], [2., 4.]]]
+ inverse(A) = [[[-0.6, 0.8], [0.4, -0.2]],
+ [[-2., 1.5], [1., -0.5]]]
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
+ { return std::vector<std::string>{"A"}; } )
+.set_attr<mxnet::FInferShape>("FInferShape", InverseShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+ { return std::vector<std::pair<int, int>>{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+ { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 1, 1, inverse>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_linalg_inverse"})
+.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix");
+
+NNVM_REGISTER_OP(_backward_linalg_inverse)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+ { return std::vector<std::pair<int, int> >{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+ { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, inverse_backward>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index ec310fe..3ef714e 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -93,6 +93,12 @@ NNVM_REGISTER_OP(_linalg_potri)
NNVM_REGISTER_OP(_backward_linalg_potri)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 1, potri_backward>);
+NNVM_REGISTER_OP(_linalg_inverse)
+.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, inverse>);
+
+NNVM_REGISTER_OP(_backward_linalg_inverse)
+.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1, inverse_backward>);
+
#if MXNET_USE_CUSOLVER == 1
NNVM_REGISTER_OP(_linalg_potrf)
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 3b36f7c..5b0c7e3 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -398,6 +398,21 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs,
return false;
}
+// Shape inference function for linalg_inverse
+// Inputs: A. Outputs: inverse(A)
+inline bool InverseShape(const nnvm::NodeAttrs& attrs,
+ mxnet::ShapeVector* in_attrs,
+ mxnet::ShapeVector* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1);
+ CHECK_EQ(out_attrs->size(), 1);
+ const mxnet::TShape& in = (*in_attrs)[0];
+ const int ndim(in.ndim());
+ CHECK_GE(ndim, 2) << "Input A's dimension must be >= 2";
+ CHECK_EQ(in[ndim-2], in[ndim-1]) << "Input A's last two dimension must be equal";
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, in);
+ return true;
+}
+
// Shape inference function for linalg_syevd
// Inputs: A. Outputs: U, L
inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1768da2..ee94629 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6425,6 +6425,34 @@ def test_laop_5():
check_symbolic_forward(test_trian, [data_in], [res_trian])
check_numeric_gradient(test_trian, [data_in])
+# Tests for linalg.inverse
+@with_seed()
+def test_laop_6():
+ dtype = np.float64
+ rtol_fw = 1e-7
+ atol_fw = 1e-9
+ num_eps = 1e-6
+ rtol_bw = 1e-4
+ atol_bw = 1e-6
+
+ data = mx.symbol.Variable('data')
+
+ check_fw = lambda sym, location, expected:\
+ check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
+ atol=atol_fw, dtype=dtype)
+ check_grad = lambda sym, location:\
+ check_numeric_gradient(sym, location, numeric_eps=num_eps, rtol=rtol_bw,
+ atol=atol_bw, dtype=dtype)
+
+ a = np.sqrt(np.arange(4 * 4)).reshape(4, 4)
+ a = np.tile(a, (3, 1, 1))
+ r = np.eye(4)
+ r = np.tile(r, (3, 1, 1))
+ test_inverse = mx.sym.linalg.inverse(data)
+ test_eye = mx.sym.linalg.gemm2(data, test_inverse)
+ check_fw(test_eye, [a], [r])
+ check_grad(test_inverse, [a])
+
@with_seed()
def test_stack():
for _ in range(100):