You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/19 20:53:21 UTC

[incubator-mxnet] branch master updated: [COREML] Update the json getter (#8698)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 38a032c  [COREML] Update the json getter (#8698)
38a032c is described below

commit 38a032c886f56f94cdad004c89fd4e1926f85ba6
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Nov 19 12:53:19 2017 -0800

    [COREML] Update the json getter (#8698)
    
    * [COREML] Update the json getter
    
    * add docstring
---
 tools/coreml/converter/_layers.py | 40 +++++++++++++++++++++++++++++++--------
 1 file changed, 32 insertions(+), 8 deletions(-)

diff --git a/tools/coreml/converter/_layers.py b/tools/coreml/converter/_layers.py
index fe00232..4c5ebc6 100644
--- a/tools/coreml/converter/_layers.py
+++ b/tools/coreml/converter/_layers.py
@@ -38,6 +38,30 @@ def _get_node_name(net, node_id):
 def _get_node_shape(net, node_id):
     return net['nodes'][node_id]['shape']
 
+def _get_attrs(node):
+    """get attribute dict from node
+
+    This functions keeps backward compatibility
+    for both attr and attrs key in the json field.
+
+    Parameters
+    ----------
+    node : dict
+       The json graph Node
+
+    Returns
+    -------
+    attrs : dict
+       The attr dict, returns empty dict if
+       the field do not exist.
+    """
+    if 'attrs' in node:
+        return node['attrs']
+    elif 'attr' in node:
+        return node['attr']
+    else:
+        return {}
+
 
 # TODO These operators still need to be converted (listing in order of priority):
 # High priority:
@@ -108,7 +132,7 @@ def convert_transpose(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
 
     axes = literal_eval(param['axes'])
     builder.add_permute(name, axes, input_name, output_name)
@@ -180,7 +204,7 @@ def convert_activation(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    mx_non_linearity = node['attr']['act_type']
+    mx_non_linearity = _get_attrs(node)['act_type']
     #TODO add SCALED_TANH, SOFTPLUS, SOFTSIGN, SIGMOID_HARD, LEAKYRELU, PRELU, ELU, PARAMETRICSOFTPLUS, THRESHOLDEDRELU, LINEAR
     if mx_non_linearity == 'relu':
         non_linearity = 'RELU'
@@ -281,7 +305,7 @@ def convert_convolution(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
     inputs = node['inputs']
     args, _ = module.get_params()
 
@@ -361,7 +385,7 @@ def convert_pooling(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
 
     layer_type_mx = param['pool_type']
     if layer_type_mx == 'max':
@@ -445,9 +469,9 @@ def convert_batchnorm(net, node, module, builder):
 
     eps = 1e-3 # Default value of eps for MXNet.
     use_global_stats = False # Default value of use_global_stats for MXNet.
-    if 'attr' in node:
-        if 'eps' in node['attr']:
-            eps = literal_eval(node['attr']['eps'])
+    attrs = _get_attrs(node)
+    if 'eps' in attrs:
+        eps = literal_eval(attrs['eps'])
 
     args, aux = module.get_params()
     gamma = args[_get_node_name(net, inputs[1][0])].asnumpy()
@@ -511,7 +535,7 @@ def convert_deconvolution(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
     inputs = node['inputs']
     args, _ = module.get_params()
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].