You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/16 19:39:23 UTC
[incubator-mxnet] branch master updated: Add embedding to
print_summary (#12796)
This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d096aa5 Add embedding to print_summary (#12796)
d096aa5 is described below
commit d096aa5e88307c6e08c03c686893ef35470c8485
Author: Hyung-jun Kim <so...@users.noreply.github.com>
AuthorDate: Wed Oct 17 04:39:05 2018 +0900
Add embedding to print_summary (#12796)
---
python/mxnet/visualization.py | 2 ++
tests/python/unittest/test_viz.py | 5 +++--
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 8294620..a0eb253 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -157,6 +157,8 @@ def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74,
if show_shape:
num_filter = shape_dict[key][1]
cur_param = int(num_filter) * 2
+ elif op == 'Embedding':
+ cur_param = int(node["attrs"]['input_dim']) * int(node["attrs"]['output_dim'])
if not pre_node:
first_connection = ''
else:
diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py
index eb5921f..fe564b0 100644
--- a/tests/python/unittest/test_viz.py
+++ b/tests/python/unittest/test_viz.py
@@ -24,7 +24,8 @@ import mxnet as mx
def test_print_summary():
data = mx.sym.Variable('data')
bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
- conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
+ emb1= mx.symbol.Embedding(data = data, name='emb1', input_dim=100, output_dim=28)
+ conv1= mx.symbol.Convolution(data = emb1, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')
@@ -33,7 +34,7 @@ def test_print_summary():
sc1 = mx.symbol.SliceChannel(data=fc2, num_outputs=10, name="slice_1", squeeze_axis=0)
mx.viz.print_summary(sc1)
shape = {}
- shape["data"]=(1,3,28,28)
+ shape["data"]=(1,3,28)
mx.viz.print_summary(sc1, shape)
def graphviz_exists():