You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by wk...@apache.org on 2019/03/18 02:50:26 UTC
[incubator-mxnet] branch master updated: Fix crashes on
visualization (#14425)
This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ab5b44c Fix crashes on visualization (#14425)
ab5b44c is described below
commit ab5b44ca07b6a41ee23f96559d420e58408b799a
Author: Vandana Kannan <va...@users.noreply.github.com>
AuthorDate: Sun Mar 17 19:49:57 2019 -0700
Fix crashes on visualization (#14425)
* Check for kernel in Pooling
* Fix Leakyrelu visualization
* Address review comments
* Change all occurences to string format
* Fix lint error
---
python/mxnet/visualization.py | 33 +++++++++++++++++++++------------
1 file changed, 21 insertions(+), 12 deletions(-)
diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index dd3a1df..4101f74 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -205,7 +205,7 @@ def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74,
print('=' * line_length)
else:
print('_' * line_length)
- print('Total params: %s' % total_params)
+ print("Total params: {params}".format(params=total_params))
print('_' * line_length)
def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None, node_attrs={},
@@ -337,24 +337,33 @@ def plot_network(symbol, title="plot", save_format='pdf', shape=None, dtype=None
label = node["name"]
attr["fillcolor"] = cm[0]
elif op == "Convolution":
- label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attrs"]["kernel"])),
- "x".join(_str2tuple(node["attrs"]["stride"]))
- if "stride" in node["attrs"] else "1",
- node["attrs"]["num_filter"])
+ label = "Convolution\n{kernel}/{stride}, {filter}".format(
+ kernel="x".join(_str2tuple(node["attrs"]["kernel"])),
+ stride="x".join(_str2tuple(node["attrs"]["stride"]))
+ if "stride" in node["attrs"] else "1",
+ filter=node["attrs"]["num_filter"]
+ )
attr["fillcolor"] = cm[1]
elif op == "FullyConnected":
- label = r"FullyConnected\n%s" % node["attrs"]["num_hidden"]
+ label = "FullyConnected\n{hidden}".format(hidden=node["attrs"]["num_hidden"])
attr["fillcolor"] = cm[1]
elif op == "BatchNorm":
attr["fillcolor"] = cm[3]
- elif op in ('Activation', 'LeakyReLU'):
- label = r"%s\n%s" % (op, node["attrs"]["act_type"])
+ elif op == 'Activation':
+ act_type = node["attrs"]["act_type"]
+ label = 'Activation\n{activation}'.format(activation=act_type)
+ attr["fillcolor"] = cm[2]
+ elif op == 'LeakyReLU':
+ attrs = node.get("attrs")
+ act_type = attrs.get("act_type", "Leaky") if attrs else "Leaky"
+ label = 'LeakyReLU\n{activation}'.format(activation=act_type)
attr["fillcolor"] = cm[2]
elif op == "Pooling":
- label = r"Pooling\n%s, %s/%s" % (node["attrs"]["pool_type"],
- "x".join(_str2tuple(node["attrs"]["kernel"])),
- "x".join(_str2tuple(node["attrs"]["stride"]))
- if "stride" in node["attrs"] else "1")
+ label = "Pooling\n{pooltype}, {kernel}/{stride}".format(pooltype=node["attrs"]["pool_type"],
+ kernel="x".join(_str2tuple(node["attrs"]["kernel"]))
+ if "kernel" in node["attrs"] else "[]",
+ stride="x".join(_str2tuple(node["attrs"]["stride"]))
+ if "stride" in node["attrs"] else "1")
attr["fillcolor"] = cm[4]
elif op in ("Concat", "Flatten", "Reshape"):
attr["fillcolor"] = cm[5]