You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/03/27 20:50:26 UTC

[incubator-mxnet] branch numpy updated: [numpy] Fix numpy import in python2 (#14537)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy by this push:
     new db37dd9  [numpy] Fix numpy import in python2 (#14537)
db37dd9 is described below

commit db37dd9a93ab64f3247b9ca785b347d2239c8a70
Author: reminisce <wu...@gmail.com>
AuthorDate: Wed Mar 27 13:49:59 2019 -0700

    [numpy] Fix numpy import in python2 (#14537)
    
    * Fix several test failures
    
    * Fix subgraph op infer shape
    
    * Fix sparse slice
    
    * Fix deconv infer shape
    
    * Fix numpy import compatibility problem in python2
---
 python/mxnet/ndarray/_internal.py     |  2 --
 python/mxnet/ndarray/contrib.py       |  1 +
 python/mxnet/ndarray/register.py      |  7 ++---
 python/mxnet/symbol/_internal.py      |  2 --
 python/mxnet/symbol/register.py       |  7 ++---
 src/common/utils.h                    |  5 ++++
 src/operator/leaky_relu-inl.h         |  4 +--
 src/operator/nn/deconvolution-inl.h   | 12 +++++----
 src/operator/nn/deconvolution.cc      | 50 ++++++++++++++++++++++++++---------
 src/operator/tensor/matrix_op-inl.h   |  2 +-
 tests/python/unittest/test_ndarray.py |  6 ++++-
 11 files changed, 66 insertions(+), 32 deletions(-)

diff --git a/python/mxnet/ndarray/_internal.py b/python/mxnet/ndarray/_internal.py
index 5f3ce97..8045d9b 100644
--- a/python/mxnet/ndarray/_internal.py
+++ b/python/mxnet/ndarray/_internal.py
@@ -20,8 +20,6 @@
 import os as _os
 import sys as _sys
 
-import numpy as np
-
 try:
     if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
         from .._ctypes.ndarray import NDArrayBase, CachedOp
diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py
index 74c355d..1718a2c 100644
--- a/python/mxnet/ndarray/contrib.py
+++ b/python/mxnet/ndarray/contrib.py
@@ -18,6 +18,7 @@
 # coding: utf-8
 # pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
 """Contrib NDArray API of MXNet."""
+from __future__ import absolute_import
 import math
 import numpy as np
 from ..context import current_context
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 05d7f17..1ccf228 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -16,9 +16,10 @@
 # under the License.
 
 """Register backend ops in mxnet.ndarray namespace"""
+from __future__ import absolute_import
 import os as _os
 import ctypes
-import numpy as np  # pylint: disable=unused-import
+import numpy as _np  # pylint: disable=unused-import
 
 from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
 from ..ndarray_doc import _build_doc
@@ -103,7 +104,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
             if dtype_name is not None:
                 code.append("""
     if '%s' in kwargs:
-        kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
+        kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
             dtype_name, dtype_name, dtype_name))
             code.append("""
     _ = kwargs.pop('name', None)
@@ -136,7 +137,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
                 code.append("""
     if %s is not _Null:
         keys.append('%s')
-        vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
+        vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
 
     if not signature_only:
         code.append("""
diff --git a/python/mxnet/symbol/_internal.py b/python/mxnet/symbol/_internal.py
index 53fc684..7e9787e 100644
--- a/python/mxnet/symbol/_internal.py
+++ b/python/mxnet/symbol/_internal.py
@@ -22,8 +22,6 @@
 import sys as _sys
 import os as _os
 
-import numpy as np
-
 try:
     if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
         from .._ctypes.symbol import SymbolBase, _set_symbol_class
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index 15c8e5e..ac59f8b 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -17,9 +17,10 @@
 
 # pylint: disable=unused-import
 """Register backend ops in mxnet.symbol namespace."""
+from __future__ import absolute_import
 import os as _os
 import ctypes
-import numpy as np
+import numpy as _np
 
 from . import _internal
 from ._internal import SymbolBase, _symbol_creator
@@ -109,7 +110,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
             if dtype_name is not None:
                 code.append("""
     if '%s' in kwargs:
-        kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
+        kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
             dtype_name, dtype_name, dtype_name))
             code.append("""
     attr = kwargs.pop('attr', None)
@@ -175,7 +176,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
                 code.append("""
     if %s is not _Null:
         _keys.append('%s')
-        _vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
+        _vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
 
             code.append("""
     if not hasattr(NameManager._current, "value"):
diff --git a/src/common/utils.h b/src/common/utils.h
index 4843d7e..4fb398d 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -746,6 +746,11 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
  * 4. -1 dim size means the dimension's size is unknown.
  * so that operator's infer shape function can work in backend.
  * \param shape to be converted.
+ * Note: It is possible that the shape to be converted is already
+ * numpy compatible. For example, when a subgraph operator's infer
+ * shape function is called from the infer shape pass of the whole
+ * graph, its input/output shapes have been converted to numpy
+ * compatible shapes.
  */
 inline void ConvertToNumpyShape(mxnet::TShape* shape) {
   if (shape->ndim() == 0) {  // legacy shape ndim = 0 means unknown
diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index 22f5229..5518352 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -338,10 +338,10 @@ class LeakyReLUProp : public OperatorProperty {
       CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
     }
     const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
-    if (dshape.ndim() == 0) return false;
+    if (!mxnet::ndim_is_known(dshape)) return false;
     if (param_.act_type == leakyrelu::kPReLU) {
       const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
-      if (gshape.ndim() == 0) {
+      if (!mxnet::ndim_is_known(gshape)) {
         in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
       }
       if (dshape == gshape) {
diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h
index b28e478..5f3137f 100644
--- a/src/operator/nn/deconvolution-inl.h
+++ b/src/operator/nn/deconvolution-inl.h
@@ -134,11 +134,13 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
       for (size_t i = 0; i < ndim; i++) {
         // input.ndim() can be larger than ndim, in case that the complete input
         // shape was passed and not only the ndim last ones
-        o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i);
-        CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape";
-        o_pad[i] -= target_shape[i];
-        o_adj[i] = o_pad[i] % 2;
-        o_pad[i] = (o_pad[i] + 1) / 2;
+        if (mxnet::dim_size_is_known(input, input_ndim - ndim + i)) {
+          o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i);
+          CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape";
+          o_pad[i] -= target_shape[i];
+          o_adj[i] = o_pad[i] % 2;
+          o_pad[i] = (o_pad[i] + 1) / 2;
+        }
       }
     } else {
       for (size_t i = 0; i < ndim; i++) {
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index d8c91f7..09b255d 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -54,7 +54,7 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
   }
   out_shape->resize(1, mxnet::TShape());
   const mxnet::TShape &dshape = (*in_shape)[deconv::kData];
-  if (!shape_is_known(dshape)) return false;
+  if (!mxnet::ndim_is_known(dshape)) return false;
 
   if (param_.kernel.ndim() == 1) {
     // 1d conv
@@ -90,8 +90,12 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
     Shape<3> oshape;
     oshape[0] = dshape_ncw[0];
     oshape[1] = param_.num_filter;
-    oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) +
-      dilated_ksize_x - 2 * o_pad[0] + o_adj[0];
+    if (mxnet::dim_size_is_known(dshape_ncw[2])) {
+      oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) +
+          dilated_ksize_x - 2 * o_pad[0] + o_adj[0];
+    } else {
+      oshape[2] = -1;
+    }
 
     if (param_.target_shape.ndim() > 0) {
       if (param_.target_shape[0] > 0) {
@@ -141,10 +145,18 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
     Shape<4> oshape;
     oshape[0] = dshape_nchw[0];
     oshape[1] = param_.num_filter;
-    oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) +
-      dilated_ksize_y - 2 * o_pad[0] + o_adj[0];
-    oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) +
-      dilated_ksize_x - 2 * o_pad[1] + o_adj[1];
+    if (mxnet::dim_size_is_known(dshape_nchw[2])) {
+      oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) +
+          dilated_ksize_y - 2 * o_pad[0] + o_adj[0];
+    } else {
+      oshape[2] = -1;
+    }
+    if (mxnet::dim_size_is_known(dshape_nchw[3])) {
+      oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) +
+          dilated_ksize_x - 2 * o_pad[1] + o_adj[1];
+    } else {
+      oshape[3] = -1;
+    }
 
     if (param_.target_shape.ndim() > 1) {
       if (param_.target_shape[0] > 0) {
@@ -203,12 +215,24 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
     Shape<5> oshape;
     oshape[0] = dshape_ncdhw[0];
     oshape[1] = param_.num_filter;
-    oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) +
-      dilated_ksize_d - 2 * o_pad[0] + o_adj[0];
-    oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) +
-      dilated_ksize_y - 2 * o_pad[1] + o_adj[1];
-    oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) +
-      dilated_ksize_x - 2 * o_pad[2] + o_adj[2];
+    if (mxnet::dim_size_is_known(dshape_ncdhw[2])) {
+      oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) +
+          dilated_ksize_d - 2 * o_pad[0] + o_adj[0];
+    } else {
+      oshape[2] = -1;
+    }
+    if (mxnet::dim_size_is_known(dshape_ncdhw[3])) {
+      oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) +
+          dilated_ksize_y - 2 * o_pad[1] + o_adj[1];
+    } else {
+      oshape[3] = -1;
+    }
+    if (mxnet::dim_size_is_known(dshape_ncdhw[4])) {
+      oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) +
+          dilated_ksize_x - 2 * o_pad[2] + o_adj[2];
+    } else {
+      oshape[4] = -1;
+    }
 
     if (param_.target_shape.ndim() > 2) {
       if (param_.target_shape[0] > 0) {
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index f79af7c..65f3b2f 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -594,7 +594,7 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
   mxnet::TShape begin(N, -1), end(N, -1);
   for (int i = 0; i < N; ++i) {
     int s = 0;
-    if (param.begin[i]) {
+    if (i < param.begin.ndim() && param.begin[i]) {
       s = *param.begin[i];
       if (s < 0) s += ishape[i];
     }
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 3a17a1e..c71209f 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -120,7 +120,11 @@ def test_ndarray_setitem():
 
     # numpy assignment for empty axis
     for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
-        x = mx.nd.zeros(trivial_shape)
+        if trivial_shape == tuple():
+            with mx.numpy.enable_np_comp():
+                x = mx.nd.zeros(trivial_shape)
+        else:
+            x = mx.nd.zeros(trivial_shape)
         x[:] = np.ones(trivial_shape)
         x_np = np.ones(trivial_shape, dtype=x.dtype)
         assert x.shape == trivial_shape