This is an automated email from the ASF dual-hosted git repository.

manuseth pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 0f57c64  Reuse params from cached_op_args (#20221)
0f57c64 is described below

commit 0f57c64110231e9e24dcce7d2d8e077125e11f30
Author: Sam Skalicky <[email protected]>
AuthorDate: Wed Apr 28 13:08:15 2021 -0700

    Reuse params from cached_op_args (#20221)
    
    * initial commit
    
    * fixed handling
    
    * fixed formatting
    
    Co-authored-by: Ubuntu <[email protected]>
---
 python/mxnet/gluon/block.py | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d415c5f..f8dc1bc 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1015,6 +1015,7 @@ class HybridBlock(Block):
         super(HybridBlock, self).__init__(prefix=prefix, params=params)
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = []
         self._out_format = None
         self._in_format = None
         self._active = False
@@ -1066,10 +1067,17 @@ class HybridBlock(Block):
     def _build_cache(self, *args):
         data, out = self._get_graph(*args)
         data_names = {data.name: i for i, data in enumerate(data)}
-        params = self.collect_params()
         input_names = out.list_inputs()
-        param_names = set(params.keys())
         expected_names = set(input_names)
+
+        # try to reuse cached_op_args for params
+        if len(self._cached_op_args) > 0:
+            params = {param_tuple[1].name:param_tuple[1]
+                      for param_tuple in self._cached_op_args
+                      if isinstance(param_tuple[1], Parameter)}
+        else:
+            params = self.collect_params()
+        param_names = set(params.keys())
         for name in expected_names:
             assert name in param_names or name in data_names, \
                 "Unknown input to HybridBlock: %s" %name
@@ -1280,10 +1288,11 @@ class HybridBlock(Block):
         """
         if len(kwargs) > 0:
             self._backend_opts = kwargs
+        if not backend:
+            raise ValueError('Must specify "backend" to optimize_for')
 
-        if clear or not self._active:
-            self.hybridize(True, backend, clear, static_alloc, static_shape,
-                           inline_limit, forward_bulk_size, backward_bulk_size)
+        self.hybridize(True, backend, clear, static_alloc, static_shape,
+                       inline_limit, forward_bulk_size, backward_bulk_size)
 
         # do part of forward API call
         has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + 
list(args))
@@ -1307,6 +1316,7 @@ class HybridBlock(Block):
     def _clear_cached_op(self):
         self._cached_graph = ()
         self._cached_op = None
+        self._cached_op_args = []
 
     def register_child(self, block, name=None):
         if not isinstance(block, HybridBlock):

Reply via email to