You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/29 18:15:14 UTC

[GitHub] piiswrong closed pull request #10864: Support for axis parameter in linalg.gemm

piiswrong closed pull request #10864: Support for axis parameter in linalg.gemm
URL: https://github.com/apache/incubator-mxnet/pull/10864
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index aee67d739b0..dc5940013c6 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -64,6 +64,13 @@ void linalg_batch_gemm(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DTyp
                        const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
                        bool tA, bool tB, Stream<xpu> *s = 0);
 
+// Version of batch gemmm where rows are indexed at axis 1 and columns at axis 3.
+template<typename xpu, typename DType>
+void linalg_batch_gemm(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B,
+                       const Tensor<xpu, 4, DType>& C, DType alpha, DType beta,
+                       bool tA, bool tB, Stream<xpu> *s = 0);
+
+
 template<typename xpu, typename DType>
 inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
                         const Tensor<xpu, 2, DType>& B,
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index d1286170c2c..43eba9b1270 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -56,6 +56,11 @@ inline void check_gemm(const Tensor<xpu, 2, DType>& A, const Tensor<xpu, 2, DTyp
     << "Non compatible matrix dimensions between inputs A and B for gemm";
 }
 
+template<typename xpu, typename DType>
+void linalg_gemm_axis(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
+                      const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+                      bool tA, bool tB, Stream<xpu> *s = 0);
+
 #if (MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1)
 
 #define LINALG_CPU_GEMM(fname, DType) \
@@ -80,6 +85,38 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
   } \
 }
 
+// Batched gemm where the batch coordinate is given by the second axis.
+#define LINALG_CPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<cpu, DType>(const Tensor<cpu, 3, DType>& A, const Tensor<cpu, 3, DType>& B, \
+                                  const Tensor<cpu, 3, DType>& C, DType alpha, DType beta, \
+                                  bool tA, bool tB, Stream<cpu> *s) { \
+  linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+  for (index_t i = 0; i < A.size(1); ++i) { \
+     cblas_##fname(CblasRowMajor, (tA ? CblasTrans : CblasNoTrans), \
+                   (tB ? CblasTrans : CblasNoTrans), \
+                   C.size(0), C.size(2), (tA ? A.size(0) : A.size(2)), alpha, \
+                   A.dptr_+i*A.stride_, A.size(1)*A.stride_, \
+                   B.dptr_+i*B.stride_, B.size(1)*B.stride_, beta, \
+                   C.dptr_+i*C.stride_, C.size(1)*C.stride_); \
+  } \
+}
+
+LINALG_CPU_GEMM_AXIS(sgemm, float)
+LINALG_CPU_GEMM_AXIS(dgemm, double)
+
+// Version where matrix rows are given by the second axis.
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
+                                   const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
+  linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+  for (index_t i = 0; i < A.size(0); ++i) { \
+    linalg_gemm_axis(A[i], B[i], C[i], alpha, beta, tA, tB, s); \
+  } \
+}
+
 #else
 
 #define LINALG_CPU_GEMM(fname, DType) \
@@ -98,6 +135,14 @@ void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 3, DType>& A, const Tensor<
   LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
 }
 
+#define LINALG_XPU_BATCH_GEMM_AXIS(xpu, DType) \
+template<> inline \
+void linalg_batch_gemm<xpu, DType>(const Tensor<xpu, 4, DType>& A, const Tensor<xpu, 4, DType>& B, \
+                                   const Tensor<xpu, 4, DType>& C, DType alpha, DType beta, \
+                                   bool tA, bool tB, Stream<xpu> *s) { \
+  LOG(FATAL) << "linalg_batch_gemm not implemented by mxnet for cpu, needs cblas!"; \
+}
+
 #endif  // MSHADOW_USE_CBLAS == 1 || MSHADOW_USE_MKL == 1
 
 LINALG_CPU_GEMM(sgemm, float)
@@ -106,6 +151,9 @@ LINALG_CPU_GEMM(dgemm, double)
 LINALG_XPU_BATCH_GEMM(cpu, float)
 LINALG_XPU_BATCH_GEMM(cpu, double)
 
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, float)
+LINALG_XPU_BATCH_GEMM_AXIS(cpu, double)
+
 // Specialization of linalg_gemm<cpu, DType> for DType=mshadow::half::half_t.
 template<> inline
 void linalg_gemm<cpu, mshadow::half::half_t>(const Tensor<cpu, 2, mshadow::half::half_t>& A,
@@ -140,6 +188,28 @@ void linalg_gemm<gpu, DType>(const Tensor<gpu, 2, DType>& A, const Tensor<gpu, 2
 LINALG_GPU_GEMM(Sgemm, float)
 LINALG_GPU_GEMM(Dgemm, double)
 
+// Version where matrix rows are given by first axis.
+#define LINALG_GPU_GEMM_AXIS(fname, DType) \
+template<> inline \
+void linalg_gemm_axis<gpu, DType>(const Tensor<gpu, 3, DType>& A, const Tensor<gpu, 3, DType>& B, \
+                                  const Tensor<gpu, 3, DType>& C, DType alpha, DType beta, \
+                                  bool tA, bool tB, Stream<gpu> *s) { \
+  using namespace mxnet; \
+  using mshadow::gpu; \
+  CHECK_NOTNULL(s); \
+  linalg_check_batch_size(A.size(1), B.size(1), C.size(1)); \
+  CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+                            (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+                            C.size(2), C.size(0), (tB ? B.size(2) : B.size(0)), &alpha, \
+                            B.dptr_, B.size(1)*B.stride_, B.stride_, \
+                            A.dptr_, A.size(1)*A.stride_, A.stride_, &beta, \
+                            C.dptr_, C.size(1)*C.stride_, C.stride_, A.size(1))) \
+}
+LINALG_GPU_GEMM_AXIS(SgemmStridedBatched, float)
+LINALG_GPU_GEMM_AXIS(DgemmStridedBatched, double)
+
+// Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
 // Specialization of linalg_gemm<gpu, DType> for DType=mshadow::half::half_t.
 template<> inline
 void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half::half_t>& A,
@@ -192,6 +262,8 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
 #if CUDA_VERSION < 8000
   LINALG_XPU_BATCH_GEMM(gpu, float)
   LINALG_XPU_BATCH_GEMM(gpu, double)
+  LINALG_XPU_BATCH_GEMM_AXIS(gpu, float)
+  LINALG_XPU_BATCH_GEMM_AXIS(gpu, double)
 #else
 #define LINALG_GPU_BATCH_GEMM(fname, DType) \
   template<> inline \
@@ -217,10 +289,125 @@ void linalg_gemm<gpu, mshadow::half::half_t>(const Tensor<gpu, 2, mshadow::half:
   LINALG_GPU_BATCH_GEMM(SgemmStridedBatched, float)
   LINALG_GPU_BATCH_GEMM(DgemmStridedBatched, double)
 
+// Version where matrix rows are given by second axis.
+#define LINALG_GPU_BATCH_GEMM_AXIS(fname, DType) \
+  template<> inline \
+  void linalg_batch_gemm<gpu, DType>(const Tensor<gpu, 4, DType>& A, \
+                                     const Tensor<gpu, 4, DType>& B, \
+                                     const Tensor<gpu, 4, DType>& C, DType alpha, DType beta, \
+                                     bool tA, bool tB, Stream<gpu> *s) { \
+    using namespace mxnet; \
+    using mshadow::gpu; \
+    CHECK_NOTNULL(s); \
+    linalg_check_batch_size(A.size(0), B.size(0), C.size(0)); \
+    linalg_check_batch_size(A.size(2), B.size(2), C.size(2)); \
+    for (index_t i = 0; i < A.size(2); ++i) { \
+      CUBLAS_CALL(cublas##fname(Stream<gpu>::GetBlasHandle(s), \
+          (tB ? CUBLAS_OP_T : CUBLAS_OP_N), \
+          (tA ? CUBLAS_OP_T : CUBLAS_OP_N), \
+          C.size(3), C.size(1), (tB ? B.size(3) : B.size(1)), &alpha, \
+          B.dptr_+i*B.stride_, B.size(2) * B.stride_, B.size(1)*B.size(2)*B.stride_, \
+          A.dptr_+i*A.stride_, A.size(2) * A.stride_, A.size(1)*A.size(2)*A.stride_, &beta, \
+          C.dptr_+i*C.stride_, C.size(2) * C.stride_, C.size(1)*C.size(2)*C.stride_, A.size(0))) \
+    }\
+  }
+
+  LINALG_GPU_BATCH_GEMM_AXIS(SgemmStridedBatched, float)
+  LINALG_GPU_BATCH_GEMM_AXIS(DgemmStridedBatched, double)
+
 #endif  // CUDA < 8000
 
 #endif  // __CUDACC__
 
+/*!
+ * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
+ *
+ * \param A the first operand of the gemm
+ * \param B the second operand of the gemm
+ * \param C the data to be assigned
+ * \param tA whether the `A` operand should be transposed first.
+ * \param tB whether the `B` operand should be transposed first.
+ * \param s the stream to perform the operation
+ * \param req the assignment request
+ */
+template<typename xpu, typename DType>
+inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
+                        const Tensor<xpu, 2, DType>& B,
+                        const Tensor<xpu, 2, DType>& C,
+                        bool tA, bool tB, Stream<xpu> *s,
+                        mxnet::OpReqType req) {
+  using namespace mxnet;
+  switch (req) {
+    case kNullOp:
+      break;
+    case kWriteTo:
+    case kWriteInplace:
+      linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
+      break;
+    case kAddTo:
+      linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
+      break;
+    default:
+      LOG(FATAL) << "not reached";
+  }
+}
+
+#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
+// A template for a cpu linalg_gemm implementation using mshadow::dot()
+#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
+template<> inline \
+void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+                             const Tensor<cpu, 2, DType>& B, \
+                             const Tensor<cpu, 2, DType>& C, \
+                             bool tA, bool tB, Stream<cpu> *s, \
+                             mxnet::OpReqType req) { \
+  using namespace mxnet; \
+  using mshadow::cpu; \
+  switch (req) { \
+    case kNullOp: \
+      break; \
+    case kWriteTo: \
+    case kWriteInplace: \
+      if (tA) { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
+        } \
+      } else { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
+        } \
+      } \
+      break; \
+    case kAddTo: \
+      if (tA) { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
+        } \
+      } else { \
+        if (tB) { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
+        } else { \
+          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
+        } \
+      } \
+      break; \
+    default: \
+      LOG(FATAL) << "not reached"; \
+  } \
+}
+
+LINALG_CPU_GEMM_NO_CBLAS(float)
+LINALG_CPU_GEMM_NO_CBLAS(double)
+
+#endif  // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
+
 //////////////////////////////// TRSM ////////////////////////////////////////////
 
 // CPU/GPU-versions of BLAS3 function "trsm". Please refer to the BLAS3-documentation
@@ -313,95 +500,6 @@ LINALG_XPU_BATCH_TRSM(gpu, double)
 
 #endif  // __CUDACC__
 
-/*!
- * \brief Performs gemm, setting alpha and beta as appropriate for `req`.
- *
- * \param A the first operand of the gemm
- * \param B the second operand of the gemm
- * \param C the data to be assigned
- * \param tA whether the `A` operand should be transposed first.
- * \param tB whether the `B` operand should be transposed first.
- * \param s the stream to perform the operation
- * \param req the assignment request
- */
-template<typename xpu, typename DType>
-inline void linalg_gemm(const Tensor<xpu, 2, DType>& A,
-                        const Tensor<xpu, 2, DType>& B,
-                        const Tensor<xpu, 2, DType>& C,
-                        bool tA, bool tB, Stream<xpu> *s,
-                        mxnet::OpReqType req) {
-  using namespace mxnet;
-  switch (req) {
-    case kNullOp:
-      break;
-    case kWriteTo:
-    case kWriteInplace:
-      linalg_gemm(A, B, C, DType(1.0), DType(0.0), tA, tB, s);
-      break;
-    case kAddTo:
-      linalg_gemm(A, B, C, DType(1.0), DType(1.0), tA, tB, s);
-      break;
-    default:
-      LOG(FATAL) << "not reached";
-  }
-}
-
-#if (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
-// A template for a cpu linalg_gemm implementation using mshadow::dot()
-#define LINALG_CPU_GEMM_NO_CBLAS(DType) \
-template<> inline \
-void linalg_gemm<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
-                             const Tensor<cpu, 2, DType>& B, \
-                             const Tensor<cpu, 2, DType>& C, \
-                             bool tA, bool tB, Stream<cpu> *s, \
-                             mxnet::OpReqType req) { \
-  using namespace mxnet; \
-  using mshadow::cpu; \
-  switch (req) { \
-    case kNullOp: \
-      break; \
-    case kWriteTo: \
-    case kWriteInplace: \
-      if (tA) { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A.T(), B); \
-        } \
-      } else { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) = dot(A, B); \
-        } \
-      } \
-      break; \
-    case kAddTo: \
-      if (tA) { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A.T(), B); \
-        } \
-      } else { \
-        if (tB) { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B.T()); \
-        } else { \
-          const_cast<Tensor<cpu, 2, DType>&>(C) += dot(A, B); \
-        } \
-      } \
-      break; \
-    default: \
-      LOG(FATAL) << "not reached"; \
-  } \
-}
-
-LINALG_CPU_GEMM_NO_CBLAS(float)
-LINALG_CPU_GEMM_NO_CBLAS(double)
-
-#endif  // (MSHADOW_USE_CBLAS == 0 && MSHADOW_USE_MKL == 0)
-
 //////////////////////////////// TRMM ////////////////////////////////////////////
 
 // CPU/GPU-versions of BLAS3 function "trmm". Please refer to the BLAS3-documentation
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 7083efe2f1c..b1771650215 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -46,8 +46,20 @@ If *n=2*, the BLAS3 function *gemm* is performed:
 Here, *alpha* and *beta* are scalar parameters, and *op()* is either the identity or
 matrix transposition (depending on *transpose_a*, *transpose_b*).
 
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis* 
+parameter. By default, the trailing two dimensions will be used for matrix encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B*, *C* be 5 dimensional tensors. Then gemm(*A*, *B*, *C*, axis=1) is equivalent to
+
+    A1 = swapaxes(A, dim1=1, dim2=3)
+    B1 = swapaxes(B, dim1=1, dim2=3)
+    C = swapaxes(C, dim1=1, dim2=3)
+    C = gemm(A1, B1, C)
+    C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
 
 .. note:: The operator supports float32 and float64 data types only.
 
@@ -76,7 +88,7 @@ Examples::
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
   { return std::vector<std::pair<int, int>>{{2, 0}}; })
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 3, 1, gemm>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 3, 1, gemm>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_linalg_gemm"})
 .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
 .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -92,7 +104,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 4, 3, gemm_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 4, 3, gemm_backward>);
 
 NNVM_REGISTER_OP(_linalg_gemm2)
 .add_alias("linalg_gemm2")
@@ -107,8 +119,19 @@ If *n=2*, the BLAS3 function *gemm* is performed:
 Here *alpha* is a scalar parameter and *op()* is either the identity or the matrix
 transposition (depending on *transpose_a*, *transpose_b*).
 
-If *n>2*, *gemm* is performed separately on the trailing two dimensions for all inputs
-(batch mode).
+If *n>2*, *gemm* is performed separately for a batch of matrices. The column indices of the matrices
+are given by the last dimensions of the tensors, the row indices by the axis specified with the *axis* 
+parameter. By default, the trailing two dimensions will be used for matrix encoding.
+
+For a non-default axis parameter, the operation performed is equivalent to a series of swapaxes/gemm/swapaxes
+calls. For example let *A*, *B* be 5 dimensional tensors. Then gemm(*A*, *B*, axis=1) is equivalent to
+
+    A1 = swapaxes(A, dim1=1, dim2=3)
+    B1 = swapaxes(B, dim1=1, dim2=3)
+    C = gemm2(A1, B1)
+    C = swapaxis(C, dim1=1, dim2=3)
+
+without the overhead of the additional swapaxis operations.
 
 .. note:: The operator supports float32 and float64 data types only.
 
@@ -133,7 +156,7 @@ Examples::
   { return std::vector<std::string>{"A", "B"}; } )
 .set_attr<nnvm::FInferShape>("FInferShape", LaMatrixMultMacOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
-.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 2, 1, gemm2>)
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmForward<cpu, 2, 2, 2, 1, gemm2>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_linalg_gemm2"})
 .add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices")
 .add_argument("B", "NDArray-or-Symbol", "Tensor of input matrices")
@@ -148,7 +171,7 @@ NNVM_REGISTER_OP(_backward_linalg_gemm2)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 2, gemm2_backward>);
+.set_attr<FCompute>("FCompute<cpu>", LaOpGemmBackward<cpu, 2, 2, 3, 2, gemm2_backward>);
 
 NNVM_REGISTER_OP(_linalg_potrf)
 .add_alias("linalg_potrf")
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index efd705f3f44..d736845df2b 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -28,16 +28,16 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 3, 1, gemm>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 3, 1, gemm>);
 
 NNVM_REGISTER_OP(_backward_linalg_gemm)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 3, gemm_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 4, 3, gemm_backward>);
 
 NNVM_REGISTER_OP(_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, gemm2>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmForward<gpu, 2, 2, 2, 1, gemm2>);
 
 NNVM_REGISTER_OP(_backward_linalg_gemm2)
-.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 3, 2, gemm2_backward>);
+.set_attr<FCompute>("FCompute<gpu>", LaOpGemmBackward<gpu, 2, 2, 3, 2, gemm2_backward>);
 
 NNVM_REGISTER_OP(_linalg_trmm)
 .set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 2, 1, trmm>);
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 3d411b2d718..8e2acd747aa 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -40,6 +40,7 @@ namespace op {
 struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
   bool transpose_a, transpose_b;
   double alpha, beta;
+  int axis;
   DMLC_DECLARE_PARAMETER(LaMatrixMacParam) {
     DMLC_DECLARE_FIELD(transpose_a)
       .set_default(false)
@@ -53,6 +54,9 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
     DMLC_DECLARE_FIELD(beta)
       .set_default(1.0)
       .describe("Scalar factor multiplied with C.");
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(-2)
+      .describe("Axis corresponding to the matrix rows.");
   }
 };
 
@@ -60,6 +64,7 @@ struct LaMatrixMacParam : public dmlc::Parameter<LaMatrixMacParam> {
 struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
   bool transpose_a, transpose_b;
   double alpha;
+  int axis;
   DMLC_DECLARE_PARAMETER(LaMatrixMultParam) {
     DMLC_DECLARE_FIELD(transpose_a)
       .set_default(false)
@@ -70,6 +75,9 @@ struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
     DMLC_DECLARE_FIELD(alpha)
       .set_default(1.0)
       .describe("Scalar factor multiplied with A*B.");
+    DMLC_DECLARE_FIELD(axis)
+      .set_default(-2)
+      .describe("Axis corresponding to the matrix row indices.");
   }
 };
 
@@ -112,30 +120,37 @@ inline bool LaMatrixMultMacOpShape(const nnvm::NodeAttrs& attrs,
   CHECK_GE(in_attrs->size(), 2);
   CHECK_EQ(out_attrs->size(), 1);
   bool transpose_a(false), transpose_b(false);
+  int axis_param(-2);
   if ( in_attrs->size() == 2 ) {
      // Matrix-Matrix mult
      transpose_a = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_a;
      transpose_b = nnvm::get<LaMatrixMultParam>(attrs.parsed).transpose_b;
+     axis_param  = nnvm::get<LaMatrixMultParam>(attrs.parsed).axis;
   } else {
      // Matrix-Matrix mac
      transpose_a = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_a;
      transpose_b = nnvm::get<LaMatrixMacParam>(attrs.parsed).transpose_b;
+     axis_param  = nnvm::get<LaMatrixMacParam>(attrs.parsed).axis;
   }
   if ( (*in_attrs)[0].ndim() >= 2 && (*in_attrs)[0].ndim() == (*in_attrs)[1].ndim() ) {
     // Forward shape inference.
-    const int ndim((*in_attrs)[0].ndim());
+    const int ndim((*in_attrs)[0].ndim()), axis(axis_param < 0 ? ndim + axis_param : axis_param);
+    CHECK(axis >= 0 && axis < ndim-1)
+      << "Invalid row axis (" << axis_param << ")";
     std::vector<int> oshape(ndim);
-    for ( int i = 0; i < ndim-2; ++i ) {
-      // Both inputs must have same shape except for last two dimensions.
-      CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
-        << "Shapes of inputs 0, 1 must be the same, except on last two dimensions";
+    for ( int i = 0; i < ndim-1; ++i ) {
+      if (i != axis) {
+        // Both inputs must have same shape except for row/col dimensions.
+        CHECK_EQ((*in_attrs)[0][i], (*in_attrs)[1][i])
+          << "Shapes of inputs 0, 1 must be the same, except on row/col axis";
+      }
       oshape[i] = (*in_attrs)[0][i];
     }
-    CHECK_EQ((transpose_a ? (*in_attrs)[0][ndim-2] : (*in_attrs)[0][ndim-1]),
-             (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][ndim-2]))
+    CHECK_EQ((transpose_a ? (*in_attrs)[0][axis] : (*in_attrs)[0][ndim-1]),
+             (transpose_b ? (*in_attrs)[1][ndim-1] : (*in_attrs)[1][axis]))
              << "Incompatible matrix dimensions for multiplication";
-    oshape[ndim-2] = (transpose_a ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][ndim-2]);
-    oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][ndim-2] : (*in_attrs)[1][ndim-1]);
+    oshape[axis] = (transpose_a ? (*in_attrs)[0][ndim-1] : (*in_attrs)[0][axis]);
+    oshape[ndim-1] = (transpose_b ? (*in_attrs)[1][axis] : (*in_attrs)[1][ndim-1]);
     TShape tshape(oshape.begin(), oshape.end());
     SHAPE_ASSIGN_CHECK(*out_attrs, 0, tshape);
     if ( in_attrs->size() > 2 ) {
@@ -340,6 +355,33 @@ inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
   return false;
 }
 
+// Flattener for following adaptors.
+template<typename xpu, int dim, typename DType>
+mshadow::Tensor<xpu, dim, DType> LaOpFlatten(const TBlob& blob,
+                                             mshadow::Stream<xpu> *s, int axis = -2) {
+  if (axis < 0) {
+    axis = blob.ndim() + axis;
+  }
+  if (axis >= blob.ndim()-2) {
+    // Leave highest axis, collapse rest.
+    return blob.FlatToKD<xpu, dim, DType>(s);
+  }
+  // Collapse ranges [0,axis-1] and [axis+1,ndim-2].
+  CHECK_EQ(dim, 4);
+  TShape shape(dim);
+  shape[0] = 1;
+  for (int i = 0; i < axis; ++i) {
+    shape[0] *= blob.shape_[i];
+  }
+  shape[1] = blob.shape_[axis];
+  shape[2] = 1;
+  for (int i = axis+1; i < blob.ndim()-1; ++i) {
+    shape[2] *= blob.shape_[i];
+  }
+  shape[3] = blob.shape_[blob.ndim()-1];
+  return blob.get_with_shape<xpu, dim, DType>(shape.get<dim>(), s);
+}
+
 // Adapters for calling the various operators with appropriate signatures.
 
 template<typename xpu, typename DType, int idim, int odim, int inum, int onum, typename laop>
@@ -347,7 +389,7 @@ struct LaOpCaller {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     CHECK(false) << "no specialized LaOpCaller defined for template parameters";
   }
 };
@@ -356,10 +398,10 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -367,11 +409,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 1, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -379,11 +421,11 @@ struct LaOpCaller<xpu, DType, idim, odim, 2, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -391,12 +433,12 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -404,13 +446,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 3, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -418,13 +460,13 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 1, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -432,14 +474,14 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 2, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis), ctx, attrs);
   }
 };
 template<typename xpu, typename DType, int idim, int odim, typename laop>
@@ -447,15 +489,15 @@ struct LaOpCaller<xpu, DType, idim, odim, 4, 3, laop> {
   static void op(const std::vector<TBlob>& inputs,
                  const std::vector<TBlob>& outputs,
                  const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx) {
+                 const OpContext& ctx, int axis = -2) {
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    laop::op(inputs[0].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[1].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[2].FlatToKD<xpu, idim+1, DType>(s),
-             inputs[3].FlatToKD<xpu, idim+1, DType>(s),
-             outputs[0].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[1].FlatToKD<xpu, odim+1, DType>(s),
-             outputs[2].FlatToKD<xpu, odim+1, DType>(s), ctx, attrs);
+    laop::op(LaOpFlatten<xpu, idim+1, DType>(inputs[0], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[1], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[2], s, axis),
+             LaOpFlatten<xpu, idim+1, DType>(inputs[3], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[0], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[1], s, axis),
+             LaOpFlatten<xpu, odim+1, DType>(outputs[2], s, axis), ctx, attrs);
   }
 };
 
@@ -504,6 +546,64 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
   });
 }
 
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmForward(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  CHECK_EQ(inputs.size(), inum);
+  CHECK_EQ(outputs.size(), onum);
+  const int axis(inputs.size() == 2 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+                                    : nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    if (axis == -2 || axis == inputs[0].ndim()-2) {
+      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+                                                               attrs, ctx);
+    } else {
+      LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, outputs,
+                                                                   attrs, ctx, axis);
+    }
+  });
+}
+
+template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
+void LaOpGemmBackward(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  CHECK_EQ(inputs.size(), inum);
+  CHECK_EQ(outputs.size(), onum);
+  const int axis(inputs.size() == 3 ? nnvm::get<LaMatrixMultParam>(attrs.parsed).axis
+                                    : nnvm::get<LaMatrixMacParam>(attrs.parsed).axis);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    std::vector<TBlob> tspace(outputs);
+    for ( int i = 0; i < onum; ++i ) {
+      if ( req[i] == kAddTo ) {
+        tspace[i].dptr_ = ctx.requested[0]
+                             .get_space_typed<xpu, 1, OType>(Shape1(outputs[i].Size()), s).dptr_;
+      }
+    }
+    if (axis == -2 || axis == inputs[0].ndim()-2) {
+      LaOpCaller<xpu, OType, idim, odim, inum, onum, laop>::op(inputs, outputs,
+                                                               attrs, ctx);
+    } else {
+      LaOpCaller<xpu, OType, idim+1, odim+1, inum, onum, laop>::op(inputs, outputs,
+                                                                   attrs, ctx, axis);
+    }
+    for ( int i = 0; i < onum; ++i ) {
+      if ( req[i] == kAddTo ) {
+        Tensor<xpu, 1, OType> out = outputs[i].FlatTo1D<xpu, OType>(s);
+        out += tspace[i].FlatTo1D<xpu, OType>(s);
+      }
+    }
+  });
+}
+
 // Specific wrapper for syevd (cannot use the default ones, because A, U have
 // different dimensionality than L
 
diff --git a/src/operator/tensor/la_op_inline.h b/src/operator/tensor/la_op_inline.h
index a508eb77364..b483108970a 100644
--- a/src/operator/tensor/la_op_inline.h
+++ b/src/operator/tensor/la_op_inline.h
@@ -60,24 +60,24 @@ struct Scale {
 
 // D = gemm(A,B,C)
 struct gemm {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& C, DType alpha, DType beta,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, DType alpha, DType beta,
                  bool tA, bool tB, Stream<xpu> *s) {
     linalg_batch_gemm(A, B, C, alpha, beta, tA, tB, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& D,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, DType>& D,
                  Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
     if ( C.dptr_ != D.dptr_ ) Copy(D, C, s);
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
     op(A, B, D, DType(param.alpha), DType(param.beta), param.transpose_a,
        param.transpose_b, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& C, const Tensor<xpu, 3, DType>& D,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const Tensor<xpu, dim, DType>& D,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(A, B, C, D, s, attrs);
@@ -86,17 +86,17 @@ struct gemm {
 
 // C = gemm2(A,B)
 struct gemm2 {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, Stream<xpu> *s,
                  const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed);
     gemm::op(A, B, C, DType(param.alpha), DType(0), param.transpose_a,
              param.transpose_b, s);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& A, const Tensor<xpu, dim, DType>& B,
+                 const Tensor<xpu, dim, DType>& C, const OpContext& ctx,
                  const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(A, B, C, s, attrs);
@@ -343,11 +343,11 @@ struct syevd {
 // Backward operators (always using batch processing)
 
 struct gemm_backward {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
-                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
-                 const Tensor<xpu, 3, DType>& dC,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& C,
+                 const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, DType>& dB,
+                 const Tensor<xpu, dim, DType>& dC,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMacParam& param = nnvm::get<LaMatrixMacParam>(attrs.parsed);
     bool tA(param.transpose_a), tB(param.transpose_b);
@@ -359,11 +359,11 @@ struct gemm_backward {
     using namespace mxnet_op;
     Kernel<Scale, xpu>::Launch(s, dC.MSize(), DType(param.beta), dC.dptr_);
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dD, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
-                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
-                 const Tensor<xpu, 3, DType>& dC,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dD, const Tensor<xpu, dim, DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& C,
+                 const Tensor<xpu, dim, DType>& dA, const Tensor<xpu, dim, DType>& dB,
+                 const Tensor<xpu, dim, DType>& dC,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(dD, A, B, C, dA, dB, dC, s, attrs);
@@ -371,10 +371,10 @@ struct gemm_backward {
 };
 
 struct gemm2_backward {
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
-                 const Tensor<xpu, 3, DType>& dB,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& dA,
+                 const Tensor<xpu, dim, DType>& dB,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
     const LaMatrixMultParam& param = nnvm::get<LaMatrixMultParam>(attrs.parsed);
     bool tA(param.transpose_a), tB(param.transpose_b);
@@ -383,10 +383,10 @@ struct gemm2_backward {
     (tB ? gemm::op(dC, A, dB, DType(param.alpha), DType(0), true, tA, s)
         : gemm::op(A, dC, dB, DType(param.alpha), DType(0), !tA, false, s));
   }
-  template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
-                 const Tensor<xpu, 3, DType>& dB,
+  template<typename xpu, int dim, typename DType>
+  static void op(const Tensor<xpu, dim, DType>& dC, const Tensor<xpu, dim, DType>& A,
+                 const Tensor<xpu, dim, DType>& B, const Tensor<xpu, dim, DType>& dA,
+                 const Tensor<xpu, dim, DType>& dB,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
     op(dC, A, B, dA, dB, s, attrs);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1db836b0918..fc4f6b36aef 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4596,7 +4596,24 @@ def test_laop():
     check_fw(test_gemm, [a, b, c], [r])
     if grad_check == 1:
         check_grad(test_gemm, [a, b, c])
-
+    # Check for different axis that describes matrix rows.
+    a2 = np.copy(np.swapaxes(a, 0, 2))
+    b2 = np.copy(np.swapaxes(b, 0, 2))
+    c2 = np.copy(np.swapaxes(c, 0, 2))
+    r2 = np.copy(np.swapaxes(r, 0, 2))
+    test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., axis = 0)
+    check_fw(test_gemm, [a2, b2, c2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2, c2])
+    a2 = np.copy(np.swapaxes(a, 1, 2))
+    b2 = np.copy(np.swapaxes(b, 1, 2))
+    c2 = np.copy(np.swapaxes(c, 1, 2))
+    r2 = np.copy(np.swapaxes(r, 1, 2))
+    test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7., axis = -3)
+    check_fw(test_gemm, [a2, b2, c2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2, c2])
+    
     # Check gemm2 operator same way as gemm.
     res_gemm = 4. * np.dot(data_in1, data_in2)
     test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.)
@@ -4628,6 +4645,20 @@ def test_laop():
     check_fw(test_gemm, [a, b], [r])
     if grad_check == 1:
         check_grad(test_gemm, [a, b])
+    a2 = np.copy(np.swapaxes(a, 0, 2))
+    b2 = np.copy(np.swapaxes(b, 0, 2))
+    r2 = np.copy(np.swapaxes(r, 0, 2))
+    test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = 0)
+    check_fw(test_gemm, [a2, b2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2])
+    a2 = np.copy(np.swapaxes(a, 1, 2))
+    b2 = np.copy(np.swapaxes(b, 1, 2))
+    r2 = np.copy(np.swapaxes(r, 1, 2))
+    test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., axis = -3)
+    check_fw(test_gemm, [a2, b2], [r2])
+    if grad_check == 1:
+        check_grad(test_gemm, [a2, b2])
 
     # Now test all the other operators.
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services