You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/03 20:10:54 UTC
[incubator-mxnet] branch master updated: improve convert_symbol.py
add support to SUM with coeff (#7120)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9add5ae improve convert_symbol.py add support to SUM with coeff (#7120)
9add5ae is described below
commit 9add5ae417cd6fa5e9153c1f19195f5b88c01305
Author: 梁德澎 <li...@gmail.com>
AuthorDate: Fri Aug 4 04:10:52 2017 +0800
improve convert_symbol.py add support to SUM with coeff (#7120)
* improve convert_symbol.py add support to SUM with coeff
* fix code style
* fix code style
* fix code style
---
tools/caffe_converter/convert_symbol.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/tools/caffe_converter/convert_symbol.py b/tools/caffe_converter/convert_symbol.py
index fad89c4..c384c76 100644
--- a/tools/caffe_converter/convert_symbol.py
+++ b/tools/caffe_converter/convert_symbol.py
@@ -207,6 +207,7 @@ def _parse_proto(prototxt_fname):
need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
if layer.type == 'Eltwise':
type_string = 'mx.symbol.broadcast_add'
+ param = layer.eltwise_param
param_string = ""
need_flatten[name] = False
if layer.type == 'Reshape':
@@ -239,8 +240,15 @@ def _parse_proto(prototxt_fname):
symbol_string += "%s = %s(name='%s', data=%s %s)\n" % (
name, type_string, name, mapping[bottom[0]], param_string)
else:
- symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % (
- name, type_string, name, ','.join([mapping[x] for x in bottom]), param_string)
+ if layer.type == 'Eltwise' and param.operation == 1 and len(param.coeff) > 0:
+ symbol_string += "%s = " % name
+ symbol_string += " + ".join(["%s * %s" % (
+ mapping[bottom[i]], param.coeff[i]) for i in range(len(param.coeff))])
+ symbol_string += "\n"
+ else:
+ symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % (
+ name, type_string, name, ','.join(
+ [mapping[x] for x in bottom]), param_string)
for j in range(len(layer.top)):
mapping[layer.top[j]] = name
output_name = name
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].