You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/02/22 01:49:57 UTC

[incubator-mxnet] branch master updated: Add infer_type_partial (#14214)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 87441c3  Add infer_type_partial (#14214)
87441c3 is described below

commit 87441c38175205e99be08a4889b15bdb81a1cbb9
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Thu Feb 21 17:49:34 2019 -0800

    Add infer_type_partial (#14214)
    
    * Add infer_type_partial
    
    * Added infer_type_partial to symbol docs main area
    
    * Added test
---
 docs/api/python/symbol/symbol.md     |  1 +
 include/mxnet/c_api.h                | 32 ++++++++++++++
 python/mxnet/symbol/symbol.py        | 81 +++++++++++++++++++++++++++++++++++-
 src/c_api/c_api_symbolic.cc          | 21 ++++++++++
 tests/python/unittest/test_symbol.py |  7 ++++
 5 files changed, 141 insertions(+), 1 deletion(-)

diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 9cab2c5..fea746b 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -337,6 +337,7 @@ Composite multiple symbols into a new one by an operator.
     :nosignatures:
 
     Symbol.infer_type
+    Symbol.infer_type_partial
     Symbol.infer_shape
     Symbol.infer_shape_partial
 ```
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 13ee903..76a4995 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1563,6 +1563,38 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
                                 int *complete);
 
 /*!
+ * \brief partially infer type of unknown input types given the known one.
+ *
+ *  Return partially inferred results if not all types could be inferred.
+ *  The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
+ *  The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional.
+ *
+ * \param sym symbol handle
+ * \param num_args numbe of input arguments.
+ * \param keys the key of keyword args (optional)
+ * \param arg_type_data the content of the CSR
+ * \param in_type_size sizeof the returning array of in_types
+ * \param in_type_data returning array of pointers to head of the input type.
+ * \param out_type_size sizeof the returning array of out_types
+ * \param out_type_data returning array of pointers to head of the input type.
+ * \param aux_type_size sizeof the returning array of aux_types
+ * \param aux_type_data returning array of pointers to head of the auxiliary type.
+ * \param complete whether infer type completes or more information is needed.
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolInferTypePartial(SymbolHandle sym,
+                                       mx_uint num_args,
+                                       const char** keys,
+                                       const int *arg_type_data,
+                                       mx_uint *in_type_size,
+                                       const int **in_type_data,
+                                       mx_uint *out_type_size,
+                                       const int **out_type_data,
+                                       mx_uint *aux_type_size,
+                                       const int **aux_type_data,
+                                       int *complete);
+
+/*!
  * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8
  * \param sym_handle symbol to be converted
  * \param ret_sym_handle quantized symbol result
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 43de0c9..3e3e79e 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -882,6 +882,81 @@ class Symbol(SymbolBase):
             List of auxiliary state types.
             The order is same as the order of list_auxiliary_states().
         """
+        try:
+            res = self._infer_type_impl(False, *args, **kwargs)
+            if res[1] is None:
+                arg_shapes, _, _ = self._infer_type_impl(True, *args, **kwargs)
+                arg_names = self.list_arguments()
+                unknowns = []
+                for name, dtype in zip(arg_names, arg_shapes):
+                    if not dtype:
+                        if len(unknowns) >= 10:
+                            unknowns.append('...')
+                            break
+                        unknowns.append('%s: %s' % (name, str(dtype)))
+                warnings.warn(
+                    "Cannot decide type for the following arguments. " +
+                    "Consider providing them as input:\n\t" +
+                    "\n\t".join(unknowns), stacklevel=2)
+            return res
+        except MXNetError:
+            print("infer_type error. Arguments:")
+            for i, arg in enumerate(args):
+                print("  #%d: %s" % (i, arg))
+            for k, v in kwargs.items():
+                print("  %s: %s" % (k, v))
+            raise
+
+    def infer_type_partial(self, *args, **kwargs):
+        """Infers the type partially.
+
+        This functions works the same way as `infer_type`,
+        except that this function can return partial results.
+
+        In the following example, information about fc2 is not available. So, `infer_shape`
+        will return a tuple of `None` values but `infer_shape_partial` will return partial values.
+
+        Example
+        -------
+        >>> data = mx.sym.Variable('data')
+        >>> prev = mx.sym.Variable('prev')
+        >>> casted_prev  = mx.sym.cast(prev, dtype='float32')
+        >>> out  = mx.sym.Activation(data=mx.sym.elemwise_add(data, casted_prev), act_type='relu')
+        >>> out.list_arguments()
+        ['data', 'prev']
+        >>> out.infer_type(data='float32')
+        (None, None, None)
+        >>> out.infer_type_partial(data='float32')
+        ([numpy.float32, None], [numpy.float32], [])
+        >>> # infers type if you give information about prev
+        >>> out.infer_type(data='float32', prev='float16')
+        ([numpy.float32, numpy.float16], [numpy.float32], [])
+
+        Parameters
+        ----------
+        *args :
+            Type of known arguments in a positional way.
+            Unknown type can be marked as None.
+
+        **kwargs :
+            Keyword arguments of known types.
+
+        Returns
+        -------
+        arg_types : list of numpy.dtype or None
+            List of argument types.
+            The order is same as the order of list_arguments().
+        out_types : list of numpy.dtype or None
+            List of output types.
+            The order is same as the order of list_outputs().
+        aux_types : list of numpy.dtype or None
+            List of auxiliary state types.
+            The order is same as the order of list_auxiliary_states().
+        """
+        return self._infer_type_impl(True, *args, **kwargs)
+
+    def _infer_type_impl(self, partial, *args, **kwargs):
+        """The actual implementation for calling type inference API."""
         # pylint: disable=too-many-locals
         if len(args) != 0 and len(kwargs) != 0:
             raise ValueError('Can only specify known argument \
@@ -912,7 +987,11 @@ class Symbol(SymbolBase):
         aux_type_size = mx_uint()
         aux_type_data = ctypes.POINTER(ctypes.c_int)()
         complete = ctypes.c_int()
-        check_call(_LIB.MXSymbolInferType(
+        if partial:
+            infer_func = _LIB.MXSymbolInferTypePartial
+        else:
+            infer_func = _LIB.MXSymbolInferType
+        check_call(infer_func(
             self.handle,
             mx_uint(len(sdata)),
             keys,
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 32b63c1..9f0d283 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -638,6 +638,27 @@ int MXSymbolInferType(SymbolHandle sym,
   API_END();
 }
 
+int MXSymbolInferTypePartial(SymbolHandle sym,
+                             mx_uint num_args,
+                             const char** keys,
+                             const int *arg_type_data,
+                             mx_uint *in_type_size,
+                             const int **in_type_data,
+                             mx_uint *out_type_size,
+                             const int **out_type_data,
+                             mx_uint *aux_type_size,
+                             const int **aux_type_data,
+                             int *complete) {
+  int succ;
+  *complete = 1;
+  return MXSymbolInferType(sym, num_args, keys,
+                            arg_type_data,
+                            in_type_size, in_type_data,
+                            out_type_size, out_type_data,
+                            aux_type_size, aux_type_data,
+                            &succ);
+}
+
 int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
   API_BEGIN();
   LOG(FATAL) << "not implemented";
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index c5c1b01..ac4564b 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -120,6 +120,13 @@ def test_symbol_infer_type():
     assert out == [np.float32]
     assert aux == []
 
+    # partial infer type
+    arg, out, aux = mlp.infer_type_partial()
+    assert arg == [None, np.float32, np.float32, np.float32]
+    assert out == [np.float32]
+    assert aux == []
+
+
 def test_symbol_infer_shape():
     num_hidden = 128
     num_dim    = 64