This is an automated email from the ASF dual-hosted git repository.
muli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new d65d363 [caffe] support convert mtcnn and MobileNet model (#6956)
d65d363 is described below
commit d65d363be82fcf8f29994b6506a7047af29b488d
Author: joey2014 <[email protected]>
AuthorDate: Thu Aug 3 15:33:40 2017 -0500
[caffe] support convert mtcnn and MobileNet model (#6956)
* support convert mtcnn and MobileNet model
* pass python lint
* put "import re" before "import caffe_parser" as lint required
* correct missed checkin and pass pylint
---
tools/caffe_converter/convert_model.py | 7 ++++---
tools/caffe_converter/convert_symbol.py | 24 ++++++++++++++++++++++--
2 files changed, 26 insertions(+), 5 deletions(-)
diff --git a/tools/caffe_converter/convert_model.py
b/tools/caffe_converter/convert_model.py
index 2d8c994..d1e4cd0 100644
--- a/tools/caffe_converter/convert_model.py
+++ b/tools/caffe_converter/convert_model.py
@@ -3,6 +3,7 @@
from __future__ import print_function
import argparse
import sys
+import re
import caffe_parser
import mxnet as mx
import numpy as np
@@ -53,8 +54,8 @@ def convert_model(prototxt_fname, caffemodel_fname,
output_prefix=None):
or layer_type == 'Deconvolution' or layer_type == 39:
if layer_type == 'PReLU':
assert (len(layer_blobs) == 1)
- wmat = layer_blobs[0].data
weight_name = layer_name + '_gamma'
+ wmat =
np.array(layer_blobs[0].data).reshape(arg_shape_dic[weight_name])
arg_params[weight_name] = mx.nd.zeros(wmat.shape)
arg_params[weight_name][:] = wmat
continue
@@ -148,7 +149,7 @@ def convert_model(prototxt_fname, caffemodel_fname,
output_prefix=None):
aux_params[var_name] = mx.nd.zeros(var.shape)
# Get the original epsilon
for idx, layer in enumerate(layers_proto):
- if layer.name == bn_name:
+ if layer.name == bn_name or re.sub('[-/]', '_', layer.name) ==
bn_name:
bn_index = idx
eps_caffe = layers_proto[bn_index].batch_norm_param.eps
# Compensate for the epsilon shift performed in convert_symbol
@@ -180,7 +181,7 @@ def convert_model(prototxt_fname, caffemodel_fname,
output_prefix=None):
assert len(layer_blobs) == 0
if output_prefix is not None:
- model = mx.mod.Module(symbol=sym, label_names=['prob_label', ])
+ model = mx.mod.Module(symbol=sym, label_names=[arg_names[-1], ])
model.bind(data_shapes=[('data', tuple(input_dim))])
model.init_params(arg_params=arg_params, aux_params=aux_params)
model.save_checkpoint(output_prefix, 0)
diff --git a/tools/caffe_converter/convert_symbol.py
b/tools/caffe_converter/convert_symbol.py
index c384c76..100a64f 100644
--- a/tools/caffe_converter/convert_symbol.py
+++ b/tools/caffe_converter/convert_symbol.py
@@ -120,6 +120,7 @@ def _parse_proto(prototxt_fname):
flatten_count = 0
output_name = ""
prev_name = None
+ _output_name = {}
# convert reset layers one by one
for i, layer in enumerate(layers):
@@ -252,6 +253,22 @@ def _parse_proto(prototxt_fname):
for j in range(len(layer.top)):
mapping[layer.top[j]] = name
output_name = name
+ for k in range(len(layer.bottom)):
+ if layer.bottom[k] in _output_name:
+ _output_name[layer.bottom[k]]['count'] =
_output_name[layer.bottom[k]]['count']+1
+ else:
+ _output_name[layer.bottom[k]] = {'count':0}
+ for k in range(len(layer.top)):
+ if layer.top[k] in _output_name:
+ _output_name[layer.top[k]]['count'] =
_output_name[layer.top[k]]['count']+1
+ else:
+ _output_name[layer.top[k]] = {'count':0, 'name':name}
+
+ output_name = []
+ for i in _output_name:
+ if 'name' in _output_name[i] and _output_name[i]['count'] == 0:
+ output_name.append(_output_name[i]['name'])
+
return symbol_string, output_name, input_dim
def convert_symbol(prototxt_fname):
@@ -272,8 +289,11 @@ def convert_symbol(prototxt_fname):
sym, output_name, input_dim = _parse_proto(prototxt_fname)
exec(sym) # pylint: disable=exec-used
_locals = locals()
- exec("ret = " + output_name, globals(), _locals) # pylint:
disable=exec-used
- ret = _locals['ret']
+ ret = []
+ for i in output_name:
+ exec("ret = " + i, globals(), _locals) # pylint: disable=exec-used
+ ret.append(_locals['ret'])
+ ret = mx.sym.Group(ret)
return ret, input_dim
def main():
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].