You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/11 20:52:58 UTC
[incubator-mxnet] branch master updated: Symbol __getitem__ using
list_outputs() is too expensive (#8989)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 2700ddb Symbol __getitem__ using list_outputs() is too expensive (#8989)
2700ddb is described below
commit 2700ddbbeef212879802f7f0c0812192ec5c2b77
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Mon Dec 11 12:52:54 2017 -0800
Symbol __getitem__ using list_outputs() is too expensive (#8989)
* Symbol __getitem__ using list_outputs() is too expensive, when it only cares about the output count in most cases
* unit
* GetNumOutputs() and __len__ changes per PR comments
* Set commits
* return unsigned int (as far as pylint is concerned)
---
include/mxnet/c_api.h | 11 +++++++++++
nnvm | 2 +-
python/mxnet/symbol/symbol.py | 27 ++++++++++++++++++++++++---
src/c_api/c_api_symbolic.cc | 5 +++++
tests/python/unittest/test_symbol.py | 3 +++
5 files changed, 44 insertions(+), 4 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index faa4535..d34b194 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1051,6 +1051,16 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
+
+/*!
+ * \brief Get number of outputs of the symbol.
+ * \param symbol The symbol
+ * \param out_size number of outputs
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
+ mx_uint *output_count);
+
/*!
* \brief Get a symbol that contains all the internals.
* \param symbol The symbol
@@ -1077,6 +1087,7 @@ MXNET_DLL int MXSymbolGetChildren(SymbolHandle symbol,
MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
mx_uint index,
SymbolHandle *out);
+
/*!
* \brief List auxiliary states in the symbol.
* \param symbol the symbol
diff --git a/nnvm b/nnvm
index 8d79cfd..7a052d6 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit 8d79cfd0b42fbe9f6ad75886d495065d5500b9dd
+Subproject commit 7a052d678455f1c96538c1cc5a25f11115363558
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index ce7776d..22212b0 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -491,14 +491,16 @@ class Symbol(SymbolBase):
Indexing key
"""
- output_names = self.list_outputs()
+ output_count = len(self)
if isinstance(index, py_slice):
start = 0 if index.start is None else index.start
- stop = len(output_names) if index.stop is None else index.stop
+ stop = output_count if index.stop is None else index.stop
step = 1 if index.step is None else index.step
return Group([self[i] for i in range(start, stop, step)])
if isinstance(index, string_types):
+ # Returning this list of names is expensive. Some symbols may have hundreds of outputs
+ output_names = self.list_outputs()
idx = None
for i, name in enumerate(output_names):
if name == index:
@@ -511,7 +513,7 @@ class Symbol(SymbolBase):
if not isinstance(index, int):
raise TypeError('Symbol only support integer index to fetch i-th output')
- if index >= len(output_names):
+ if index >= output_count:
# Important, python determines the end by this exception
raise IndexError
handle = SymbolHandle()
@@ -745,6 +747,25 @@ class Symbol(SymbolBase):
self.handle, ctypes.byref(size), ctypes.byref(sarr)))
return [py_str(sarr[i]) for i in range(size.value)]
+ def __len__(self):
+ """Get number of outputs for the symbol.
+
+ Example
+ -------
+ >>> a = mx.sym.var('a')
+ >>> b = mx.sym.var('b')
+ >>> c = a + b
+ >>> len(c)
+
+ Returns
+ -------
+ len(self): Number of outputs
+ Number of outputs
+ """
+ output_count = mx_uint()
+ check_call(_LIB.MXSymbolGetNumOutputs(self.handle, ctypes.byref(output_count)))
+ return output_count.value
+
def list_auxiliary_states(self):
"""Lists all the auxiliary states in the symbol.
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index dad71b0..3668af0 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -310,6 +310,11 @@ int MXSymbolListOutputs(SymbolHandle symbol,
return NNSymbolListOutputNames(symbol, out_size, out_str_array);
}
+int MXSymbolGetNumOutputs(SymbolHandle symbol,
+ mx_uint *output_count) {
+ return NNSymbolGetNumOutputs(symbol, output_count);
+}
+
int MXSymbolCompose(SymbolHandle sym,
const char *name,
mx_uint num_args,
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 30e76a2..8fba1cc 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -46,6 +46,7 @@ def test_symbol_compose():
composed = net2(fc3_data=net1, name='composed')
multi_out = mx.symbol.Group([composed, net1])
assert len(multi_out.list_outputs()) == 2
+ assert len(multi_out) == 2
def test_symbol_copy():
@@ -72,7 +73,9 @@ def test_symbol_children():
net1 = mx.symbol.FullyConnected(data=oldfc, name='fc2', num_hidden=100)
assert net1.get_children().list_outputs() == ['fc1_output', 'fc2_weight', 'fc2_bias']
+ assert len(net1.get_children()) == 3
assert net1.get_children().get_children().list_outputs() == ['data', 'fc1_weight', 'fc1_bias']
+ assert len(net1.get_children().get_children()) == 3
assert net1.get_children()['fc2_weight'].list_arguments() == ['fc2_weight']
assert net1.get_children()['fc2_weight'].get_children() is None
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].