You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/03 20:10:54 UTC

[incubator-mxnet] branch master updated: improve convert_symbol.py add support to SUM with coeff (#7120)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9add5ae  improve convert_symbol.py add support to SUM with coeff (#7120)
9add5ae is described below

commit 9add5ae417cd6fa5e9153c1f19195f5b88c01305
Author: 梁德澎 <li...@gmail.com>
AuthorDate: Fri Aug 4 04:10:52 2017 +0800

    improve convert_symbol.py add support to SUM with coeff (#7120)
    
    * improve convert_symbol.py add support to SUM with coeff
    
    * fix code style
    
    * fix code style
    
    * fix code style
---
 tools/caffe_converter/convert_symbol.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/tools/caffe_converter/convert_symbol.py b/tools/caffe_converter/convert_symbol.py
index fad89c4..c384c76 100644
--- a/tools/caffe_converter/convert_symbol.py
+++ b/tools/caffe_converter/convert_symbol.py
@@ -207,6 +207,7 @@ def _parse_proto(prototxt_fname):
             need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
         if layer.type == 'Eltwise':
             type_string = 'mx.symbol.broadcast_add'
+            param = layer.eltwise_param
             param_string = ""
             need_flatten[name] = False
         if layer.type == 'Reshape':
@@ -239,8 +240,15 @@ def _parse_proto(prototxt_fname):
                 symbol_string += "%s = %s(name='%s', data=%s %s)\n" % (
                     name, type_string, name, mapping[bottom[0]], param_string)
             else:
-                symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % (
-                    name, type_string, name, ','.join([mapping[x] for x in bottom]), param_string)
+                if layer.type == 'Eltwise' and param.operation == 1 and len(param.coeff) > 0:
+                    symbol_string += "%s = " % name
+                    symbol_string += " + ".join(["%s * %s" % (
+                        mapping[bottom[i]], param.coeff[i]) for i in range(len(param.coeff))])
+                    symbol_string += "\n"
+                else:
+                    symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % (
+                        name, type_string, name, ','.join(
+                            [mapping[x] for x in bottom]), param_string)
         for j in range(len(layer.top)):
             mapping[layer.top[j]] = name
         output_name = name

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].