You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/23 20:33:59 UTC

[incubator-mxnet] branch master updated: modify parameters counting of FC and CONV (#7568)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d9b6a3  modify parameters counting of FC and CONV (#7568)
6d9b6a3 is described below

commit 6d9b6a3fd8a251dad2ce846a0ce4ade037ce7bbb
Author: qingzhouzhen <57...@qq.com>
AuthorDate: Thu Aug 24 04:33:57 2017 +0800

    modify parameters counting of FC and CONV (#7568)
---
 python/mxnet/visualization.py | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
index 4dbf680..aa00488 100644
--- a/python/mxnet/visualization.py
+++ b/python/mxnet/visualization.py
@@ -134,12 +134,20 @@ def print_summary(symbol, shape=None, line_length=120, positions=[.44, .64, .74,
                             pre_filter = pre_filter + int(shape[0])
         cur_param = 0
         if op == 'Convolution':
-            cur_param = pre_filter * int(node["attr"]["num_filter"])
-            for k in _str2tuple(node["attr"]["kernel"]):
-                cur_param *= int(k)
-            cur_param += int(node["attr"]["num_filter"])
+            if ("no_bias" in node["attr"]) and (node["attr"]["no_bias"] == 'True'):
+                cur_param = pre_filter * int(node["attr"]["num_filter"])
+                for k in _str2tuple(node["attr"]["kernel"]):
+                    cur_param *= int(k)
+            else:
+                cur_param = pre_filter * int(node["attr"]["num_filter"])
+                for k in _str2tuple(node["attr"]["kernel"]):
+                    cur_param *= int(k)
+                cur_param += int(node["attr"]["num_filter"])
         elif op == 'FullyConnected':
-            cur_param = pre_filter * (int(node["attr"]["num_hidden"]) + 1)
+            if ("no_bias" in node["attr"]) and (node["attr"]["no_bias"] == 'True'):
+                cur_param = pre_filter * (int(node["attr"]["num_hidden"]))
+            else:
+                cur_param = (pre_filter+1) * (int(node["attr"]["num_hidden"]))
         elif op == 'BatchNorm':
             key = node["name"] + "_output"
             if show_shape:

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].