This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 181a2e2 fix unidirectional model's parameter format (#12055)
181a2e2 is described below
commit 181a2e28a61e9406318a8ba08f36fc4c8e23e8c1
Author: Sheng Zha <[email protected]>
AuthorDate: Wed Aug 8 11:14:56 2018 -0700
fix unidirectional model's parameter format (#12055)
* fix unidirectional model's parameter format
* Update rnn_layer.py
---
python/mxnet/gluon/rnn/rnn_layer.py | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py
b/python/mxnet/gluon/rnn/rnn_layer.py
index 4a7a0be..d2c6ac9 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -21,6 +21,8 @@
# pylint: disable=too-many-lines, arguments-differ
"""Definition of various recurrent neural network layers."""
from __future__ import print_function
+import re
+
__all__ = ['RNN', 'LSTM', 'GRU']
from ... import ndarray, symbol
@@ -92,10 +94,17 @@ class _RNNLayer(HybridBlock):
def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
- def convert_key(key): # for compatibility with old parameter format
- key = key.split('_')
- return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0],
'_'.join(key[1:]))
- ret = {prefix + convert_key(key) : val for key, val in
self._reg_params.items()}
+ pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
+ def convert_key(m, bidirectional): # for compatibility with old
parameter format
+ d, l, g, t = [m.group(i) for i in range(1, 5)]
+ if bidirectional:
+ return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t)
+ else:
+ return '_unfused.{}.{}_{}'.format(l, g, t)
+ bidirectional = any(pattern.match(k).group(1) == 'r' for k in
self._reg_params)
+
+ ret = {prefix + convert_key(pattern.match(key), bidirectional) : val
+ for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret