You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/03/18 02:50:26 UTC

[incubator-mxnet] branch master updated: Fix crashes on visualization (#14425)

This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ab5b44c  Fix crashes on visualization (#14425)
ab5b44c is described below

commit ab5b44ca07b6a41ee23f96559d420e58408b799a
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Sun Mar 17 19:49:57 2019 -0700

    Fix crashes on visualization (#14425)
    
    * Check for kernel in Pooling
    
    * Fix Leakyrelu visualization
    
    * Address review comments
    
    * Change all occurences to string format
    
    * Fix lint error
---
 python/mxnet/visualization.py | 33 +++++++++++++++++++++------------
 1 file changed, 21 insertions(+), 12 deletions(-)

diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index dd3a1df..4101f74 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -205,7 +205,7 @@ def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74,
             print('=' * line_length)
         else:
             print('_' * line_length)
-    print('Total params: %s' % total_params)
+    print("Total params: {params}".format(params=total_params))
     print('_' * line_length)
 
 def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
@@ -337,24 +337,33 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None
             label = node["name"]
             attr["fillcolor"] = cm[0]
         elif op == "Convolution":
-            label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])),
-                                                 "x".join(_str2tuple(node["attrs"]["stride"]))
-                                                 if "stride" in node["attrs"] else "1",
-                                                 node["attrs"]["num_filter"])
+            label = "Convolution\n{kernel}/{stride}, {filter}".format(
+                kernel="x".join(_str2tuple(node["attrs"]["kernel"])),
+                stride="x".join(_str2tuple(node["attrs"]["stride"]))
+                if "stride" in node["attrs"] else "1",
+                filter=node["attrs"]["num_filter"]
+            )
             attr["fillcolor"] = cm[1]
         elif op == "FullyConnected":
-            label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"]
+            label = "FullyConnected\n{hidden}".format(hidden=node["attrs"]["num_hidden"])
             attr["fillcolor"] = cm[1]
         elif op == "BatchNorm":
             attr["fillcolor"] = cm[3]
-        elif op in ('Activation', 'LeakyReLU'):
-            label = r"%s\n%s" % (op, node["attrs"]["act_type"])
+        elif op == 'Activation':
+            act_type = node["attrs"]["act_type"]
+            label = 'Activation\n{activation}'.format(activation=act_type)
+            attr["fillcolor"] = cm[2]
+        elif op == 'LeakyReLU':
+            attrs = node.get("attrs")
+            act_type = attrs.get("act_type", "Leaky") if attrs else "Leaky"
+            label = 'LeakyReLU\n{activation}'.format(activation=act_type)
             attr["fillcolor"] = cm[2]
         elif op == "Pooling":
-            label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"],
-                                             "x".join(_str2tuple(node["attrs"]["kernel"])),
-                                             "x".join(_str2tuple(node["attrs"]["stride"]))
-                                             if "stride" in node["attrs"] else "1")
+            label = "Pooling\n{pooltype}, {kernel}/{stride}".format(pooltype=node["attrs"]["pool_type"],
+                                                                    kernel="x".join(_str2tuple(node["attrs"]["kernel"]))
+                                                                    if "kernel" in node["attrs"] else "[]",
+                                                                    stride="x".join(_str2tuple(node["attrs"]["stride"]))
+                                                                    if "stride" in node["attrs"] else "1")
             attr["fillcolor"] = cm[4]
         elif op in ("Concat", "Flatten", "Reshape"):
             attr["fillcolor"] = cm[5]