You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2020/08/19 00:37:51 UTC

[incubator-mxnet] branch master updated: Numpy Dot Large Tensor Fix (#18925)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8794a0a  Numpy Dot Large Tensor Fix (#18925)
8794a0a is described below

commit 8794a0adf4918513d441d2d4408ef7f28798c500
Author: Zhaoqi Zhu <zh...@usc.edu>
AuthorDate: Tue Aug 18 17:36:42 2020 -0700

    Numpy Dot Large Tensor Fix (#18925)
    
    * fix np dot
    
    * add test
    
    * fix test
    
    * tweak test
    
    Co-authored-by: Zhu <zh...@3c22fbbb4e1a.ant.amazon.com>
    Co-authored-by: Ubuntu <ub...@ip-172-31-10-124.us-west-2.compute.internal>
    Co-authored-by: Ubuntu <ub...@ip-172-31-6-47.us-west-2.compute.internal>
---
 3rdparty/mshadow/mshadow/dot_engine-inl.h | 24 ++++++++++++------------
 src/operator/numpy/np_tensordot_op-inl.h  | 24 ++++++++++++------------
 tests/nightly/test_np_large_array.py      | 13 +++++++++++++
 3 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/3rdparty/mshadow/mshadow/dot_engine-inl.h b/3rdparty/mshadow/mshadow/dot_engine-inl.h
index 225821e..9327315 100644
--- a/3rdparty/mshadow/mshadow/dot_engine-inl.h
+++ b/3rdparty/mshadow/mshadow/dot_engine-inl.h
@@ -299,17 +299,17 @@ struct BLASEngine<cpu, float> {
   }
   inline static void gemm(Stream<cpu> *stream,
                           bool transa, bool transb,
-                          int m, int n, int k, float alpha,
-                          const float *A, int lda, const float *B, int ldb,
-                          float beta, float *C, int ldc) {
+                          index_t m, index_t n, index_t k, float alpha,
+                          const float *A, index_t lda, const float *B, index_t ldb,
+                          float beta, float *C, index_t ldc) {
     cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
                 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
   }
   inline static void batched_gemm(Stream<cpu> *stream,
                                   bool transa, bool transb,
-                                  int m, int n, int k, float alpha,
-                                  const float *A, int lda, const float *B, int ldb,
-                                  float beta, float *C, int ldc, int batch_count,
+                                  index_t m, index_t n, index_t k, float alpha,
+                                  const float *A, index_t lda, const float *B, index_t ldb,
+                                  float beta, float *C, index_t ldc, index_t batch_count,
                                   float **workspace) {
 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
   // since same m/n/k is used for all single gemms, so we put all gemms into one group
@@ -408,17 +408,17 @@ struct BLASEngine<cpu, double> {
   }
   inline static void gemm(Stream<cpu> *stream,
                           bool transa, bool transb,
-                          int m, int n, int k, double alpha,
-                          const double *A, int lda, const double *B, int ldb,
-                          double beta, double *C, int ldc) {
+                          index_t m, index_t n, index_t k, double alpha,
+                          const double *A, index_t lda, const double *B, index_t ldb,
+                          double beta, double *C, index_t ldc) {
     cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
                 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
   }
   inline static void batched_gemm(Stream<cpu> *stream,
                                   bool transa, bool transb,
-                                  int m, int n, int k, double alpha,
-                                  const double *A, int lda, const double *B, int ldb,
-                                  double beta, double *C, int ldc, int batch_count,
+                                  index_t m, index_t n, index_t k, double alpha,
+                                  const double *A, index_t lda, const double *B, index_t ldb,
+                                  double beta, double *C, index_t ldc, index_t batch_count,
                                   double **workspace) {
 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
   // since same m/n/k is used for all single gemms, so we put all gemms into one group
diff --git a/src/operator/numpy/np_tensordot_op-inl.h b/src/operator/numpy/np_tensordot_op-inl.h
index d025f15..1e5ba7b 100644
--- a/src/operator/numpy/np_tensordot_op-inl.h
+++ b/src/operator/numpy/np_tensordot_op-inl.h
@@ -60,10 +60,10 @@ inline void ShiftAxes(Tuple<int>* axes_summed, const int ndim) {
 /**
  * Gets matrix dimensions of a and b after transpose and reshape.
  */
-inline void GetMatrixDimensions(int* ad1,
-                                int* ad2,
-                                int* bd1,
-                                int* bd2,
+inline void GetMatrixDimensions(index_t* ad1,
+                                index_t* ad2,
+                                index_t* bd1,
+                                index_t* bd2,
                                 const mxnet::Tuple<int>& a_axes_remained,
                                 const mxnet::Tuple<int>& a_axes_summed,
                                 const mxnet::Tuple<int>& b_axes_remained,
@@ -157,10 +157,10 @@ void MatrixDot(const OpContext& ctx,
                const TBlob& b,
                const TBlob& out,
                const OpReqType req,
-               const int ad1,
-               const int ad2,
-               const int bd1,
-               const int bd2,
+               const index_t ad1,
+               const index_t ad2,
+               const index_t bd1,
+               const index_t bd2,
                const bool aT = false,
                const bool bT = false) {
   using namespace mshadow;
@@ -266,7 +266,7 @@ void TensordotImpl(const Tuple<int>& a_axes_summed,
       GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
                        &b_axes, a_shape, b_shape);
 
-      int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+      index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
       GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
                           b_axes_remained, b_axes_summed, a_shape, b_shape);
 
@@ -435,7 +435,7 @@ void TensordotBackwardImpl(const Tuple<int>& a_axes_summed,
       GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
                       &b_axes, a_shape, b_shape);
 
-      int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+      index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
       GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
                           b_axes_remained, b_axes_summed, a_shape, b_shape);
 
@@ -653,7 +653,7 @@ void TensordotIntAxesImpl(const int axes,
       GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
                       &b_axes, a_shape, b_shape);
 
-      int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+      index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
       GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
                           b_axes_remained, b_axes_summed, a_shape, b_shape);
       MatrixDot<xpu>(ctx, a, b, out, req, ad1, ad2, bd1, bd2);
@@ -746,7 +746,7 @@ void TensordotIntAxesBackwardImpl(const int axes,
       GetReorderedAxes(a_axes_summed, &a_axes_remained, &a_axes, b_axes_summed, &b_axes_remained,
                       &b_axes, a_shape, b_shape);
 
-      int ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
+      index_t ad1 = 1, ad2 = 1, bd1 = 1, bd2 = 1;
       GetMatrixDimensions(&ad1, &ad2, &bd1, &bd2, a_axes_remained, a_axes_summed,
                           b_axes_remained, b_axes_summed, a_shape, b_shape);
 
diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py
index 072e80b..7f13135 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -36,6 +36,7 @@ MEDIUM_X = 10000
 LARGE_X = 100000000
 SMALL_X = 100
 SMALL_Y = 50
+INT_OVERFLOW = 2**31
 
 
 @use_np
@@ -76,3 +77,15 @@ def test_softmax():
         true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis]))
         output = npx.softmax(input_data, axis=axis)
         assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
+
+#@pytest.mark.skip(reason="CI hasn't switch to ILP64 OpenBLAS yet")
+@use_np
+def test_dot():
+    A = np.ones((1, INT_OVERFLOW), dtype='float32')
+    B = np.ones((INT_OVERFLOW, 1), dtype='float32')
+    A.attach_grad()
+    with mx.autograd.record():
+        C = np.dot(A, B)
+    assert_almost_equal(C.asnumpy(), [INT_OVERFLOW], rtol=1e-5, atol=1e-5)
+    C.backward()
+    assert A.grad.shape == (1, INT_OVERFLOW)