You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/22 01:49:57 UTC
[incubator-mxnet] branch master updated: Add infer_type_partial
(#14214)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 87441c3 Add infer_type_partial (#14214)
87441c3 is described below
commit 87441c38175205e99be08a4889b15bdb81a1cbb9
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Thu Feb 21 17:49:34 2019 -0800
Add infer_type_partial (#14214)
* Add infer_type_partial
* Added infer_type_partial to symbol docs main area
* Added test
---
docs/api/python/symbol/symbol.md | 1 +
include/mxnet/c_api.h | 32 ++++++++++++++
python/mxnet/symbol/symbol.py | 81 +++++++++++++++++++++++++++++++++++-
src/c_api/c_api_symbolic.cc | 21 ++++++++++
tests/python/unittest/test_symbol.py | 7 ++++
5 files changed, 141 insertions(+), 1 deletion(-)
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 9cab2c5..fea746b 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -337,6 +337,7 @@ Composite multiple symbols into a new one by an operator.
:nosignatures:
Symbol.infer_type
+ Symbol.infer_type_partial
Symbol.infer_shape
Symbol.infer_shape_partial
```
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 13ee903..76a4995 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1563,6 +1563,38 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
int *complete);
/*!
+ * \brief partially infer type of unknown input types given the known one.
+ *
+ * Return partially inferred results if not all types could be inferred.
+ * The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
+ * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
+ *
+ * \param sym symbol handle
+ * \param num_args numbe of input arguments.
+ * \param keys the key of keyword args (optional)
+ * \param arg_type_data the content of the CSR
+ * \param in_type_size sizeof the returning array of in_types
+ * \param in_type_data returning array of pointers to head of the input type.
+ * \param out_type_size sizeof the returning array of out_types
+ * \param out_type_data returning array of pointers to head of the input type.
+ * \param aux_type_size sizeof the returning array of aux_types
+ * \param aux_type_data returning array of pointers to head of the auxiliary type.
+ * \param complete whether infer type completes or more information is needed.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
+ mx_uint num_args,
+ const char** keys,
+ const int *arg_type_data,
+ mx_uint *in_type_size,
+ const int **in_type_data,
+ mx_uint *out_type_size,
+ const int **out_type_data,
+ mx_uint *aux_type_size,
+ const int **aux_type_data,
+ int *complete);
+
+/*!
* \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8
* \param sym_handle symbol to be converted
* \param ret_sym_handle quantized symbol result
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 43de0c9..3e3e79e 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -882,6 +882,81 @@ class Symbol(SymbolBase):
List of auxiliary state types.
The order is same as the order of list_auxiliary_states().
"""
+ try:
+ res = self._infer_type_impl(False, *args, **kwargs)
+ if res[1] is None:
+ arg_shapes, _, _ = self._infer_type_impl(True, *args, **kwargs)
+ arg_names = self.list_arguments()
+ unknowns = []
+ for name, dtype in zip(arg_names, arg_shapes):
+ if not dtype:
+ if len(unknowns) >= 10:
+ unknowns.append('...')
+ break
+ unknowns.append('%s: %s' % (name, str(dtype)))
+ warnings.warn(
+ "Cannot decide type for the following arguments. " +
+ "Consider providing them as input:\n\t" +
+ "\n\t".join(unknowns), stacklevel=2)
+ return res
+ except MXNetError:
+ print("infer_type error. Arguments:")
+ for i, arg in enumerate(args):
+ print(" #%d: %s" % (i, arg))
+ for k, v in kwargs.items():
+ print(" %s: %s" % (k, v))
+ raise
+
+ def infer_type_partial(self, *args, **kwargs):
+ """Infers the type partially.
+
+ This functions works the same way as `infer_type`,
+ except that this function can return partial results.
+
+ In the following example, information about fc2 is not available. So, `infer_shape`
+ will return a tuple of `None` values but `infer_shape_partial` will return partial values.
+
+ Example
+ -------
+ >>> data = mx.sym.Variable('data')
+ >>> prev = mx.sym.Variable('prev')
+ >>> casted_prev = mx.sym.cast(prev, dtype='float32')
+ >>> out = mx.sym.Activation(data=mx.sym.elemwise_add(data, casted_prev), act_type='relu')
+ >>> out.list_arguments()
+ ['data', 'prev']
+ >>> out.infer_type(data='float32')
+ (None, None, None)
+ >>> out.infer_type_partial(data='float32')
+ ([numpy.float32, None], [numpy.float32], [])
+ >>> # infers type if you give information about prev
+ >>> out.infer_type(data='float32', prev='float16')
+ ([numpy.float32, numpy.float16], [numpy.float32], [])
+
+ Parameters
+ ----------
+ *args :
+ Type of known arguments in a positional way.
+ Unknown type can be marked as None.
+
+ **kwargs :
+ Keyword arguments of known types.
+
+ Returns
+ -------
+ arg_types : list of numpy.dtype or None
+ List of argument types.
+ The order is same as the order of list_arguments().
+ out_types : list of numpy.dtype or None
+ List of output types.
+ The order is same as the order of list_outputs().
+ aux_types : list of numpy.dtype or None
+ List of auxiliary state types.
+ The order is same as the order of list_auxiliary_states().
+ """
+ return self._infer_type_impl(True, *args, **kwargs)
+
+ def _infer_type_impl(self, partial, *args, **kwargs):
+ """The actual implementation for calling type inference API."""
# pylint: disable=too-many-locals
if len(args) != 0 and len(kwargs) != 0:
raise ValueError('Can only specify known argument \
@@ -912,7 +987,11 @@ class Symbol(SymbolBase):
aux_type_size = mx_uint()
aux_type_data = ctypes.POINTER(ctypes.c_int)()
complete = ctypes.c_int()
- check_call(_LIB.MXSymbolInferType(
+ if partial:
+ infer_func = _LIB.MXSymbolInferTypePartial
+ else:
+ infer_func = _LIB.MXSymbolInferType
+ check_call(infer_func(
self.handle,
mx_uint(len(sdata)),
keys,
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 32b63c1..9f0d283 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -638,6 +638,27 @@ int MXSymbolInferType(SymbolHandle sym,
API_END();
}
+int MXSymbolInferTypePartial(SymbolHandle sym,
+ mx_uint num_args,
+ const char** keys,
+ const int *arg_type_data,
+ mx_uint *in_type_size,
+ const int **in_type_data,
+ mx_uint *out_type_size,
+ const int **out_type_data,
+ mx_uint *aux_type_size,
+ const int **aux_type_data,
+ int *complete) {
+ int succ;
+ *complete = 1;
+ return MXSymbolInferType(sym, num_args, keys,
+ arg_type_data,
+ in_type_size, in_type_data,
+ out_type_size, out_type_data,
+ aux_type_size, aux_type_data,
+ &succ);
+}
+
int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
LOG(FATAL) << "not implemented";
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index c5c1b01..ac4564b 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -120,6 +120,13 @@ def test_symbol_infer_type():
assert out == [np.float32]
assert aux == []
+ # partial infer type
+ arg, out, aux = mlp.infer_type_partial()
+ assert arg == [None, np.float32, np.float32, np.float32]
+ assert out == [np.float32]
+ assert aux == []
+
+
def test_symbol_infer_shape():
num_hidden = 128
num_dim = 64