You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2018/10/16 19:39:23 UTC

[incubator-mxnet] branch master updated: Add embedding to print_summary (#12796)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d096aa5  Add embedding to print_summary (#12796)
d096aa5 is described below

commit d096aa5e88307c6e08c03c686893ef35470c8485
Author: Hyung-jun Kim <so...@users.noreply.github.com>
AuthorDate: Wed Oct 17 04:39:05 2018 +0900

    Add embedding to print_summary (#12796)
---
 python/mxnet/visualization.py     | 2 ++
 tests/python/unittest/test_viz.py | 5 +++--
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 8294620..a0eb253 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -157,6 +157,8 @@ def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74,
             if show_shape:
                 num_filter = shape_dict[key][1]
                 cur_param = int(num_filter) * 2
+        elif op == 'Embedding':
+            cur_param = int(node["attrs"]['input_dim']) * int(node["attrs"]['output_dim'])
         if not pre_node:
             first_connection = ''
         else:
diff --git a/tests/python/unittest/test_viz.py b/tests/python/unittest/test_viz.py
index eb5921f..fe564b0 100644
--- a/tests/python/unittest/test_viz.py
+++ b/tests/python/unittest/test_viz.py
@@ -24,7 +24,8 @@ import mxnet as mx
 def test_print_summary():
     data = mx.sym.Variable('data')
     bias = mx.sym.Variable('fc1_bias', lr_mult=1.0)
-    conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
+    emb1= mx.symbol.Embedding(data = data, name='emb1', input_dim=100, output_dim=28)
+    conv1= mx.symbol.Convolution(data = emb1, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
     bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
     act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
     mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')
@@ -33,7 +34,7 @@ def test_print_summary():
     sc1 = mx.symbol.SliceChannel(data=fc2, num_outputs=10, name="slice_1", squeeze_axis=0)
     mx.viz.print_summary(sc1)
     shape = {}
-    shape["data"]=(1,3,28,28)
+    shape["data"]=(1,3,28)
     mx.viz.print_summary(sc1, shape)
 
 def graphviz_exists():