This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 181a2e2  fix unidirectional model's parameter format (#12055)
181a2e2 is described below

commit 181a2e28a61e9406318a8ba08f36fc4c8e23e8c1
Author: Sheng Zha <[email protected]>
AuthorDate: Wed Aug 8 11:14:56 2018 -0700

    fix unidirectional model's parameter format (#12055)
    
    * fix unidirectional model's parameter format
    
    * Update rnn_layer.py
---
 python/mxnet/gluon/rnn/rnn_layer.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index 4a7a0be..d2c6ac9 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -21,6 +21,8 @@
 # pylint: disable=too-many-lines, arguments-differ
 """Definition of various recurrent neural network layers."""
 from __future__ import print_function
+import re
+
 __all__ = ['RNN', 'LSTM', 'GRU']
 
 from ... import ndarray, symbol
@@ -92,10 +94,17 @@ class _RNNLayer(HybridBlock):
     def _collect_params_with_prefix(self, prefix=''):
         if prefix:
             prefix += '.'
-        def convert_key(key): # for compatibility with old parameter format
-            key = key.split('_')
-            return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], 
'_'.join(key[1:]))
-        ret = {prefix + convert_key(key) : val for key, val in 
self._reg_params.items()}
+        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
+        def convert_key(m, bidirectional): # for compatibility with old 
parameter format
+            d, l, g, t = [m.group(i) for i in range(1, 5)]
+            if bidirectional:
+                return '_unfused.{}.{}_cell.{}_{}'.format(l, d, g, t)
+            else:
+                return '_unfused.{}.{}_{}'.format(l, g, t)
+        bidirectional = any(pattern.match(k).group(1) == 'r' for k in 
self._reg_params)
+
+        ret = {prefix + convert_key(pattern.match(key), bidirectional) : val
+               for key, val in self._reg_params.items()}
         for name, child in self._children.items():
             ret.update(child._collect_params_with_prefix(prefix + name))
         return ret

Reply via email to