You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/19 20:53:22 UTC

[GitHub] piiswrong closed pull request #8698: [COREML] Update the json getter

piiswrong closed pull request #8698: [COREML] Update the json getter
URL: https://github.com/apache/incubator-mxnet/pull/8698
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/tools/coreml/converter/_layers.py b/tools/coreml/converter/_layers.py
index fe00232828..4c5ebc6fb0 100644
--- a/tools/coreml/converter/_layers.py
+++ b/tools/coreml/converter/_layers.py
@@ -38,6 +38,30 @@ def _get_node_name(net, node_id):
 def _get_node_shape(net, node_id):
     return net['nodes'][node_id]['shape']
 
+def _get_attrs(node):
+    """get attribute dict from node
+
+    This functions keeps backward compatibility
+    for both attr and attrs key in the json field.
+
+    Parameters
+    ----------
+    node : dict
+       The json graph Node
+
+    Returns
+    -------
+    attrs : dict
+       The attr dict, returns empty dict if
+       the field do not exist.
+    """
+    if 'attrs' in node:
+        return node['attrs']
+    elif 'attr' in node:
+        return node['attr']
+    else:
+        return {}
+
 
 # TODO These operators still need to be converted (listing in order of priority):
 # High priority:
@@ -108,7 +132,7 @@ def convert_transpose(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
 
     axes = literal_eval(param['axes'])
     builder.add_permute(name, axes, input_name, output_name)
@@ -180,7 +204,7 @@ def convert_activation(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    mx_non_linearity = node['attr']['act_type']
+    mx_non_linearity = _get_attrs(node)['act_type']
     #TODO add SCALED_TANH, SOFTPLUS, SOFTSIGN, SIGMOID_HARD, LEAKYRELU, PRELU, ELU, PARAMETRICSOFTPLUS, THRESHOLDEDRELU, LINEAR
     if mx_non_linearity == 'relu':
         non_linearity = 'RELU'
@@ -281,7 +305,7 @@ def convert_convolution(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
     inputs = node['inputs']
     args, _ = module.get_params()
 
@@ -361,7 +385,7 @@ def convert_pooling(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
 
     layer_type_mx = param['pool_type']
     if layer_type_mx == 'max':
@@ -445,9 +469,9 @@ def convert_batchnorm(net, node, module, builder):
 
     eps = 1e-3 # Default value of eps for MXNet.
     use_global_stats = False # Default value of use_global_stats for MXNet.
-    if 'attr' in node:
-        if 'eps' in node['attr']:
-            eps = literal_eval(node['attr']['eps'])
+    attrs = _get_attrs(node)
+    if 'eps' in attrs:
+        eps = literal_eval(attrs['eps'])
 
     args, aux = module.get_params()
     gamma = args[_get_node_name(net, inputs[1][0])].asnumpy()
@@ -511,7 +535,7 @@ def convert_deconvolution(net, node, module, builder):
     """
     input_name, output_name = _get_input_output_name(net, node)
     name = node['name']
-    param = node['attr']
+    param = _get_attrs(node)
     inputs = node['inputs']
     args, _ = module.get_params()
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services