You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/13 17:11:05 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12904: support upper triangular matrices in linalg

sandeep-krishnamurthy closed pull request #12904: support upper triangular matrices in linalg
URL: https://github.com/apache/incubator-mxnet/pull/12904
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h
index b3353e285c7..e89a0824a94 100644
--- a/src/operator/tensor/la_op-inl.h
+++ b/src/operator/tensor/la_op-inl.h
@@ -21,6 +21,7 @@
  * Copyright (c) 2017 by Contributors
  * \file la_op-inl.h
  * \brief Operators for advanced linear algebra.
+ * \note  See https://arxiv.org/pdf/1710.08717.pdf for details of gradient computations.
  */
 #ifndef MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
 #define MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
@@ -32,20 +33,29 @@ namespace op {
 
 using namespace mshadow;
 
-// Helper functions.
-struct CopyLowerToUpper {
+// Copies lower/upper triangular part to upper/lower, i.e. to the opposite side.
+struct CopyTriangularToOppositeSide {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
+  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) {
     // Below computation works even when we are dealing with a batch of matrices.
     const int row((i % matrix_size) / stride), col(i % stride);
-    if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i];
+    if (row > col) {
+       if (to_lower) {
+         data[i] = data[i + (col - row) * (stride - 1)];
+       } else {
+         data[i + (col - row) * (stride - 1)] = data[i];
+       }
+    }
   }
 };
-struct ZeroUpper {
+
+// Zero's lower/upper triangular part of a matrix.
+struct ZeroTriangular {
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data) {
+  MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data,
+                                  bool zero_lower) {
     const int row((i % matrix_size) / stride), col(i % stride);
-    if ( row < col ) data[i] = 0;
+    if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0;
   }
 };
 struct Scale {
@@ -103,87 +113,91 @@ struct gemm2 {
   }
 };
 
-// L = potrf(A).
+// B = potrf(A).
 struct potrf {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
                  Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
-    if ( A.dptr_ != L.dptr_ ) Copy(L, A, s);
-    linalg_batch_potrf(L, true, s);
+    const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+    if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+    linalg_batch_potrf(B, param.lower, s);
     using namespace mxnet_op;
-    Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_, L.stride_, L.dptr_);
+    Kernel<ZeroTriangular, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_, B.stride_,
+                                   B.dptr_, !param.lower);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& L,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(A, L, s, attrs);
+    op(A, B, s, attrs);
   }
 };
 
-// A = potri(L).
+// A = potri(B).
 struct potri {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
+  static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
                  Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
-    if ( A.dptr_ != L.dptr_ ) Copy(A, L, s);
-    linalg_batch_potri(A, true, s);
+    const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+    if ( A.dptr_ != B.dptr_ ) Copy(A, B, s);
+    linalg_batch_potri(A, param.lower, s);
     using namespace mxnet_op;
-    Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_, A.dptr_);
+    Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_, A.stride_,
+                                          A.dptr_, !param.lower);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
+  static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& A,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(L, A, s, attrs);
+    op(B, A, s, attrs);
   }
 };
 
-// B = trsm(L,A)
+// C = trsm(A,B)
 struct trsm {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
-                 DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
-    linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s);
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
+                 DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
+    linalg_batch_trsm(A, C, alpha, rightside, lower, transpose, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& C,
                  Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
-    if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+    if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
     const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
-    op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
+    op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& C,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(L, A, B, s, attrs);
+    op(A, B, C, s, attrs);
   }
 };
 
-// B = trmm(L,A)
+// C = trmm(A,B)
 struct trmm {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& B,
-                 DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
-    linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s);
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& C,
+                 DType alpha, bool rightside, bool lower, bool transpose, Stream<xpu> *s) {
+    linalg_batch_trmm(A, C, alpha, rightside, lower, transpose, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, Stream<xpu> *s,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
                  const nnvm::NodeAttrs& attrs) {
-    if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+    if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
     const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
-    op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
+    op(A, C, DType(param.alpha), param.rightside, param.lower, param.transpose, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>& A,
-                 const Tensor<xpu, 3, DType>& B, const OpContext& ctx,
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
                  const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(L, A, B, s, attrs);
+    op(A, B, C, s, attrs);
   }
 };
 
@@ -223,8 +237,8 @@ struct syrk {
     linalg_batch_syrk(A, B, alpha, beta, tA, s);
     // Symmetric B is in lower triangle: Copy to upper
     using namespace mxnet_op;
-    Kernel<CopyLowerToUpper, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
-                                          B.stride_, B.dptr_);
+    Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
+                                          B.stride_, B.dptr_, false);
   }
   template<typename xpu, typename DType>
   static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
@@ -276,8 +290,8 @@ struct gelqf {
       Tensor<xpu, 2, DType> QLeft(Qi.dptr_, Shape2(m, m), Qi.stride_, s);
       Copy(Li, QLeft, s);
       using namespace mxnet_op;
-      Kernel<ZeroUpper, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
-                                     Li.dptr_);
+      Kernel<ZeroTriangular, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
+                                     Li.dptr_, false);
       // Call orglq: Input is Qi and part of work. Overwrites Qi by final Q
       // matrix (conversion from internal representation)
       linalg_orglq(Qi, work, s);
@@ -395,117 +409,129 @@ struct gemm2_backward {
 
 struct potrf_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
+  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
                  const Tensor<xpu, 3, DType>& dA,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
-    // Backward of L = potrf(A).
-    //   dA = 0.5 * L**T * copyLTU(L**T * dL) * L**(-1)
+    // Backward of B = potrf(A).
+    //   dA = 0.5 * B**T * copyLTU(B**T * dB) * B**(-1)
     // Here, copyLTU(M) creates a symmetric matrix from the square matrix M
     // by setting the upper triangle to be equal to the lower triangle, leaving
     // lower triangle and diagonal unchanged.
-    if ( dL.dptr_ != dA.dptr_ ) {
-      Copy(dA, dL, s);
+    // The function also handles the case when B is upper triangular by appropriate
+    // transpositions.
+    const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+    if ( dB.dptr_ != dA.dptr_ ) {
+      Copy(dA, dB, s);
     }
-    trmm::op(L, dA, DType(1.0), false, true, s);
+    trmm::op(B, dA, DType(1.0), !param.lower, param.lower, true, s);
     using namespace mxnet_op;
-    Kernel<CopyLowerToUpper, xpu>::Launch
-           (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_);
-    trsm::op(L, dA, DType(1.0), false, true, s);
-    trsm::op(L, dA, DType(0.5), true, false, s);
+    Kernel<CopyTriangularToOppositeSide, xpu>::Launch
+           (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_, !param.lower);
+    trsm::op(B, dA, DType(1.0), false, param.lower, param.lower, s);
+    trsm::op(B, dA, DType(0.5), true, param.lower, !param.lower, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& L,
+  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& B,
                  const Tensor<xpu, 3, DType>& dA,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(dL, L, dA, s, attrs);
+    op(dB, B, dA, s, attrs);
   }
 };
 
 struct potri_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
+  static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
-    // Backward of A = potri(L).
-    // dL = -tril( A * (dA + dA**T) * L**(-T)), where tril() extracts lower triangle
+    // Backward of A = potri(B).
+    // dB = -tril( A * (dA + dA**T) * B**(-T)), where tril() extracts lower triangle
     // and diagonal. We must not assume that dA is symmetric.
+    // The function also handles the case when B is upper triangular by appropriate
+    // transpositions.
     // Note: Calling gemm twice here is a bit wasteful, but otherwise the symmetrization
     // of dA would require temporary memory.
-    gemm::op(A, dA, dL, DType(1.), DType(0.), false, false, s);
-    gemm::op(A, dA, dL, DType(1.), DType(1.), false, true, s);
-    trsm::op(L, dL, DType(-1.), true, true, s);
+    const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+    if (param.lower) {
+      gemm::op(A, dA, dB, DType(1.), DType(0.), false, false, s);
+      gemm::op(A, dA, dB, DType(1.), DType(1.), false, true, s);
+    } else {
+      gemm::op(dA, A, dB, DType(1.), DType(0.), false, false, s);
+      gemm::op(dA, A, dB, DType(1.), DType(1.), true, false, s);
+    }
+    trsm::op(B, dB, DType(-1.), param.lower, param.lower, true, s);
     using namespace mxnet_op;
-    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
-                                   dL.dptr_);
+    Kernel<ZeroTriangular, xpu>::Launch(s, dB.MSize(), dB.size(1)*dB.stride_, dB.stride_,
+                                   dB.dptr_, !param.lower);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
+  static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& B,
+                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dB,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(dA, L, A, dL, s, attrs);
+    op(dA, B, A, dB, s, attrs);
   }
 };
 
 struct trsm_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
+  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
+                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
                  Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
-    // Backward of B = trsm(L,A).
+    // Backward of C = trsm(A,B).
     const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
+    // Compute dB
+    if ( dB.dptr_ != dC.dptr_ ) Copy(dB, dC, s);
+    trsm::op(A, dB, DType(param.alpha), param.rightside, param.lower, !param.transpose, s);
     // Compute dA
-    if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
-    trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s);
-    // Compute dL
     const bool da_left(param.rightside == param.transpose);
     DType scale(-1.0/param.alpha);
-    (da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose, !param.transpose, s)
-             : gemm::op(B, dA, dL, scale, DType(0), !param.transpose, param.transpose, s));
+    (da_left ? gemm::op(dB, C, dA, scale, DType(0), param.transpose, !param.transpose, s)
+             : gemm::op(C, dB, dA, scale, DType(0), !param.transpose, param.transpose, s));
     using namespace mxnet_op;
-    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_, dL.dptr_);
+    Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
+                                   dA.dptr_, !param.lower);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& B,
-                 const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>& dA,
+  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& C,
+                 const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>& dB,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(dB, L, A, B, dL, dA, s, attrs);
+    op(dC, A, B, C, dA, dB, s, attrs);
   }
 };
 
 struct trmm_backward {
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
-                 const Tensor<xpu, 3, DType>& dA, Stream<xpu>* s,
+  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
+                 const Tensor<xpu, 3, DType>& dB, Stream<xpu>* s,
                  const nnvm::NodeAttrs& attrs) {
-    // Backward of B = trmm(L,A).
+    // Backward of C = trmm(A,B).
     const LaTriangMatrixMultParam& param = nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
-    // Compute dL
+    // Compute dA
     DType scale(param.alpha);
     if (param.rightside == param.transpose) {
-      gemm::op(dB, A, dL, scale, DType(0.), param.transpose, !param.transpose, s);
+      gemm::op(dC, B, dA, scale, DType(0.), param.transpose, !param.transpose, s);
     } else {
-      gemm::op(A, dB, dL, scale, DType(0.), !param.transpose, param.transpose, s);
+      gemm::op(B, dC, dA, scale, DType(0.), !param.transpose, param.transpose, s);
     }
     using namespace mxnet_op;
-    Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_, dL.stride_,
-                                   dL.dptr_);
-    // Compute dA
-    if (dA.dptr_ != dB.dptr_) Copy(dA, dB, s);
-    trmm::op(L, dA, scale, param.rightside, !param.transpose, s);
+    Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_,
+                                   dA.dptr_, !param.lower);
+    // Compute dB
+    if (dB.dptr_ != dC.dptr_) Copy(dB, dC, s);
+    trmm::op(A, dB, scale, param.rightside, param.lower, !param.transpose, s);
   }
   template<typename xpu, typename DType>
-  static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>& L,
-                 const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>& dL,
-                 const Tensor<xpu, 3, DType>& dA, const OpContext& ctx,
+  static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>& A,
+                 const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& dA,
+                 const Tensor<xpu, 3, DType>& dB, const OpContext& ctx,
                  const nnvm::NodeAttrs& attrs) {
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    op(dB, L, A, dL, dA, s, attrs);
+    op(dC, A, B, dA, dB, s, attrs);
   }
 };
 
@@ -586,13 +612,13 @@ struct gelqf_backward {
     Tensor<xpu, 3, DType> tempM = ctx.requested[0]
       .get_space_typed<xpu, 3, DType>(dL.shape_, s);
     Copy(tempM, dL, s);
-    trmm::op(L, tempM, DType(1.0), false, true, s);
+    trmm::op(L, tempM, DType(1.0), false, true, true, s);
     gemm::op(dA, Q, tempM, DType(-1.0), DType(1.0), false, true, s);
-    Kernel<CopyLowerToUpper, xpu>::Launch
+    Kernel<CopyTriangularToOppositeSide, xpu>::Launch
            (s, tempM.MSize(), tempM.size(1)*tempM.stride_, tempM.stride_,
-            tempM.dptr_);
+            tempM.dptr_, false);
     gemm::op(tempM, Q, dA, DType(1.0), DType(1.0), false, false, s);
-    trsm::op(L, dA, DType(1.0), false, true, s);
+    trsm::op(L, dA, DType(1.0), false, true, true, s);
   }
 };
 
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 91bcdd314d8..f8a130d0ce4 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -30,6 +30,7 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(LaMatrixMacParam);
 DMLC_REGISTER_PARAMETER(LaMatrixMultParam);
+DMLC_REGISTER_PARAMETER(LaCholeskyParam);
 DMLC_REGISTER_PARAMETER(LaTriangMatrixMultParam);
 DMLC_REGISTER_PARAMETER(LaSyrkParam);
 
@@ -178,11 +179,12 @@ NNVM_REGISTER_OP(_linalg_potrf)
 .describe(R"code(Performs Cholesky factorization of a symmetric positive-definite matrix.
 Input is a tensor *A* of dimension *n >= 2*.
 
-If *n=2*, the Cholesky factor *L* of the symmetric, positive definite matrix *A* is
-computed. *L* is lower triangular (entries of upper triangle are all zero), has
+If *n=2*, the Cholesky factor *B* of the symmetric, positive definite matrix *A* is
+computed. *B* is triangular (entries of upper or lower triangle are all zero), has
 positive diagonal entries, and:
 
-  *A* = *L* \* *L*\ :sup:`T`
+  *A* = *B* \* *B*\ :sup:`T`  if *lower* = *true*
+  *A* = *B*\ :sup:`T` \* *B*  if *lower* = *false*
 
 If *n>2*, *potrf* is performed separately on the trailing two dimensions for all inputs
 (batch mode).
@@ -201,6 +203,7 @@ Examples::
 )code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
 .set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
   { return std::vector<std::string>{"A"}; } )
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
@@ -214,6 +217,7 @@ Examples::
 NNVM_REGISTER_OP(_backward_linalg_potrf)
 .set_num_inputs(2)
 .set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
   { return std::vector<std::pair<int, int> >{{0, 0}}; })
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
@@ -227,10 +231,11 @@ NNVM_REGISTER_OP(_linalg_potri)
 .describe(R"code(Performs matrix inversion from a Cholesky factorization.
 Input is a tensor *A* of dimension *n >= 2*.
 
-If *n=2*, *A* is a lower triangular matrix (entries of upper triangle are all zero)
+If *n=2*, *A* is a triangular matrix (entries of upper or lower triangle are all zero)
 with positive diagonal. We compute:
 
-  *out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1`
+  *out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1` if *lower* = *true*
+  *out* = *A*\ :sup:`-1` \* *A*\ :sup:`-T` if *lower* = *false*
 
 In other words, if *A* is the Cholesky factor of a symmetric positive definite matrix
 *B* (obtained by *potrf*), then
@@ -259,6 +264,7 @@ Examples::
 )code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
 .set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
   { return std::vector<std::string>{"A"}; } )
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
@@ -272,6 +278,7 @@ Examples::
 NNVM_REGISTER_OP(_backward_linalg_potri)
 .set_num_inputs(3)
 .set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
   { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
@@ -283,7 +290,7 @@ NNVM_REGISTER_OP(_linalg_trmm)
 Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape
 on the leading *n-2* dimensions.
 
-If *n=2*, *A* must be lower triangular. The operator performs the BLAS3 function
+If *n=2*, *A* must be triangular. The operator performs the BLAS3 function
 *trmm*:
 
    *out* = *alpha* \* *op*\ (*A*) \* *B*
@@ -346,7 +353,7 @@ NNVM_REGISTER_OP(_linalg_trsm)
 Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same shape
 on the leading *n-2* dimensions.
 
-If *n=2*, *A* must be lower triangular. The operator performs the BLAS3 function
+If *n=2*, *A* must be triangular. The operator performs the BLAS3 function
 *trsm*, solving for *out* in:
 
    *op*\ (*A*) \* *out* = *alpha* \* *B*
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 433789cf66b..0327dd19b72 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -81,10 +81,22 @@ struct LaMatrixMultParam : public dmlc::Parameter<LaMatrixMultParam> {
   }
 };
 
+// Parameters for Cholesky factorization and matrix inversion
+struct LaCholeskyParam : public dmlc::Parameter<LaCholeskyParam> {
+  bool lower;
+  DMLC_DECLARE_PARAMETER(LaCholeskyParam) {
+    DMLC_DECLARE_FIELD(lower)
+      .set_default(true)
+      .describe
+         ("True if the triangular matrix is lower triangular, false if it is upper triangular.");
+  }
+};
+
 // Parameters for matrix-matrix multiplication where one is a triangular matrix.
 struct LaTriangMatrixMultParam : public dmlc::Parameter<LaTriangMatrixMultParam> {
   bool transpose;
   bool rightside;
+  bool lower;
   double alpha;
   DMLC_DECLARE_PARAMETER(LaTriangMatrixMultParam) {
     DMLC_DECLARE_FIELD(transpose)
@@ -93,6 +105,10 @@ struct LaTriangMatrixMultParam : public dmlc::Parameter<LaTriangMatrixMultParam>
     DMLC_DECLARE_FIELD(rightside)
       .set_default(false)
       .describe("Multiply triangular matrix from the right to non-triangular one.");
+    DMLC_DECLARE_FIELD(lower)
+      .set_default(true)
+      .describe
+         ("True if the triangular matrix is lower triangular, false if it is upper triangular.");
     DMLC_DECLARE_FIELD(alpha)
       .set_default(1.0)
       .describe("Scalar factor to be applied to the result.");
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 27d75d132fd..ec7496a3606 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5228,7 +5228,7 @@ def _make_symm_symbol(a, ndims):
     tr_shape = tuple(tr_shape)
     return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))
 
-def _make_lower_triangle_symm(a, ndims, m, dtype=np.float32):
+def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
     assert ndims >= 2
     # The last two dimensions must both be m
     # Create mask for lower triangle and diagonal
@@ -5239,6 +5239,9 @@ def _make_lower_triangle_symm(a, ndims, m, dtype=np.float32):
         index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
         part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
         lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
+    if not lower:
+        lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
+        lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
     shp = tuple([1]*(ndims-2) + [m, m])
     lt_mask = mx.sym.reshape(lt_mask, shape=shp)
     return mx.sym.broadcast_mul(a, lt_mask)
@@ -5380,141 +5383,147 @@ def test_laop():
         check_grad(test_gemm, [a2, b2])
 
     # Now test all the other operators.
+    for lower in [True, False]:
+        upper = not lower
+
+        # Tests with trivial 1x1 matrices.
+        shape = (4, 4, 1, 1)
+        data_in = np.random.uniform(1, 10, shape)
+        # test potrf
+        # Note: Have to symmetrize input, for gradient test to work
+        res_potrf = np.sqrt(data_in)
+        test_potrf = mx.sym.linalg.potrf(data1, lower=lower)
+        check_fw(test_potrf, [data_in], [res_potrf])
+        if grad_check == 1:
+            check_grad(test_potrf, [data_in])
+        # test potri
+        ones = mx.nd.ones(shape).asnumpy()
+        res_potri = np.divide(ones, data_in * data_in)
+        test_potri = mx.sym.linalg.potri(data1, lower=lower)
+        check_fw(test_potri, [data_in], [res_potri])
+        if grad_check == 1:
+            check_grad(test_potri, [data_in])
+        # test trsm
+        trian_in = data_in * 7.
+        test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7., lower=lower)
+        check_fw(test_trsm, [trian_in, data_in], [ones])
+        if grad_check == 1:
+            check_grad(test_trsm, [trian_in,data_in])
+        # test trmm
+        trian_in = np.divide(ones, trian_in)
+        test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True,
+                                       rightside=True, lower=lower)
+        check_fw(test_trmm, [trian_in, data_in], [ones])
+        if grad_check == 1:
+            check_grad(test_trmm, [trian_in, data_in])
+        # test sumlogdiag
+        res_sumlogdiag = np.reshape(np.log(data_in), (4, 4))
+        test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1)
+        check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag])
+        if grad_check == 1:
+            check_grad(test_sumlogdiag, [data_in])
+
+        # more elaborate example of Cholesky factorization
+        matrix = np.array([[9., 3., -6., 12.],
+                           [3., 26., -7., -11.],
+                           [-6., -7., 9., 7.],
+                           [12., -11., 7., 65.]])
+        trian  = np.array([[3., 0., 0., 0.],
+                           [1., 5., 0., 0.],
+                           [-2., -1., 2., 0.],
+                           [4., -3., 6., 2.]])
+        pow    = np.array([[2., 1., 1., 1.],
+                           [1., 4., 1., 1.],
+                           [1., 1., 8., 1.],
+                           [1., 1., 1., 16.]])
+        inv    = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.],
+                           [0.05/3., 0.05, 0.05, 0.],
+                           [2.65, 0.05, 2.5, -0.75],
+                           [-2.5/3., 0., -0.75, 0.25]])
+        ident  = np.eye(4)
+
+        low_trian = trian
+        if not lower:
+            trian = np.transpose(trian)
+
+        # test potrf
+        test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4), lower=lower)
+        a = rep_3x(matrix, 4, 4)
+        r = rep_3x(trian, 4, 4)
+        check_fw(test_potrf, [a], [r])
+        if grad_check == 1:
+            check_grad(test_potrf, [a])
+
+        #test potri
+        data1_ltri = _make_triangle_symm(
+            data1, ndims=4, m=4, lower=lower, dtype=dtype)
+        test_potri = mx.sym.linalg.potri(data1_ltri, lower=lower)
+        a = rep_3x(trian, 4, 4)
+        r = rep_3x(inv, 4, 4)
+        check_fw(test_potri, [a], [r])
+        if grad_check == 1:
+            check_grad(test_potri, [a])
+
+        # test trsm
+        test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7., transpose=upper, lower=lower)
+        a = rep_3x(trian, 4, 4)
+        b = rep_3x(matrix, 4, 4)
+        r = rep_3x(7. * np.transpose(low_trian), 4, 4)
+        check_fw(test_trsm, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trsm, [a, b])
 
-    # Tests with trivial 1x1 matrices.
-    shape = (4, 4, 1, 1)
-    data_in = np.random.uniform(1, 10, shape)
-    # test potrf
-    # Note: Have to symmetrize input, for gradient test to work
-    res_potrf = np.sqrt(data_in)
-    test_potrf = mx.sym.linalg.potrf(data1)
-    check_fw(test_potrf, [data_in], [res_potrf])
-    if grad_check == 1:
-        check_grad(test_potrf, [data_in])
-    # test potri
-    ones = mx.nd.ones(shape).asnumpy()
-    res_potri = np.divide(ones, data_in * data_in)
-    test_potri = mx.sym.linalg.potri(data1)
-    check_fw(test_potri, [data_in], [res_potri])
-    if grad_check == 1:
-        check_grad(test_potri, [data_in])
-    # test trsm
-    trian_in = data_in * 7.
-    test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7.)
-    check_fw(test_trsm, [trian_in, data_in], [ones])
-    if grad_check == 1:
-        check_grad(test_trsm, [trian_in,data_in])
-    # test trmm
-    trian_in = np.divide(ones, trian_in)
-    test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True,
-                                   rightside=True)
-    check_fw(test_trmm, [trian_in, data_in], [ones])
-    if grad_check == 1:
-        check_grad(test_trmm, [trian_in, data_in])
-    # test sumlogdiag
-    res_sumlogdiag = np.reshape(np.log(data_in), (4, 4))
-    test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1)
-    check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag])
-    if grad_check == 1:
-        check_grad(test_sumlogdiag, [data_in])
-
-    # more elaborate example of Cholesky factorization
-    matrix = np.array([[9., 3., -6., 12.],
-                       [3., 26., -7., -11.],
-                       [-6., -7., 9., 7.],
-                       [12., -11., 7., 65.]])
-    trian  = np.array([[3., 0., 0., 0.],
-                       [1., 5., 0., 0.],
-                       [-2., -1., 2., 0.],
-                       [4., -3., 6., 2.]])
-    pow    = np.array([[2., 1., 1., 1.],
-                       [1., 4., 1., 1.],
-                       [1., 1., 8., 1.],
-                       [1., 1., 1., 16.]])
-    inv    = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.],
-                       [0.05/3., 0.05, 0.05, 0.],
-                       [2.65, 0.05, 2.5, -0.75],
-                       [-2.5/3., 0., -0.75, 0.25]])
-    ident  = np.eye(4)
-
-    # test potrf
-    test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4))
-    a = rep_3x(matrix, 4, 4)
-    r = rep_3x(trian, 4, 4)
-    check_fw(test_potrf, [a], [r])
-    if grad_check == 1:
-        check_grad(test_potrf, [a])
-
-    #test potri
-    data1_ltri = _make_lower_triangle_symm(
-        data1, ndims=4, m=4, dtype=dtype)
-    test_potri = mx.sym.linalg.potri(data1_ltri)
-    a = rep_3x(trian, 4, 4)
-    r = rep_3x(inv, 4, 4)
-    check_fw(test_potri, [a], [r])
-    if grad_check == 1:
-        check_grad(test_potri, [a])
-
-    # test trsm
-    test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7.)
-    a = rep_3x(trian, 4, 4)
-    b = rep_3x(matrix, 4, 4)
-    r = rep_3x(7. * np.transpose(trian), 4, 4)
-    check_fw(test_trsm, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trsm, [a, b])
-
-    test_trsm2 = mx.sym.linalg.trsm(
-        data1_ltri, data2, alpha=-2., rightside=True, transpose=True)
-    r = rep_3x(-2. * trian, 4, 4)
-    check_fw(test_trsm2, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trsm2, [a, b])
+        test_trsm2 = mx.sym.linalg.trsm(
+            data1_ltri, data2, alpha=-2., rightside=True, transpose=lower, lower=lower)
+        r = rep_3x(-2. * low_trian, 4, 4)
+        check_fw(test_trsm2, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trsm2, [a, b])
 
-    test_trsm3 = mx.sym.linalg.trsm(
-        data1_ltri, data2, alpha=0.5, transpose=True)
-    b = rep_3x(np.transpose(trian), 4, 4)
-    r = rep_3x(0.5 * ident, 4, 4)
-    check_fw(test_trsm3, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trsm3, [a, b])
+        test_trsm3 = mx.sym.linalg.trsm(
+            data1_ltri, data2, alpha=0.5, transpose=lower, lower=lower)
+        b = rep_3x(np.transpose(low_trian), 4, 4)
+        r = rep_3x(0.5 * ident, 4, 4)
+        check_fw(test_trsm3, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trsm3, [a, b])
 
-    test_trsm4 = mx.sym.linalg.trsm(
-        data1_ltri, data2, alpha=-0.5, rightside=True)
-    b = rep_3x(trian, 4, 4)
-    r = rep_3x(-0.5 * ident, 4, 4)
-    check_fw(test_trsm4, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trsm4, [a, b])
-
-    # test trmm
-    test_trmm = mx.sym.linalg.trmm(
-        data1_ltri, data2, alpha=7., transpose=True, rightside=True)
-    a = rep_3x(trian, 4, 4)
-    b = rep_3x(matrix, 4, 4)
-    r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4)
-    check_fw(test_trmm, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trmm, [a, b])
+        test_trsm4 = mx.sym.linalg.trsm(
+            data1_ltri, data2, alpha=-0.5, rightside=True, transpose=upper, lower=lower)
+        b = rep_3x(low_trian, 4, 4)
+        r = rep_3x(-0.5 * ident, 4, 4)
+        check_fw(test_trsm4, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trsm4, [a, b])
+
+        # test trmm
+        test_trmm = mx.sym.linalg.trmm(
+            data1_ltri, data2, alpha=7., transpose=True, rightside=True, lower=lower)
+        a = rep_3x(trian, 4, 4)
+        b = rep_3x(matrix, 4, 4)
+        r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4)
+        check_fw(test_trmm, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trmm, [a, b])
 
-    test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2.)
-    r = rep_3x(-2. * np.dot(trian, matrix), 4, 4)
-    check_fw(test_trmm2, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trmm2, [a, b])
+        test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2., lower=lower)
+        r = rep_3x(-2. * np.dot(trian, matrix), 4, 4)
+        check_fw(test_trmm2, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trmm2, [a, b])
 
-    test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True)
-    r = rep_3x(np.dot(matrix, trian), 4, 4)
-    check_fw(test_trmm3, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trmm3, [a, b])
+        test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True, lower=lower)
+        r = rep_3x(np.dot(matrix, trian), 4, 4)
+        check_fw(test_trmm3, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trmm3, [a, b])
 
-    test_trmm4 = mx.sym.linalg.trmm(
-        data1_ltri, data2, alpha=1.2, transpose=True)
-    r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4)
-    check_fw(test_trmm4, [a, b], [r])
-    if grad_check == 1:
-        check_grad(test_trmm4, [a, b])
+        test_trmm4 = mx.sym.linalg.trmm(
+            data1_ltri, data2, alpha=1.2, transpose=True, lower=lower)
+        r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4)
+        check_fw(test_trmm4, [a, b], [r])
+        if grad_check == 1:
+            check_grad(test_trmm4, [a, b])
 
     # test sumlogdiag
     a = rep_3x(pow, 4, 4)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services