You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/03/15 17:08:16 UTC

[incubator-mxnet] branch master updated: Add repr for SymbolBlock (#14423)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 226212b  Add repr for SymbolBlock (#14423)
226212b is described below

commit 226212b40b5b1a43a3d91d3a810541887beaae8c
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Fri Mar 15 10:07:55 2019 -0700

    Add repr for SymbolBlock (#14423)
    
    * Add repr for SymbolBlock
    
    * Add a test
    
    * Correct self.cached_graph
    
    * Address review comments
---
 python/mxnet/gluon/block.py         | 8 ++++++++
 tests/python/unittest/test_gluon.py | 5 +++++
 2 files changed, 13 insertions(+)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7047364..2f3ed91 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1024,6 +1024,14 @@ class SymbolBlock(HybridBlock):
             ret.collect_params().load(param_file, ctx=ctx)
         return ret
 
+    def __repr__(self):
+        s = '{name}(\n{modstr}\n)'
+        modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1],
+                                                                           numinputs=len(self._cached_graph[0]),
+                                                                           numoutputs=len(self._cached_graph[1].
+                                                                                          list_outputs()))])
+        return s.format(name=self.__class__.__name__,
+                        modstr=modstr)
 
     def __init__(self, outputs, inputs, params=None):
         super(SymbolBlock, self).__init__(prefix=None, params=None)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 34380dc..6af7a5f 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -882,8 +882,13 @@ def test_import():
     net2 = gluon.SymbolBlock.imports(
         'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
     out2 = net2(data)
+    lines = str(net2).splitlines()
 
     assert_almost_equal(out1.asnumpy(), out2.asnumpy())
+    assert lines[0] == 'SymbolBlock('
+    assert lines[1]
+    assert lines[2] == ')'
+
 
 @with_seed()
 def test_hybrid_stale_cache():