You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/11 20:52:58 UTC

[incubator-mxnet] branch master updated: Symbol __getitem__ using list_outputs() is too expensive (#8989)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2700ddb  Symbol __getitem__ using list_outputs() is too expensive (#8989)
2700ddb is described below

commit 2700ddbbeef212879802f7f0c0812192ec5c2b77
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Mon Dec 11 12:52:54 2017 -0800

    Symbol __getitem__ using list_outputs() is too expensive (#8989)
    
    * Symbol __getitem__ using list_outputs() is too expensive, when it only cares about the output count in most cases
    
    * unit
    
    * GetNumOutputs() and __len__ changes per PR comments
    
    * Set commits
    
    * return unsigned int (as far as pylint is concerned)
---
 include/mxnet/c_api.h                | 11 +++++++++++
 nnvm                                 |  2 +-
 python/mxnet/symbol/symbol.py        | 27 ++++++++++++++++++++++++---
 src/c_api/c_api_symbolic.cc          |  5 +++++
 tests/python/unittest/test_symbol.py |  3 +++
 5 files changed, 44 insertions(+), 4 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index faa4535..d34b194 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1051,6 +1051,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
 MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
                                   mx_uint *out_size,
                                   const char ***out_str_array);
+
+/*!
+ * \brief Get number of outputs of the symbol.
+ * \param symbol The symbol
+ * \param out_size number of outputs
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
+                                     mx_uint *output_count);
+
 /*!
  * \brief Get a symbol that contains all the internals.
  * \param symbol The symbol
@@ -1077,6 +1087,7 @@ MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
 MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
                                 mx_uint index,
                                 SymbolHandle *out);
+
 /*!
  * \brief List auxiliary states in the symbol.
  * \param symbol the symbol
diff --git a/nnvm b/nnvm
index 8d79cfd..7a052d6 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit 8d79cfd0b42fbe9f6ad75886d495065d5500b9dd
+Subproject commit 7a052d678455f1c96538c1cc5a25f11115363558
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index ce7776d..22212b0 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -491,14 +491,16 @@ class Symbol(SymbolBase):
             Indexing key
 
         """
-        output_names = self.list_outputs()
+        output_count = len(self)
         if isinstance(index, py_slice):
             start = 0 if index.start is None else index.start
-            stop = len(output_names) if index.stop is None else index.stop
+            stop = output_count if index.stop is None else index.stop
             step = 1 if index.step is None else index.step
             return Group([self[i] for i in range(start, stop, step)])
 
         if isinstance(index, string_types):
+            # Returning this list of names is expensive. Some symbols may have hundreds of outputs
+            output_names = self.list_outputs()
             idx = None
             for i, name in enumerate(output_names):
                 if name == index:
@@ -511,7 +513,7 @@ class Symbol(SymbolBase):
 
         if not isinstance(index, int):
             raise TypeError('Symbol only support integer index to fetch i-th output')
-        if index >= len(output_names):
+        if index >= output_count:
             # Important, python determines the end by this exception
             raise IndexError
         handle = SymbolHandle()
@@ -745,6 +747,25 @@ class Symbol(SymbolBase):
             self.handle, ctypes.byref(size), ctypes.byref(sarr)))
         return [py_str(sarr[i]) for i in range(size.value)]
 
+    def __len__(self):
+        """Get number of outputs for the symbol.
+
+        Example
+        -------
+        >>> a = mx.sym.var('a')
+        >>> b = mx.sym.var('b')
+        >>> c = a + b
+        >>> len(c)
+
+        Returns
+        -------
+        len(self): Number of outputs
+            Number of outputs
+        """
+        output_count = mx_uint()
+        check_call(_LIB.MXSymbolGetNumOutputs(self.handle, ctypes.byref(output_count)))
+        return output_count.value
+
     def list_auxiliary_states(self):
         """Lists all the auxiliary states in the symbol.
 
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index dad71b0..3668af0 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -310,6 +310,11 @@ int MXSymbolListOutputs(SymbolHandle symbol,
   return NNSymbolListOutputNames(symbol, out_size, out_str_array);
 }
 
+int MXSymbolGetNumOutputs(SymbolHandle symbol,
+                           mx_uint *output_count) {
+  return NNSymbolGetNumOutputs(symbol, output_count);
+}
+
 int MXSymbolCompose(SymbolHandle sym,
                     const char *name,
                     mx_uint num_args,
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 30e76a2..8fba1cc 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -46,6 +46,7 @@ def test_symbol_compose():
     composed = net2(fc3_data=net1, name='composed')
     multi_out = mx.symbol.Group([composed, net1])
     assert len(multi_out.list_outputs()) == 2
+    assert len(multi_out) == 2
 
 
 def test_symbol_copy():
@@ -72,7 +73,9 @@ def test_symbol_children():
     net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100)
 
     assert net1.get_children().list_outputs() == ['fc1_output', 'fc2_weight', 'fc2_bias']
+    assert len(net1.get_children()) == 3
     assert net1.get_children().get_children().list_outputs() == ['data', 'fc1_weight', 'fc1_bias']
+    assert len(net1.get_children().get_children()) == 3
     assert net1.get_children()['fc2_weight'].list_arguments() == ['fc2_weight']
     assert net1.get_children()['fc2_weight'].get_children() is None
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].