You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/22 20:20:35 UTC
[incubator-mxnet] branch numpy updated: [numpy] Fix unit tests
after introducing numpy compatible shapes (#14487)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy by this push:
new 64c61b9 [numpy] Fix unit tests after introducing numpy compatible shapes (#14487)
64c61b9 is described below
commit 64c61b9d183774d53c3f42e8ba2caed4cdb8ec78
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri Mar 22 13:20:02 2019 -0700
[numpy] Fix unit tests after introducing numpy compatible shapes (#14487)
* Fix infer shape rnn
* Fix boolean mask and custom op unit tests
* Fix multi proposal
* Fix diag
* Add global switch for backward compatibility and fix infer shape bugs
* Fix slice op infer shape
* Fix rnn infer shape
* Add util funcs for ndim_is_known and dim_size_is_known
* Revert rnn_cell.py
---
include/mxnet/c_api.h | 15 +++-
include/mxnet/imperative.h | 16 ++++
include/mxnet/tuple.h | 28 ++++++-
python/mxnet/ndarray/ndarray.py | 5 +-
python/mxnet/numpy/__init__.py | 46 ++++++++++++
python/mxnet/operator.py | 26 +++----
python/mxnet/symbol/symbol.py | 22 ++++--
src/c_api/c_api.cc | 24 ++++--
src/c_api/c_api_common.h | 2 +-
src/c_api/c_api_executor.cc | 7 ++
src/c_api/c_api_ndarray.cc | 12 +++
src/c_api/c_api_symbolic.cc | 9 ++-
src/common/utils.h | 54 ++++++++++++++
src/executor/infer_graph_attr_pass.cc | 18 ++++-
src/imperative/imperative.cc | 4 +-
src/imperative/imperative_utils.h | 19 ++++-
src/operator/batch_norm_v1-inl.h | 2 +-
src/operator/contrib/multi_proposal-inl.h | 2 +-
src/operator/control_flow.cc | 6 +-
src/operator/convolution_v1-inl.h | 2 +-
src/operator/custom/custom.cc | 18 +++--
src/operator/nn/batch_norm.cc | 2 +-
src/operator/nn/concat.cc | 25 ++++---
src/operator/nn/convolution.cc | 85 ++++++++++++----------
src/operator/nn/cudnn/cudnn_batch_norm.cc | 2 +-
src/operator/nn/dropout.cc | 2 +-
src/operator/nn/fully_connected.cc | 2 +-
src/operator/nn/layer_norm.cc | 2 +-
src/operator/nn/pooling.cc | 2 +-
src/operator/operator_common.h | 6 +-
src/operator/pooling_v1-inl.h | 2 +-
src/operator/quantization/quantized_concat.cc | 2 +-
src/operator/random/unique_sample_op.h | 2 +-
src/operator/rnn-inl.h | 8 +-
src/operator/slice_channel-inl.h | 15 ++--
src/operator/softmax_output-inl.h | 2 +-
src/operator/softmax_output.cc | 4 +-
src/operator/svm_output-inl.h | 2 +-
src/operator/tensor/broadcast_reduce_op.h | 8 +-
src/operator/tensor/diag_op-inl.h | 4 +-
src/operator/tensor/elemwise_binary_broadcast_op.h | 2 +-
src/operator/tensor/init_op.h | 14 ++--
src/operator/tensor/matrix_op-inl.h | 26 +++----
tests/python/unittest/test_operator.py | 8 +-
44 files changed, 406 insertions(+), 158 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 5b77405..088117e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -170,7 +170,7 @@ typedef int (*CustomOpFBFunc)(int /*size*/, void** /*ptrs*/, int* /*tags*/,
typedef int (*CustomOpDelFunc)(void* /*state*/);
typedef int (*CustomOpListFunc)(char*** /*args*/, void* /*state*/);
typedef int (*CustomOpInferShapeFunc)(int /*num_input*/, int* /*ndims*/,
- unsigned** /*shapes*/, void* /*state*/);
+ int** /*shapes*/, void* /*state*/);
typedef int (*CustomOpInferStorageTypeFunc)(int /*num_input*/, int* /*stypes*/, void* /*state*/);
typedef int (*CustomOpBackwardInferStorageTypeFunc)(int /*num_input*/,
int * /*stypes*/,
@@ -1037,6 +1037,19 @@ MXNET_DLL int MXAutogradIsRecording(bool* curr);
*/
MXNET_DLL int MXAutogradIsTraining(bool* curr);
/*!
+ * \brief get whether numpy compatibility is on
+ * \param curr returns the current status
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXIsNumpyCompatible(bool* curr);
+/*!
+ * \brief set numpy compatibility switch
+ * \param is_np_comp 1 when numpy compatibility is on, 0 when off
+ * \param prev returns the previous status before this set
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSetIsNumpyCompatible(int is_np_comp, int* prev);
+/*!
* \brief mark NDArrays as variables to compute gradient for autograd
* \param num_var number of variable NDArrays
* \param var_handles variable NDArrays
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 52cedb2..ad20991 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -97,6 +97,16 @@ class Imperative {
is_recording_ = is_recording;
return old;
}
+ /*! brief whether numpy compatibility is on. */
+ bool is_np_comp() const {
+ return is_np_comp_;
+ }
+ /*! brief turn on or turn off numpy compatibility switch. */
+ bool set_is_np_comp(bool is_np_comp) {
+ bool old = is_np_comp_;
+ is_np_comp_ = is_np_comp;
+ return old;
+ }
/*! \brief to record operator, return corresponding node. */
void RecordOp(nnvm::NodeAttrs&& attrs,
const std::vector<NDArray*>& inputs,
@@ -165,9 +175,15 @@ class Imperative {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
static thread_local bool is_recording_;
+ // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
+ // Delete it in the next major release.
+ static thread_local bool is_np_comp_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
+ // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
+ // Delete it in the next major release.
+ static MX_THREAD_LOCAL bool is_np_comp_;
#endif
/*! \brief node count used for naming */
std::atomic<uint64_t> node_count_{0};
diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h
index 49852f7..d83e843 100644
--- a/include/mxnet/tuple.h
+++ b/include/mxnet/tuple.h
@@ -607,12 +607,36 @@ class TShape : public Tuple<dim_t> {
#endif
};
+/*! brief check if a shape's ndim is known. */
+inline bool ndim_is_known(const int ndim) {
+ CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;
+ return ndim != -1;
+}
+
+/*! brief check if a shape's ndim is known. */
+inline bool ndim_is_known(const TShape& x) {
+ return ndim_is_known(x.ndim());
+}
+
+/*! brief check if a shape's dim size is known. */
+inline bool dim_size_is_known(const int dim_size) {
+ CHECK_GE(dim_size, -1) << "shape dim size must be >= -1, while received " << dim_size;
+ return dim_size != -1;
+}
+
+/*! brief check if a shape's dim size is known. */
+inline bool dim_size_is_known(const TShape& x, const int idx) {
+ CHECK(idx >= 0 && idx < x.ndim())
+ << "idx = " << idx << " exceeds shape dimension range [0, " << x.ndim() << ")";
+ return dim_size_is_known(x[idx]);
+}
+
/*! brief check if shape is known using the NumPy compatible definition.
* zero-dim and zero-size tensors are valid. -1 means unknown.*/
inline bool shape_is_known(const TShape& x) {
- if (x.ndim() == -1) return false;
+ if (!ndim_is_known(x)) return false;
for (int i = 0; i < x.ndim(); ++i) {
- if (x[i] == -1) return false;
+ if (!dim_size_is_known(x, i)) return false;
}
return true;
}
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index b97bfe7..f867065 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1852,7 +1852,10 @@ fixed-size items.
pdata = ctypes.POINTER(mx_int)()
check_call(_LIB.MXNDArrayGetShape(
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
- return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index
+ if ndim.value == -1:
+ return None
+ else:
+ return tuple(pdata[:ndim.value]) # pylint: disable=invalid-slice-index
@property
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index b1139a0..e0dfda1 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -17,4 +17,50 @@
# specific language governing permissions and limitations
# under the License.
+import ctypes
+from ..base import _LIB, check_call
+
__all__ = []
+
+
+def set_np_comp(is_np_comp):
+ prev = ctypes.c_int()
+ check_call(_LIB.MXSetIsNumpyCompatible(ctypes.c_int(is_np_comp), ctypes.byref(prev)))
+ return bool(prev.value)
+
+
+def is_np_comp():
+ curr = ctypes.c_bool()
+ check_call(_LIB.MXIsNumpyCompatible(ctypes.byref(curr)))
+ return curr.value
+
+
+class _NumpyCompatibilityStateScope(object):
+ """Scope for managing numpy compatibility state.
+
+ Example::
+
+ with _NumpyCompatibilityStateScope(True):
+ y = model(x)
+ backward([y])
+
+ """
+ def __init__(self, is_np_comp): #pylint: disable=redefined-outer-name
+ self._enter_is_np_comp = is_np_comp
+ self._prev_is_np_comp = None
+
+ def __enter__(self):
+ if self._enter_is_np_comp is not None:
+ self._prev_is_np_comp = set_np_comp(self._enter_is_np_comp)
+
+ def __exit__(self, ptype, value, trace):
+ if self._enter_is_np_comp is not None and self._prev_is_np_comp != self._enter_is_np_comp:
+ set_np_comp(self._prev_is_np_comp)
+
+
+def enable_np_comp():
+ return _NumpyCompatibilityStateScope(True)
+
+
+def disable_np_comp():
+ return _NumpyCompatibilityStateScope(False)
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index e8fa571..2c69b9b 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -28,7 +28,7 @@ from threading import Lock
from ctypes import CFUNCTYPE, POINTER, Structure, pointer
from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool
-from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf
+from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int
from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
from . import symbol, context
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
@@ -164,7 +164,7 @@ class NumpyOp(PythonOp):
fb_functype = CFUNCTYPE(None, c_int, POINTER(POINTER(mx_float)), POINTER(c_int),
POINTER(POINTER(mx_uint)), POINTER(c_int), c_void_p)
infer_functype = CFUNCTYPE(None, c_int, POINTER(c_int),
- POINTER(POINTER(mx_uint)), c_void_p)
+ POINTER(POINTER(mx_int)), c_void_p)
list_functype = CFUNCTYPE(None, POINTER(POINTER(POINTER(c_char))), c_void_p)
class NumpyOpInfo(Structure):
"""Structure that holds Callback information. Passed to NumpyOpProp"""
@@ -214,9 +214,9 @@ class NumpyOp(PythonOp):
assert len(ishape) == n_in
rshape = list(ishape) + list(oshape)
for i in range(n_in+n_out):
- tensor_shapes[i] = cast(c_array_buf(mx_uint,
- array('I', rshape[i])),
- POINTER(mx_uint))
+ tensor_shapes[i] = cast(c_array_buf(mx_int,
+ array('i', rshape[i])),
+ POINTER(mx_int))
tensor_dims[i] = len(rshape[i])
def list_outputs_entry(out, _):
@@ -266,7 +266,7 @@ class NDArrayOp(PythonOp):
def get_symbol(self, *args, **kwargs):
fb_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_void_p), POINTER(c_int), c_void_p)
infer_functype = CFUNCTYPE(c_bool, c_int, POINTER(c_int),
- POINTER(POINTER(mx_uint)), c_void_p)
+ POINTER(POINTER(mx_int)), c_void_p)
list_functype = CFUNCTYPE(c_bool, POINTER(POINTER(POINTER(c_char))), c_void_p)
deps_functype = CFUNCTYPE(c_bool, c_int_p, c_int_p, c_int_p,
c_int_p, POINTER(c_int_p), c_void_p)
@@ -335,9 +335,9 @@ class NDArrayOp(PythonOp):
assert len(ishape) == n_in
rshape = list(ishape) + list(oshape)
for i in range(n_in+n_out):
- tensor_shapes[i] = cast(c_array_buf(mx_uint,
- array('I', rshape[i])),
- POINTER(mx_uint))
+ tensor_shapes[i] = cast(c_array_buf(mx_int,
+ array('i', rshape[i])),
+ POINTER(mx_int))
tensor_dims[i] = len(rshape[i])
except Exception:
print('Error in NDArrayOp.infer_shape: %s' % traceback.format_exc())
@@ -698,7 +698,7 @@ def register(reg_name):
del_functype = CFUNCTYPE(c_int, c_void_p)
infershape_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int),
- POINTER(POINTER(mx_uint)), c_void_p)
+ POINTER(POINTER(mx_int)), c_void_p)
infertype_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), c_void_p)
inferstorage_backward_functype = CFUNCTYPE(c_int, c_int, POINTER(c_int), \
@@ -747,9 +747,9 @@ def register(reg_name):
"shapes, got %d."%(n_aux, len(ashape))
rshape = list(ishape) + list(oshape) + list(ashape)
for i in range(n_in+n_out+n_aux):
- tensor_shapes[i] = cast(c_array_buf(mx_uint,
- array('I', rshape[i])),
- POINTER(mx_uint))
+ tensor_shapes[i] = cast(c_array_buf(mx_int,
+ array('i', rshape[i])),
+ POINTER(mx_int))
tensor_dims[i] = len(rshape[i])
infer_shape_entry._ref_holder = [tensor_shapes]
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 01b851c..73d4bab 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -42,6 +42,7 @@ from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from ..ndarray import _ndarray_cls
from ..executor import Executor
+from ..numpy import is_np_comp
from . import _internal
from . import op
from ._internal import SymbolBase, _set_symbol_class
@@ -1078,7 +1079,11 @@ class Symbol(SymbolBase):
arg_names = self.list_arguments()
unknowns = []
for name, shape in zip(arg_names, arg_shapes):
- if not shape or not _numpy.prod(shape):
+ if is_np_comp():
+ shape_is_none = not shape or -1 in shape
+ else:
+ shape_is_none = not shape or 0 in shape
+ if shape_is_none:
if len(unknowns) >= 10:
unknowns.append('...')
break
@@ -1204,12 +1209,15 @@ class Symbol(SymbolBase):
ctypes.byref(aux_shape_data),
ctypes.byref(complete)))
if complete.value != 0:
- arg_shapes = [
- tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)]
- out_shapes = [
- tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)]
- aux_shapes = [
- tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)]
+ arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]])
+ if arg_shape_ndim[i] >= 0 else None
+ for i in range(arg_shape_size.value)]
+ out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]])
+ if out_shape_ndim[i] >= 0 else None
+ for i in range(out_shape_size.value)]
+ aux_shapes = [tuple(aux_shape_data[i][:aux_shape_ndim[i]])
+ if aux_shape_ndim[i] >= 0 else None
+ for i in range(aux_shape_size.value)]
return (arg_shapes, out_shapes, aux_shapes)
else:
return (None, None, None)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 9be9134..667b8a7 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -44,9 +44,11 @@
#include "mxnet/rtc.h"
#include "mxnet/storage.h"
#include "mxnet/libinfo.h"
+#include "mxnet/imperative.h"
#include "./c_api_common.h"
#include "../operator/custom/custom-inl.h"
#include "../operator/tensor/matrix_op-inl.h"
+#include "../common/utils.h"
using namespace mxnet;
@@ -498,15 +500,23 @@ int MXNDArrayGetShape(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
- const mxnet::TShape &s = arr->shape();
+ mxnet::TShape s = arr->shape();
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToLegacyShape(&s);
+ }
*out_dim = s.ndim();
- CHECK_GE(s.ndim(), 0);
- std::vector<int>& buffer = ret->arg_shape_buffer;
- buffer.resize(s.ndim());
- mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
- *out_pdata = buffer.data();
+ if (s.ndim() >= 0) {
+ std::vector<int> &buffer = ret->arg_shape_buffer;
+ buffer.resize(s.ndim());
+ mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
+ *out_pdata = buffer.data();
+ }
} else {
- *out_dim = 0;
+ if (Imperative::Get()->is_np_comp()) {
+ *out_dim = -1;
+ } else {
+ *out_dim = 0;
+ }
}
API_END();
}
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 690a1ea..329dc9a 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -91,7 +91,7 @@ struct MXAPIThreadLocalEntry {
data->resize(shapes.size());
size_t size = 0;
for (const auto& s : shapes) {
- CHECK_GE(s.ndim(), 0);
+ if (s.ndim() > 0);
size += s.ndim();
}
buffer->resize(size);
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index d000638..fc59463 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -25,8 +25,10 @@
#include <mxnet/base.h>
#include <mxnet/c_api.h>
#include <mxnet/executor.h>
+#include <mxnet/imperative.h>
#include "./c_api_common.h"
#include "../executor/graph_executor.h"
+#include "../common/utils.h"
#if MXNET_USE_TENSORRT
#include "../executor/trt_graph_executor.h"
#endif // MXNET_USE_TENSORRT
@@ -416,6 +418,11 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
CHECK(p.second) << "Duplicate shapes are provided for argument "
<< provided_arg_shape_names[i] << " in simple_bind";
}
+ if (!Imperative::Get()->is_np_comp()) {
+ for (auto &kv : arg_shape_map) {
+ common::ConvertToNumpyShape(&kv.second);
+ }
+ }
// create para name set for sharing data array memory
std::unordered_set<std::string> shared_arg_name_set(num_shared_arg_names);
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 18f6c41..0e136b0 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -276,6 +276,18 @@ int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_END();
}
+int MXIsNumpyCompatible(bool* curr) {
+ API_BEGIN();
+ *curr = Imperative::Get()->is_np_comp();
+ API_END();
+}
+
+int MXSetIsNumpyCompatible(int is_np_comp, int* prev) {
+ API_BEGIN();
+ *prev = Imperative::Get()->set_is_np_comp(static_cast<bool>(is_np_comp));
+ API_END();
+}
+
int MXAutogradMarkVariables(mx_uint num_var,
NDArrayHandle *var_handles,
mx_uint *reqs_array,
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index c4d3bb0..7fa0ee0 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -24,6 +24,7 @@
*/
#include "mxnet/base.h"
#include "mxnet/c_api.h"
+#include "mxnet/imperative.h"
#include "nnvm/c_api.h"
#include "nnvm/pass.h"
#include "nnvm/pass_functions.h"
@@ -543,8 +544,14 @@ int MXSymbolInferShape(SymbolHandle sym,
throw dmlc::Error(err.msg);
}
+ // if use legacy shape definition, need to convert numpy shape to legacy shape
+ mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToLegacyShape(&shapes);
+ }
+
// copy back
- CopyAttr(g.indexed_graph(), g.GetAttr<mxnet::ShapeVector>("shape"),
+ CopyAttr(g.indexed_graph(), shapes,
&(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
// copy data back
diff --git a/src/common/utils.h b/src/common/utils.h
index 8e69669..f3df2e1 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -734,6 +734,60 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
}
}
+/*!
+ * \brief If numpy compatibility is turned off (default), the shapes passed in
+ * by users follow the legacy shape definition:
+ * 1. 0 ndim means the shape is completely unknown.
+ * 2. 0 dim size means the dim size is unknown.
+ * We need to convert those shapes to use the numpy shape definition:
+ * 1. 0 ndim means it's a scalar tensor.
+ * 2. -1 ndim means the shape is unknown.
+ * 3. 0 dim size means no elements in that dimension.
+ * 4. -1 dim size means the dimension's size is unknown.
+ * so that operator's infer shape function can work in backend.
+ * \param shape to be converted.
+ */
+inline void ConvertToNumpyShape(mxnet::TShape* shape) {
+ if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown
+ *shape = mxnet::TShape(); // unknown shape ndim = -1
+ } else {
+ for (int j = 0; j < shape->ndim(); ++j) {
+ CHECK_GE((*shape)[j], 0) << "Legacy shape cannot have dim size < 0";
+ if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown
+ (*shape)[j] = -1; // unknown dim size = -1
+ }
+ }
+ }
+}
+
+inline void ConvertToNumpyShape(mxnet::ShapeVector* shapes) {
+ for (size_t i = 0; i < shapes->size(); ++i) {
+ ConvertToNumpyShape(&(shapes->at(i)));
+ }
+}
+
+/*!
+ * \brief This is function is used to convert shapes returned by
+ * the infer shape functions/pass to the legacy shape definition.
+ */
+inline void ConvertToLegacyShape(mxnet::TShape* shape) {
+ if (!mxnet::ndim_is_known(*shape)) {
+ *shape = mxnet::TShape(0);
+ } else {
+ for (int j = 0; j < shape->ndim(); ++j) {
+ if (!mxnet::dim_size_is_known(*shape, j)) {
+ (*shape)[j] = 0;
+ }
+ }
+ }
+}
+
+inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) {
+ for (size_t i = 0; i < shapes->size(); ++i) {
+ ConvertToLegacyShape(&(shapes->at(i)));
+ }
+}
+
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index 3a5c5ab..fa7aee5 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -24,6 +24,7 @@
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
+#include <mxnet/imperative.h>
#include "./exec_pass.h"
#include "../operator/operator_common.h"
#include "../common/exec_utils.h"
@@ -467,6 +468,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
std::vector<AttrType> ishape, oshape;
// whether a shape is dynamic
std::vector<int> is_dynamic(rshape.size(), 0);
+
+ // convert to numpy compatible shape to use operator's infer shape function
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToNumpyShape(&rshape);
+ }
+
// inference step function for nid
auto infer_step = [&](uint32_t nid, bool last_iter) {
const auto& inode = idx[nid];
@@ -483,6 +490,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
if (it != inode.source->attrs.dict.end()) {
std::istringstream is(it->second);
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToNumpyShape(&rshape[out_ent_id]);
+ }
}
}
// assign a default value to node attribute
@@ -546,7 +556,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
bool is_input_dynamic_shape = false;
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
- if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) {
+ if (!mxnet::ndim_is_known(ishape[i]) && is_dynamic[idx.entry_id(inode.inputs[i])]) {
is_input_dynamic_shape = true;
}
if (fis_none(ishape[i])) forward_known = false;
@@ -563,7 +573,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
if (finfer == nullptr || is_input_dynamic_shape) {
for (uint32_t i = 0; i < oshape.size(); ++i) {
- if (oshape[i].ndim() == 0) {
+ if (!mxnet::ndim_is_known(oshape[i].ndim())) {
is_dynamic[idx.entry_id(nid, i)] = 1;
}
}
@@ -650,12 +660,12 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
"shape", "shape_num_unknown_nodes",
[](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); },
[](const mxnet::TShape& s) {
- if (s.ndim() == -1) {
+ if (!mxnet::ndim_is_known(s)) {
return static_cast<size_t>(1);
}
size_t ret = 0;
for (const auto& val : s) {
- if (val == -1) {
+ if (!mxnet::dim_size_is_known(val)) {
++ret;
}
}
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 3e5b398..b027de0 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -25,9 +25,11 @@ namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
thread_local bool Imperative::is_train_ = false;
thread_local bool Imperative::is_recording_ = false;
+thread_local bool Imperative::is_np_comp_ = false;
#else
MX_THREAD_LOCAL bool Imperative::is_train_ = false;
MX_THREAD_LOCAL bool Imperative::is_recording_ = false;
+MX_THREAD_LOCAL bool Imperative::is_np_comp_ = false;
#endif
Imperative* Imperative::Get() {
@@ -109,7 +111,7 @@ OpStatePtr Imperative::Invoke(
OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode);
// the followinng loop is used for finding out the correct shape when some shapes are dynamic
for (size_t i = 0; i < outputs.size(); i++) {
- if (outputs[i]->shape().ndim() == 0) {
+ if (!shape_is_known(outputs[i]->shape())) {
// the WaitToRead overhead here does not seem to be avoidable
outputs[i]->WaitToRead();
outputs[i]->SetShapeFromChunk();
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 071f4fa..6864428 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -121,7 +121,24 @@ inline void SetShapeType(const Context& ctx,
if (!infershape.count(attrs.op)) {
is_dynamic_shape_existing = true;
} else {
- CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes));
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToNumpyShape(&in_shapes);
+ common::ConvertToNumpyShape(&out_shapes);
+ }
+ const bool success = infershape[attrs.op](attrs, &in_shapes, &out_shapes);
+ if (!success) {
+ std::stringstream os;
+ os << "Operator " << attrs.op->name << " inferring shapes failed.\n";
+ os << "input shapes:\n";
+ for (auto& nd : inputs) {
+ os << nd->shape() << '\n';
+ }
+ os << "output shapes:\n";
+ for (auto& nd : outputs) {
+ os << nd->shape() << '\n';
+ }
+ LOG(FATAL) << os.str();
+ }
CHECK_EQ(out_shapes.size(), outputs.size());
}
// infer type
diff --git a/src/operator/batch_norm_v1-inl.h b/src/operator/batch_norm_v1-inl.h
index 8016510..8941235 100644
--- a/src/operator/batch_norm_v1-inl.h
+++ b/src/operator/batch_norm_v1-inl.h
@@ -261,7 +261,7 @@ class BatchNormV1Prop : public OperatorProperty {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(0);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
in_shape->at(1) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(2) = mxnet::TShape(Shape1(dshape[1]));
out_shape->clear();
diff --git a/src/operator/contrib/multi_proposal-inl.h b/src/operator/contrib/multi_proposal-inl.h
index a9afb8e..4d278fb 100644
--- a/src/operator/contrib/multi_proposal-inl.h
+++ b/src/operator/contrib/multi_proposal-inl.h
@@ -108,7 +108,7 @@ class MultiProposalProp : public OperatorProperty {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3) << "Input:[cls_prob, bbox_pred, im_info]";
const mxnet::TShape &dshape = in_shape->at(proposal::kClsProb);
- if (!mxnet::op::shape_is_none(dshape)) return false;
+ if (mxnet::op::shape_is_none(dshape)) return false;
Shape<4> bbox_pred_shape;
bbox_pred_shape = Shape4(dshape[0], dshape[1] * 2, dshape[2], dshape[3]);
SHAPE_ASSIGN_CHECK(*in_shape, proposal::kBBoxPred,
diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc
index 9ba3b54..4c0d67b 100644
--- a/src/operator/control_flow.cc
+++ b/src/operator/control_flow.cc
@@ -301,7 +301,7 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < params.num_out_data; i++) {
mxnet::TShape shape = subg_out_shape[i];
// If we don't have shape info, we don't need to do anything.
- if (!shape_is_known(shape))
+ if (!mxnet::ndim_is_known(shape))
continue;
subg_out_shape[i] = SliceFirstDim(shape);
}
@@ -317,7 +317,7 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < params.num_out_data; i++) {
// If the output shape isn't inferred, we don't need to propogate the info.
const auto& g_out_shape = subg_out_shape[i];
- if (!shape_is_known(g_out_shape))
+ if (!mxnet::ndim_is_known(g_out_shape))
continue;
auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1);
@@ -336,7 +336,7 @@ static bool ForeachShape(const nnvm::NodeAttrs& attrs,
const auto &shape = subg_in_shape[loc];
// If the input data shape isn't inferred, we don't need to propogate the
// info.
- if (!shape_is_known(shape))
+ if (!mxnet::ndim_is_known(shape))
continue;
if (data_1d[i]) {
diff --git a/src/operator/convolution_v1-inl.h b/src/operator/convolution_v1-inl.h
index 0d6ffd7..080c718 100644
--- a/src/operator/convolution_v1-inl.h
+++ b/src/operator/convolution_v1-inl.h
@@ -405,7 +405,7 @@ class ConvolutionV1Prop : public OperatorProperty {
// CHECK_EQ(out_shape->size(), 1) << "Output: [output]";
out_shape->resize(1, mxnet::TShape());
const mxnet::TShape &dshp = (*in_shape)[conv_v1::kData];
- if (!shape_is_known(dshp)) return false;
+ if (!mxnet::ndim_is_known(dshp)) return false;
if (param_.kernel.ndim() == 2) {
// 2d conv_v1
CHECK_EQ(dshp.ndim(), 4U) \
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 39cca4d..8dcfcbe 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -133,17 +133,21 @@ bool InferShape(const NodeAttrs& attrs,
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
size_t total = params.num_args + params.num_outs + params.num_auxs;
- std::vector<uint32_t*> shapes(total);
+ std::vector<int*> shapes(total);
std::vector<int> ndims(total);
size_t buff_size = 0;
- for (const auto& i : *in_shape) buff_size += i.ndim();
- std::vector<uint32_t> buff(buff_size);
- uint32_t *ptr = buff.data();
+ for (const auto& i : *in_shape) {
+ if (i.ndim() > 0) {
+ buff_size += i.ndim();
+ }
+ }
+ std::vector<int> buff(buff_size);
+ int *ptr = buff.data();
for (size_t i = 0; i < in_shape->size(); ++i) {
shapes[i] = ptr;
ndims[i] = (*in_shape)[i].ndim();
- for (size_t j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
- *ptr = static_cast<uint32_t>((*in_shape)[i][j]);
+ for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) {
+ *ptr = (*in_shape)[i][j];
}
}
@@ -268,7 +272,7 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
for (size_t i = 0; i < in_shape.size(); ++i) {
shapes[i] = ptr;
ndims[i] = in_shape[i].ndim();
- for (size_t j = 0; j < in_shape[i].ndim(); ++j, ++ptr) {
+ for (int j = 0; j < in_shape[i].ndim(); ++j, ++ptr) {
*ptr = static_cast<uint32_t>(in_shape[i][j]);
}
}
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 590b1b4..622952c 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -332,7 +332,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
const int channelCount = dshape[channelAxis];
- if (!shape_is_known(dshape)) {
+ if (!mxnet::ndim_is_known(dshape)) {
return false;
}
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 5435bd8..b534ee5 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -46,7 +46,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
- has_unknown_dim_size = tmp[axis] == -1 || has_unknown_dim_size;
+ has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
size += tmp[axis];
tmp[axis] = -1;
shape_assign(&dshape, tmp);
@@ -91,26 +91,27 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
mxnet::TShape tmp = (*in_shape)[i];
- if (tmp.ndim()) {
+ if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
- if (tmp[axis] == 0) {
+ if (!mxnet::dim_size_is_known(tmp, axis)) {
zero_indices.emplace_back(i);
} else {
+ CHECK_GE(tmp[axis], 0);
size += tmp[axis];
}
- tmp[axis] = 0;
+ tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
}
mxnet::TShape tmp = (*out_shape)[0];
- if (tmp.ndim()) {
+ if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
- tmp[axis] = 0;
+ tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
@@ -120,21 +121,21 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
if (zero_indices.empty()) dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
- if ((*out_shape)[0][axis] != 0 && !zero_indices.empty()) {
+ if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) {
int residual = (*out_shape)[0][axis] - size;
CHECK_GE(residual, 0)
<< "Input size already exceeds output size. Residual: " << residual;
- CHECK(zero_indices.size() <= 2 && zero_indices.size() >= 0)
+ CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0)
<< "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
- bool need_infer = !(*out_shape)[0].Size();
+ bool need_infer = !shape_is_known((*out_shape)[0]);
for (int i : zero_indices) {
(*in_shape)[i][axis] = residual / zero_indices.size();
- need_infer = need_infer || !(*in_shape)[i].Size();
+ need_infer = need_infer || !shape_is_known((*in_shape)[i]);
}
return !need_infer;
}
- return dshape.Size() != 0;
+ return shape_is_known(dshape);
}
static bool ConcatType(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index dfbc89d..536e9a7 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -96,24 +96,28 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
// CHECK_EQ(out_shape->size(), 1) << "Output: [output]";
out_shape->resize(1, mxnet::TShape());
const mxnet::TShape &dshp = (*in_shape)[conv::kData];
- if (!shape_is_known(dshp)) return false;
+ if (!mxnet::ndim_is_known(dshp)) return false;
if (param_.kernel.ndim() == 1) {
// 1d conv
CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x";
Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW);
- Shape<3> wshape = Shape3(param_.num_filter / param_.num_group, dshape[1] / param_.num_group,
+ Shape<3> wshape = Shape3(param_.num_filter / param_.num_group,
+ mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0]);
wshape = ConvertLayout(wshape, kNCW, param_.layout.value());
- wshape[0] *= param_.num_group;
+ if (wshape[0] >= 0) {
+ wshape[0] *= param_.num_group;
+ }
SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape);
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter));
}
const index_t dilated_ksize_x = param_.DilatedKernelSize(0);
- CHECK_EQ(dshape[1] % param_.num_group, 0U) \
- << "input num_filter must divide group size";
+ if (dshape[1] != -1) {
+ CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size";
+ }
CHECK_EQ(param_.num_filter % param_.num_group, 0U) \
<< "output num_filter must divide group size";
CHECK_GT(param_.kernel.Size(), 0U) \
@@ -125,21 +129,21 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<3> oshape;
oshape[0] = dshape[0];
oshape[1] = param_.num_filter;
- oshape[2] = dshape[2] ?
- (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : 0;
+ oshape[2] = dshape[2] != -1 ?
+ (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_x) / param_.stride[0] + 1 : -1;
SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value()));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0].get<3>(), param_.layout.value(), kNCW);
dshape[0] = oshape[0];
- if (oshape[2] && param_.stride[0] == 1) {
+ if (oshape[2] != -1 && param_.stride[0] == 1) {
dshape[2] = oshape[2] + dilated_ksize_x - 1 - 2 * param_.pad[0];
}
SHAPE_ASSIGN_CHECK(*in_shape, conv::kData,
ConvertLayout(dshape, kNCW, param_.layout.value()));
// Check whether the kernel sizes are valid
- if (dshape[2] != 0) {
+ if (dshape[2] != -1) {
CHECK_LE(dilated_ksize_x, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input";
}
return true;
@@ -149,10 +153,12 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
<< "Input data should be 4D in batch-num_filter-y-x";
Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW);
Shape<4> wshape = Shape4(param_.num_filter / param_.num_group,
- dshape[1] / param_.num_group,
+ mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1]);
wshape = ConvertLayout(wshape, kNCHW, param_.layout.value());
- wshape[0] *= param_.num_group;
+ if (wshape[0] >= 0) {
+ wshape[0] *= param_.num_group;
+ }
SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape);
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter));
@@ -160,8 +166,9 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
const index_t dilated_ksize_y = param_.DilatedKernelSize(0);
const index_t dilated_ksize_x = param_.DilatedKernelSize(1);
- CHECK_EQ(dshape[1] % param_.num_group, 0U) \
- << "input num_filter must divide group size";
+ if (dshape[1] != -1) {
+ CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size";
+ }
CHECK_EQ(param_.num_filter % param_.num_group, 0U) \
<< "output num_filter must divide group size";
CHECK_GT(param_.kernel.Size(), 0U) \
@@ -173,29 +180,29 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<4> oshape;
oshape[0] = dshape[0];
oshape[1] = param_.num_filter;
- oshape[2] = dshape[2] ?
- (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_y) / param_.stride[0] + 1 : 0;
- oshape[3] = dshape[3] ?
- (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_x) / param_.stride[1] + 1 : 0;
+ oshape[2] = dshape[2] != -1 ?
+ (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_y) / param_.stride[0] + 1 : -1;
+ oshape[3] = dshape[3] != -1 ?
+ (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_x) / param_.stride[1] + 1 : -1;
SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value()));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0].get<4>(), param_.layout.value(), kNCHW);
dshape[0] = oshape[0];
- if (oshape[2] && param_.stride[0] == 1) {
+ if (oshape[2] != -1 && param_.stride[0] == 1) {
dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param_.pad[0];
}
- if (oshape[3] && param_.stride[1] == 1) {
+ if (oshape[3] != -1 && param_.stride[1] == 1) {
dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param_.pad[1];
}
SHAPE_ASSIGN_CHECK(*in_shape, conv::kData,
ConvertLayout(dshape, kNCHW, param_.layout.value()));
// Check whether the kernel sizes are valid
- if (dshape[2] != 0) {
+ if (dshape[2] != -1) {
CHECK_LE(dilated_ksize_y, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input";
}
- if (dshape[3] != 0) {
+ if (dshape[3] != -1) {
CHECK_LE(dilated_ksize_x, AddPad(dshape[3], param_.pad[1])) << "kernel size exceed input";
}
return true;
@@ -204,10 +211,13 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(dshp.ndim(), 5U) \
<< "Input data should be 5D in batch-num_filter-depth-y-x";
Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW);
- Shape<5> wshape = Shape5(param_.num_filter / param_.num_group, dshape[1] / param_.num_group,
+ Shape<5> wshape = Shape5(param_.num_filter / param_.num_group,
+ mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
param_.kernel[0], param_.kernel[1], param_.kernel[2]);
wshape = ConvertLayout(wshape, kNCDHW, param_.layout.value());
- wshape[0] *= param_.num_group;
+ if (wshape[0] >= 0) {
+ wshape[0] *= param_.num_group;
+ }
SHAPE_ASSIGN_CHECK(*in_shape, conv::kWeight, wshape);
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter));
@@ -218,8 +228,9 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
const index_t dilated_ksize_d = param_.DilatedKernelSize(0);
const index_t dilated_ksize_y = param_.DilatedKernelSize(1);
const index_t dilated_ksize_x = param_.DilatedKernelSize(2);
- CHECK_EQ(dshape[1] % param_.num_group, 0U)
- << "input num_filter must divide group size";
+ if (dshape[1] >= 0) {
+ CHECK_EQ(dshape[1] % param_.num_group, 0U) << "input num_filter must divide group size";
+ }
CHECK_EQ(param_.num_filter % param_.num_group, 0U)
<< "output num_filter must divide group size";
CHECK_GT(param_.kernel.Size(), 0U) \
@@ -233,37 +244,37 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<5> oshape;
oshape[0] = dshape[0];
oshape[1] = param_.num_filter;
- oshape[2] = dshape[2] ?
- (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_d) / param_.stride[0] + 1 : 0;
- oshape[3] = dshape[3] ?
- (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_y) / param_.stride[1] + 1 : 0;
- oshape[4] = dshape[4] ?
- (AddPad(dshape[4], param_.pad[2]) - dilated_ksize_x) / param_.stride[2] + 1 : 0;
+ oshape[2] = dshape[2] != -1 ?
+ (AddPad(dshape[2], param_.pad[0]) - dilated_ksize_d) / param_.stride[0] + 1 : -1;
+ oshape[3] = dshape[3] != -1 ?
+ (AddPad(dshape[3], param_.pad[1]) - dilated_ksize_y) / param_.stride[1] + 1 : -1;
+ oshape[4] = dshape[4] != -1 ?
+ (AddPad(dshape[4], param_.pad[2]) - dilated_ksize_x) / param_.stride[2] + 1 : -1;
SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCDHW, param_.layout.value()));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input depth/height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0].get<5>(), param_.layout.value(), kNCDHW);
dshape[0] = oshape[0];
- if (oshape[2] && param_.stride[0] == 1) {
+ if (oshape[2] != -1 && param_.stride[0] == 1) {
dshape[2] = oshape[2] + dilated_ksize_d - 1 - 2 * param_.pad[0];
}
- if (oshape[3] && param_.stride[1] == 1) {
+ if (oshape[3] != -1 && param_.stride[1] == 1) {
dshape[3] = oshape[3] + dilated_ksize_y - 1 - 2 * param_.pad[1];
}
- if (oshape[4] && param_.stride[2] == 1) {
+ if (oshape[4] != -1 && param_.stride[2] == 1) {
dshape[4] = oshape[4] + dilated_ksize_x - 1 - 2 * param_.pad[2];
}
SHAPE_ASSIGN_CHECK(*in_shape, conv::kData,
ConvertLayout(dshape, kNCDHW, param_.layout.value()));
// Check whether the kernel sizes are valid
- if (dshape[2] != 0) {
+ if (dshape[2] != -1) {
CHECK_LE(dilated_ksize_d, AddPad(dshape[2], param_.pad[0])) << "kernel size exceed input";
}
- if (dshape[3] != 0) {
+ if (dshape[3] != -1) {
CHECK_LE(dilated_ksize_y, AddPad(dshape[3], param_.pad[1])) << "kernel size exceed input";
}
- if (dshape[4] != 0) {
+ if (dshape[4] != -1) {
CHECK_LE(dilated_ksize_x, AddPad(dshape[4], param_.pad[2])) << "kernel size exceed input";
}
return true;
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm.cc b/src/operator/nn/cudnn/cudnn_batch_norm.cc
index 1df888e..cb35ce1 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm.cc
+++ b/src/operator/nn/cudnn/cudnn_batch_norm.cc
@@ -37,7 +37,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_
using namespace mshadow;
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, moving_mean, moving_var]";
const mxnet::TShape &dshape = in_shape->at(0);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
in_shape->at(1) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(2) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(3) = mxnet::TShape(Shape1(dshape[1]));
diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc
index 0e4d18b..afad6fd 100644
--- a/src/operator/nn/dropout.cc
+++ b/src/operator/nn/dropout.cc
@@ -95,7 +95,7 @@ Example::
CHECK_EQ(in_shape->size(), 1U);
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
mxnet::TShape dshape(in_shape->at(0));
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
out_shape->clear();
out_shape->push_back(dshape);
for (int i = 0; i < param.axes.ndim(); ++i) {
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 2fea62e..20fd3b6 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -46,7 +46,7 @@ static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape dshape = (*in_shape)[fullc::kData];
mxnet::TShape oshape = (*out_shape)[0];
// require data to be known
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
index_t num_input;
if (!param.flatten) {
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index 1b0e99d..2e47503 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -48,7 +48,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
const int channelCount = dshape[axis];
- if (!shape_is_known(dshape)) {
+ if (!mxnet::ndim_is_known(dshape)) {
return false;
}
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 7c365f5..3e081c9 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -114,7 +114,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
int layout = param.GetLayout(dshape.ndim());
if (param.global_pool) {
mxnet::TShape oshape = dshape;
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index c95f859..59f5722 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -160,14 +160,14 @@ inline std::string type_string(const int& x) {
* \return whether x and y are compatible.
*/
inline bool shape_assign(mxnet::TShape *y, const mxnet::TShape& x) {
- if (y->ndim() == -1) {
+ if (!mxnet::ndim_is_known(*y)) {
*y = x;
return true;
} else if (y->ndim() != x.ndim()) {
- return x.ndim() == -1;
+ return !mxnet::ndim_is_known(x);
} else {
for (int i = 0; i < y->ndim(); ++i) {
- if ((*y)[i] == -1) {
+ if (!mxnet::dim_size_is_known(*y, i)) {
(*y)[i] = x[i];
} else if ((*y)[i] != x[i] && x[i] >= 0) {
return false;
diff --git a/src/operator/pooling_v1-inl.h b/src/operator/pooling_v1-inl.h
index 4e0ccc1..21ba270 100644
--- a/src/operator/pooling_v1-inl.h
+++ b/src/operator/pooling_v1-inl.h
@@ -247,7 +247,7 @@ class PoolingV1Prop : public OperatorProperty {
CHECK_LE(dshape.ndim(), 5U) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
<< "Or 5D in (batch, channel, d, y, x)";
mxnet::TShape oshape = dshape;
- if (dshape.ndim() == 0) return false;
+ if (dshape.ndim() == -1) return false;
if (param_.global_pool) {
if (dshape.ndim() == 4) {
oshape[2] = 1;
diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc
index f978074..2cc2ec9 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -55,7 +55,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_sha
shape_assign(&dshape, tmp);
}
- if (dshape.ndim() == -1) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
diff --git a/src/operator/random/unique_sample_op.h b/src/operator/random/unique_sample_op.h
index c97d4fd..e88b95a 100644
--- a/src/operator/random/unique_sample_op.h
+++ b/src/operator/random/unique_sample_op.h
@@ -60,7 +60,7 @@ inline bool SampleUniqueShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 2U);
// output shape is known
- if ((*out_attrs)[0].ndim() == 2 && param.shape.ndim() == -1) {
+ if ((*out_attrs)[0].ndim() == 2 && !mxnet::ndim_is_known(param.shape)) {
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::Shape1((*out_attrs)[0][0]));
return true;
}
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 9e612d0..f86eee7 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -676,7 +676,7 @@ class RNNProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
}
const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData];
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 3U) \
<< "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
// data: [sequence len, batch, input dimension]
@@ -712,9 +712,7 @@ class RNNProp : public OperatorProperty {
oshape[2] = numDirections * param_.state_size;
}
out_shape->push_back(oshape);
- if (!param_.state_outputs) {
- return true;
- } else {
+ if (param_.state_outputs) {
// outStateShape: [layer_num, batch, state size]
mxnet::TShape outStateShape = dshape;
outStateShape[0] = total_layers;
@@ -733,8 +731,8 @@ class RNNProp : public OperatorProperty {
cellStateShape[2] = param_.state_size;
out_shape->push_back(cellStateShape);
}
- return true;
}
+ return shape_is_known(dshape);
}
bool InferType(std::vector<int> *in_type,
diff --git a/src/operator/slice_channel-inl.h b/src/operator/slice_channel-inl.h
index a51b17c..e37ffdc 100644
--- a/src/operator/slice_channel-inl.h
+++ b/src/operator/slice_channel-inl.h
@@ -195,7 +195,7 @@ class SliceChannelProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 1U);
mxnet::TShape dshape = in_shape->at(slice_enum::kData);
mxnet::TShape ishape = in_shape->at(slice_enum::kData);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
if (param_.axis >= 0) {
CHECK_LT(param_.axis, dshape.ndim());
} else {
@@ -212,15 +212,18 @@ class SliceChannelProp : public OperatorProperty {
<< " evenly sized chunks, but this is not possible because "
<< param_.num_outputs << " does not evenly divide "
<< dshape[real_axis];
- if (param_.squeeze_axis && ishape[real_axis] != 0) {
- CHECK_EQ(ishape[real_axis], static_cast<size_t>(param_.num_outputs))
+ if (param_.squeeze_axis && ishape[real_axis] != -1) {
+ CHECK_EQ(ishape[real_axis], param_.num_outputs)
<< "If squeeze axis is True, the size of the sliced axis must be the same as num_outputs."
<< " Input shape=" << ishape << ", axis=" << real_axis
<< ", num_outputs=" << param_.num_outputs << ".";
}
- dshape[real_axis] /= param_.num_outputs;
- if (param_.squeeze_axis && (dshape[real_axis] == 1 || ishape[real_axis] == 0)) {
- for (int d = real_axis; d < static_cast<int>(dshape.ndim()) - 1; ++d) {
+ if (dshape[real_axis] >= 0) {
+ dshape[real_axis] /= param_.num_outputs;
+ }
+ if (param_.squeeze_axis && (dshape[real_axis] == 1
+ || !mxnet::dim_size_is_known(ishape, real_axis))) {
+ for (int d = real_axis; d < dshape.ndim() - 1; ++d) {
dshape[d] = dshape[d+1];
}
dshape = mxnet::TShape(&dshape[0], &dshape[dshape.ndim()-1]);
diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h
index f81a232..80ab40e 100644
--- a/src/operator/softmax_output-inl.h
+++ b/src/operator/softmax_output-inl.h
@@ -349,7 +349,7 @@ class SoftmaxOutputProp : public OperatorProperty {
lshape2[i-1] = dshape[i];
mxnet::TShape lshape3 = dshape;
lshape3[1] = 1;
- if (in_shape->at(softmaxout_enum::kLabel).ndim() == -1) {
+ if (!mxnet::ndim_is_known(in_shape->at(softmaxout_enum::kLabel))) {
in_shape->at(softmaxout_enum::kLabel) = lshape1;
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) {
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) {
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 262242f..548225f 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -85,7 +85,7 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]";
const mxnet::TShape &dshape = in_shape->at(0);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
// label.shape == data.shape: use probability as label
if (dshape != (*in_shape)[softmaxout_enum::kLabel]) {
@@ -97,7 +97,7 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
lshape2[i-1] = dshape[i];
mxnet::TShape lshape3 = dshape;
lshape3[1] = 1;
- if (in_shape->at(softmaxout_enum::kLabel).ndim() == -1) {
+ if (!mxnet::ndim_is_known(in_shape->at(softmaxout_enum::kLabel))) {
in_shape->at(softmaxout_enum::kLabel) = lshape1;
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape1) {
} else if (in_shape->at(softmaxout_enum::kLabel) == lshape2) {
diff --git a/src/operator/svm_output-inl.h b/src/operator/svm_output-inl.h
index 3d651c1..dfe9fa6 100644
--- a/src/operator/svm_output-inl.h
+++ b/src/operator/svm_output-inl.h
@@ -143,7 +143,7 @@ class SVMOutputProp : public OperatorProperty {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]";
const mxnet::TShape &dshape = in_shape->at(0);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
mxnet::TShape label_shape(dshape.ndim() - 1, -1);
for (int i = 0; i + 1 < dshape.ndim(); ++i)
label_shape[i] = dshape[i];
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index fb55fcd..2d432c9 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -339,13 +339,13 @@ inline bool BroadcastToShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
- if (!shape_is_known(ishape)) return false;
+ if (!mxnet::ndim_is_known(ishape)) return false;
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
CHECK_EQ(ishape.ndim(), param.shape.ndim())
<< "Operand of shape " << ishape << " cannot be broadcasted to " << param.shape;
mxnet::TShape oshape = param.shape;
for (int i = 0; i < ishape.ndim(); ++i) {
- if (oshape[i] != 0) {
+ if (oshape[i] != -1) {
CHECK(ishape[i] == oshape[i] || ishape[i] == 1)
<< "Array cannot be broadcasted from " << ishape << " to " << param.shape;
} else {
@@ -364,7 +364,7 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& lhs_shape = (*in_attrs)[0];
mxnet::TShape& rhs_shape = (*in_attrs)[1];
- if (!shape_is_known(lhs_shape) || !shape_is_known(lhs_shape)) {
+ if (!mxnet::ndim_is_known(lhs_shape) || !mxnet::ndim_is_known(rhs_shape)) {
return false;
}
@@ -378,7 +378,7 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
oshape = mxnet::TShape(rhs_shape);
for (int i = 0; i < lhs_shape.ndim(); ++i) {
- if (rhs_shape[i] != 0) {
+ if (rhs_shape[i] != -1) {
CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
<< "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
} else {
diff --git a/src/operator/tensor/diag_op-inl.h b/src/operator/tensor/diag_op-inl.h
index b90b09a..c95c1ce 100644
--- a/src/operator/tensor/diag_op-inl.h
+++ b/src/operator/tensor/diag_op-inl.h
@@ -84,7 +84,7 @@ inline mxnet::TShape DiagShapeImpl(const mxnet::TShape& ishape, const int k,
auto s = std::min(h, w);
if (s < 0) {
- s = 0;
+ s = -1;
}
if (x1 > x2) {
@@ -114,7 +114,7 @@ inline bool DiagOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
- if (!shape_is_known(ishape)) {
+ if (!mxnet::ndim_is_known(ishape)) {
return false;
}
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index dfb3231..64a4d7c 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -48,7 +48,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& rhs = (*in_attrs)[1];
// avoid pre-mature shape inference.
- if (lhs.ndim() == -1 || rhs.ndim() == -1) return false;
+ if (!mxnet::ndim_is_known(lhs) || !mxnet::ndim_is_known(rhs)) return false;
if (lhs == rhs) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, lhs);
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 3c4d34b..bcad602 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -28,6 +28,7 @@
#include <mxnet/base.h>
#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
+#include <mxnet/imperative.h>
#include <dmlc/parameter.h>
#include <dmlc/optional.h>
#include <vector>
@@ -213,14 +214,13 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs,
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
- if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param.shape)) return true;
- for (int i=0 ; i < param.shape.ndim() ; ++i) {
- if (param.shape[i] < 0U) {
- LOG(FATAL) << "Shape cannot contain negative values " << param.shape;
- }
+ mxnet::TShape param_shape = param.shape;
+ if (!Imperative::Get()->is_np_comp()) {
+ common::ConvertToNumpyShape(¶m_shape);
}
- SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
- return true;
+ if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true;
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape);
+ return shape_is_known(out_attrs->at(0));
}
template<typename ParamType>
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 1ab244b..f79af7c 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -110,7 +110,7 @@ inline mxnet::TShape InferReshapeShape(const mxnet::Tuple<IType>& shape,
CHECK_LT(src_idx, dshape_len-1);
const int d1 = dshape_vec[src_idx++];
const int d2 = dshape_vec[src_idx++];
- if (d1 == -1 || d2 == -1) {
+ if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) {
tmp.push_back(-1);
} else {
tmp.push_back(d1 * d2);
@@ -164,7 +164,7 @@ inline bool ReverseReshapeInferShape(mxnet::TShape *in, const mxnet::TShape& out
int zero_axis = -1;
int known_dim_size_prod = 1;
for (int i = 0; i < in->ndim(); i++) {
- if ((*in)[i] == -1) {
+ if (!mxnet::dim_size_is_known(*in, i)) {
if (zero_axis != -1)
return false; // more than 1 zero found.
else
@@ -185,7 +185,7 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape &dshape = (*in_attrs)[0];
- if (dshape.ndim() == -1) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
mxnet::TShape oshape;
if (param_.shape.ndim() != 0) {
oshape = InferReshapeShape(param_.shape, dshape, param_.reverse);
@@ -314,7 +314,7 @@ void Transpose(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
- if (param.axes.ndim() == -1) {
+ if (!mxnet::ndim_is_known(param.axes)) {
mxnet::TShape axes(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
@@ -334,7 +334,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions";
mxnet::TShape ret(shp.ndim(), -1);
- if (param.axes.ndim() == -1) {
+ if (!mxnet::ndim_is_known(param.axes)) {
for (int i = 0; i < shp.ndim(); ++i) {
ret[i] = shp[shp.ndim()-1-i];
}
@@ -367,7 +367,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
- if (!shape_is_known(in_attrs->at(0)) && !shape_is_known(out_attrs->at(0))) {
+ if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) {
return false;
}
@@ -401,7 +401,7 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < axis; ++i) ret[i] = oshape[i];
for (int i = axis+1; i < indim+1; ++i) ret[i-1] = oshape[i];
SHAPE_ASSIGN_CHECK(*in_attrs, 0, ret);
- return shape_is_known(ret);
+ return shape_is_known(in_attrs->at(0)) && shape_is_known(out_attrs->at(0));
}
// Currently MKLDNN only supports step = 1 or step has no value
@@ -668,7 +668,7 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}
}
- if (len) {
+ if (len > 0) {
if (param_begin[i].has_value()) {
b = param_begin[i].value();
if (b < 0) {
@@ -731,7 +731,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& dshape = (*in_attrs)[0];
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
mxnet::TShape oshape = dshape;
@@ -1122,9 +1122,9 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const mxnet::TShape&
int* axis, index_t* begin, index_t* end) {
*axis = param.axis;
if (*axis < 0) {
- *axis += static_cast<int>(ishape.ndim());
+ *axis += ishape.ndim();
}
- CHECK(*axis < static_cast<int>(ishape.ndim()) && *axis >= 0) <<
+ CHECK(*axis < ishape.ndim() && *axis >= 0) <<
"Transformed axis must be smaller than the source ndim and larger than zero! Recieved axis=" <<
param.axis << ", src_ndim=" << ishape.ndim() << ", transformed axis=" << *axis;
index_t axis_size = static_cast<index_t>(ishape[*axis]);
@@ -1133,7 +1133,7 @@ inline void GetSliceAxisParams(const SliceAxisParam& param, const mxnet::TShape&
if (*begin < 0) {
*begin += axis_size;
}
- if (axis_size) {
+ if (axis_size > 0) {
if (!static_cast<bool>(param.end)) {
*end = axis_size;
} else {
@@ -2598,7 +2598,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
mxnet::TShape dshape = in_attrs->at(split_enum::kData);
mxnet::TShape ishape = in_attrs->at(split_enum::kData);
- if (!shape_is_known(dshape)) return false;
+ if (!mxnet::ndim_is_known(dshape)) return false;
if (param.axis >= 0) {
CHECK_LT(static_cast<size_t>(param.axis), dshape.ndim());
} else {
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5cf22c3..5173525 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2110,7 +2110,7 @@ def test_reshape():
for i in range(len(src_shape)):
holdout_src_shape = list(src_shape)
- holdout_src_shape[i] = -1
+ holdout_src_shape[i] = 0
holdout_src_shape = tuple(holdout_src_shape)
net = mx.sym.Variable('data')
net = mx.sym.elemwise_add(net.reshape(shape_args, reverse=reverse), mx.sym.ones(shape=dst_shape))
@@ -4243,7 +4243,8 @@ def test_tile():
assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3))
test_normal_case()
- test_empty_tensor()
+ with mx.numpy.enable_np_comp():
+ test_empty_tensor()
test_empty_reps()
test_tile_backward()
test_tile_numeric_gradient()
@@ -4303,7 +4304,8 @@ def test_one_hot():
test_normal_case(index_type=np.float64)
test_normal_case(index_type=np.float32)
test_normal_case(index_type=np.float16)
- test_empty_indices()
+ with mx.numpy.enable_np_comp():
+ test_empty_indices()
test_zero_depth()