This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 38a032c [COREML] Update the json getter (#8698) 38a032c is described below commit 38a032c886f56f94cdad004c89fd4e1926f85ba6 Author: Tianqi Chen <tqc...@users.noreply.github.com> AuthorDate: Sun Nov 19 12:53:19 2017 -0800 [COREML] Update the json getter (#8698) * [COREML] Update the json getter * add docstring --- tools/coreml/converter/_layers.py | 40 +++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/tools/coreml/converter/_layers.py b/tools/coreml/converter/_layers.py index fe00232..4c5ebc6 100644 --- a/tools/coreml/converter/_layers.py +++ b/tools/coreml/converter/_layers.py @@ -38,6 +38,30 @@ def _get_node_name(net, node_id): def _get_node_shape(net, node_id): return net['nodes'][node_id]['shape'] +def _get_attrs(node): + """get attribute dict from node + + This functions keeps backward compatibility + for both attr and attrs key in the json field. + + Parameters + ---------- + node : dict + The json graph Node + + Returns + ------- + attrs : dict + The attr dict, returns empty dict if + the field do not exist. + """ + if 'attrs' in node: + return node['attrs'] + elif 'attr' in node: + return node['attr'] + else: + return {} + # TODO These operators still need to be converted (listing in order of priority): # High priority: @@ -108,7 +132,7 @@ def convert_transpose(net, node, module, builder): """ input_name, output_name = _get_input_output_name(net, node) name = node['name'] - param = node['attr'] + param = _get_attrs(node) axes = literal_eval(param['axes']) builder.add_permute(name, axes, input_name, output_name) @@ -180,7 +204,7 @@ def convert_activation(net, node, module, builder): """ input_name, output_name = _get_input_output_name(net, node) name = node['name'] - mx_non_linearity = node['attr']['act_type'] + mx_non_linearity = _get_attrs(node)['act_type'] #TODO add SCALED_TANH, SOFTPLUS, SOFTSIGN, SIGMOID_HARD, LEAKYRELU, PRELU, ELU, PARAMETRICSOFTPLUS, THRESHOLDEDRELU, LINEAR if mx_non_linearity == 'relu': non_linearity = 'RELU' @@ -281,7 +305,7 @@ def convert_convolution(net, node, module, builder): """ input_name, output_name = _get_input_output_name(net, node) name = node['name'] - param = node['attr'] + param = _get_attrs(node) inputs = node['inputs'] args, _ = module.get_params() @@ -361,7 +385,7 @@ def convert_pooling(net, node, module, builder): """ input_name, output_name = _get_input_output_name(net, node) name = node['name'] - param = node['attr'] + param = _get_attrs(node) layer_type_mx = param['pool_type'] if layer_type_mx == 'max': @@ -445,9 +469,9 @@ def convert_batchnorm(net, node, module, builder): eps = 1e-3 # Default value of eps for MXNet. use_global_stats = False # Default value of use_global_stats for MXNet. - if 'attr' in node: - if 'eps' in node['attr']: - eps = literal_eval(node['attr']['eps']) + attrs = _get_attrs(node) + if 'eps' in attrs: + eps = literal_eval(attrs['eps']) args, aux = module.get_params() gamma = args[_get_node_name(net, inputs[1][0])].asnumpy() @@ -511,7 +535,7 @@ def convert_deconvolution(net, node, module, builder): """ input_name, output_name = _get_input_output_name(net, node) name = node['name'] - param = node['attr'] + param = _get_attrs(node) inputs = node['inputs'] args, _ = module.get_params() -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].