You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/15 17:08:16 UTC
[incubator-mxnet] branch master updated: Add repr for SymbolBlock
(#14423)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 226212b Add repr for SymbolBlock (#14423)
226212b is described below
commit 226212b40b5b1a43a3d91d3a810541887beaae8c
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Fri Mar 15 10:07:55 2019 -0700
Add repr for SymbolBlock (#14423)
* Add repr for SymbolBlock
* Add a test
* Correct self.cached_graph
* Address review comments
---
python/mxnet/gluon/block.py | 8 ++++++++
tests/python/unittest/test_gluon.py | 5 +++++
2 files changed, 13 insertions(+)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7047364..2f3ed91 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1024,6 +1024,14 @@ class SymbolBlock(HybridBlock):
ret.collect_params().load(param_file, ctx=ctx)
return ret
+ def __repr__(self):
+ s = '{name}(\n{modstr}\n)'
+ modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1],
+ numinputs=len(self._cached_graph[0]),
+ numoutputs=len(self._cached_graph[1].
+ list_outputs()))])
+ return s.format(name=self.__class__.__name__,
+ modstr=modstr)
def __init__(self, outputs, inputs, params=None):
super(SymbolBlock, self).__init__(prefix=None, params=None)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 34380dc..6af7a5f 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -882,8 +882,13 @@ def test_import():
net2 = gluon.SymbolBlock.imports(
'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
out2 = net2(data)
+ lines = str(net2).splitlines()
assert_almost_equal(out1.asnumpy(), out2.asnumpy())
+ assert lines[0] == 'SymbolBlock('
+ assert lines[1]
+ assert lines[2] == ')'
+
@with_seed()
def test_hybrid_stale_cache():