This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b225fa5  Python level HybridBlock export API (#19220)
b225fa5 is described below

commit b225fa5d3c9b6d9ba10533fcbf20a20c6aa96aff
Author: Leonard Lausen <[email protected]>
AuthorDate: Fri Sep 25 10:20:21 2020 -0700

    Python level HybridBlock export API (#19220)
    
    Allow HybridBlock.export without writing to a file, but obtain symbol and 
param objects instead.
    This is useful for users of HybridBlock.forward that would like to obtain a 
legacy Symbol object without accessing HybridBlock._cached_op internal data 
structure.
---
 python/mxnet/gluon/block.py         | 27 ++++++++++++++++++++-------
 tests/python/unittest/test_gluon.py |  6 ++++--
 2 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d430aee..a89d4bc 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -31,7 +31,7 @@ import contextvars
 import re
 import numpy as np
 
-from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str
+from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, 
check_call, _LIB
 from .. import symbol, ndarray, initializer, autograd, _deferred_compute as 
dc, name as _name, \
     profiler as _profiler, context as _context
 from ..symbol.numpy import _symbol as np_symbol
@@ -1306,13 +1306,16 @@ class HybridBlock(Block):
 
         Parameters
         ----------
-        path : str
+        path : str or None
             Path to save model. Two files `path-symbol.json` and 
`path-xxxx.params`
             will be created, where xxxx is the 4 digits epoch number.
+            If None, do not export to file but return Python Symbol object and
+            corresponding dictionary of parameters.
         epoch : int
             Epoch number of saved model.
         remove_amp_cast : bool, optional
             Whether to remove the amp_cast and amp_multicast operators, before 
saving the model.
+
         Returns
         -------
         symbol_filename : str
@@ -1338,8 +1341,9 @@ class HybridBlock(Block):
             if var.name in rename_map:
                 var._set_attr(name=rename_map[var.name])
 
-        sym_filename = '%s-symbol.json'%path
-        sym.save(sym_filename, remove_amp_cast=remove_amp_cast)
+        sym_filename = '%s-symbol.json' % (path if path is not None else "")
+        if path is not None:
+            sym.save(sym_filename, remove_amp_cast=remove_amp_cast)
 
         arg_names = set(sym.list_arguments())
         aux_names = set(sym.list_auxiliary_states())
@@ -1355,9 +1359,18 @@ class HybridBlock(Block):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)
+        params_filename = '%s-%04d.params'%((path if path is not None else 
""), epoch)
+
+        if path is not None:
+            save_fn(params_filename, arg_dict)
+            return (sym_filename, params_filename)
+
+        if remove_amp_cast:
+            handle = SymbolHandle()
+            import ctypes
+            check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle, 
ctypes.byref(handle)))
+            sym = type(sym)(handle)
+        return sym, arg_dict
 
     def register_op_hook(self, callback, monitor_all=False):
         """Install op hook for block recursively.
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index eb07f43..6386fc8 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1532,8 +1532,10 @@ def test_symbol_block_save_load(tmpdir):
             backbone.initialize()
             backbone.hybridize()
             backbone(mx.nd.random.normal(shape=(1, 3, 32, 32)))
-            sym_file, params_file = backbone.export(tmpfile)
-            self.backbone = gluon.SymbolBlock.imports(sym_file, 'data', 
params_file)
+            sym, params = backbone.export(None)
+            data = mx.sym.var('data')
+            self.backbone = gluon.SymbolBlock(sym, data)
+            self.backbone.load_dict(params)
             self.body = nn.Conv2D(3, 1)
 
         def hybrid_forward(self, F, x):

Reply via email to