You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/19 20:53:21 UTC
[incubator-mxnet] branch master updated: [COREML] Update the json
getter (#8698)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 38a032c [COREML] Update the json getter (#8698)
38a032c is described below
commit 38a032c886f56f94cdad004c89fd4e1926f85ba6
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Nov 19 12:53:19 2017 -0800
[COREML] Update the json getter (#8698)
* [COREML] Update the json getter
* add docstring
---
tools/coreml/converter/_layers.py | 40 +++++++++++++++++++++++++++++++--------
1 file changed, 32 insertions(+), 8 deletions(-)
diff --git a/tools/coreml/converter/_layers.py b/tools/coreml/converter/_layers.py
index fe00232..4c5ebc6 100644
--- a/tools/coreml/converter/_layers.py
+++ b/tools/coreml/converter/_layers.py
@@ -38,6 +38,30 @@ def _get_node_name(net, node_id):
def _get_node_shape(net, node_id):
return net['nodes'][node_id]['shape']
+def _get_attrs(node):
+ """get attribute dict from node
+
+ This functions keeps backward compatibility
+ for both attr and attrs key in the json field.
+
+ Parameters
+ ----------
+ node : dict
+ The json graph Node
+
+ Returns
+ -------
+ attrs : dict
+ The attr dict, returns empty dict if
+ the field do not exist.
+ """
+ if 'attrs' in node:
+ return node['attrs']
+ elif 'attr' in node:
+ return node['attr']
+ else:
+ return {}
+
# TODO These operators still need to be converted (listing in order of priority):
# High priority:
@@ -108,7 +132,7 @@ def convert_transpose(net, node, module, builder):
"""
input_name, output_name = _get_input_output_name(net, node)
name = node['name']
- param = node['attr']
+ param = _get_attrs(node)
axes = literal_eval(param['axes'])
builder.add_permute(name, axes, input_name, output_name)
@@ -180,7 +204,7 @@ def convert_activation(net, node, module, builder):
"""
input_name, output_name = _get_input_output_name(net, node)
name = node['name']
- mx_non_linearity = node['attr']['act_type']
+ mx_non_linearity = _get_attrs(node)['act_type']
#TODO add SCALED_TANH, SOFTPLUS, SOFTSIGN, SIGMOID_HARD, LEAKYRELU, PRELU, ELU, PARAMETRICSOFTPLUS, THRESHOLDEDRELU, LINEAR
if mx_non_linearity == 'relu':
non_linearity = 'RELU'
@@ -281,7 +305,7 @@ def convert_convolution(net, node, module, builder):
"""
input_name, output_name = _get_input_output_name(net, node)
name = node['name']
- param = node['attr']
+ param = _get_attrs(node)
inputs = node['inputs']
args, _ = module.get_params()
@@ -361,7 +385,7 @@ def convert_pooling(net, node, module, builder):
"""
input_name, output_name = _get_input_output_name(net, node)
name = node['name']
- param = node['attr']
+ param = _get_attrs(node)
layer_type_mx = param['pool_type']
if layer_type_mx == 'max':
@@ -445,9 +469,9 @@ def convert_batchnorm(net, node, module, builder):
eps = 1e-3 # Default value of eps for MXNet.
use_global_stats = False # Default value of use_global_stats for MXNet.
- if 'attr' in node:
- if 'eps' in node['attr']:
- eps = literal_eval(node['attr']['eps'])
+ attrs = _get_attrs(node)
+ if 'eps' in attrs:
+ eps = literal_eval(attrs['eps'])
args, aux = module.get_params()
gamma = args[_get_node_name(net, inputs[1][0])].asnumpy()
@@ -511,7 +535,7 @@ def convert_deconvolution(net, node, module, builder):
"""
input_name, output_name = _get_input_output_name(net, node)
name = node['name']
- param = node['attr']
+ param = _get_attrs(node)
inputs = node['inputs']
args, _ = module.get_params()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].