This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b225fa5 Python level HybridBlock export API (#19220)
b225fa5 is described below
commit b225fa5d3c9b6d9ba10533fcbf20a20c6aa96aff
Author: Leonard Lausen <[email protected]>
AuthorDate: Fri Sep 25 10:20:21 2020 -0700
Python level HybridBlock export API (#19220)
Allow HybridBlock.export without writing to a file, but obtain symbol and
param objects instead.
This is useful for users of HybridBlock.forward that would like to obtain a
legacy Symbol object without accessing HybridBlock._cached_op internal data
structure.
---
python/mxnet/gluon/block.py | 27 ++++++++++++++++++++-------
tests/python/unittest/test_gluon.py | 6 ++++--
2 files changed, 24 insertions(+), 9 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index d430aee..a89d4bc 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -31,7 +31,7 @@ import contextvars
import re
import numpy as np
-from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str
+from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str,
check_call, _LIB
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as
dc, name as _name, \
profiler as _profiler, context as _context
from ..symbol.numpy import _symbol as np_symbol
@@ -1306,13 +1306,16 @@ class HybridBlock(Block):
Parameters
----------
- path : str
+ path : str or None
Path to save model. Two files `path-symbol.json` and
`path-xxxx.params`
will be created, where xxxx is the 4 digits epoch number.
+ If None, do not export to file but return Python Symbol object and
+ corresponding dictionary of parameters.
epoch : int
Epoch number of saved model.
remove_amp_cast : bool, optional
Whether to remove the amp_cast and amp_multicast operators, before
saving the model.
+
Returns
-------
symbol_filename : str
@@ -1338,8 +1341,9 @@ class HybridBlock(Block):
if var.name in rename_map:
var._set_attr(name=rename_map[var.name])
- sym_filename = '%s-symbol.json'%path
- sym.save(sym_filename, remove_amp_cast=remove_amp_cast)
+ sym_filename = '%s-symbol.json' % (path if path is not None else "")
+ if path is not None:
+ sym.save(sym_filename, remove_amp_cast=remove_amp_cast)
arg_names = set(sym.list_arguments())
aux_names = set(sym.list_auxiliary_states())
@@ -1355,9 +1359,18 @@ class HybridBlock(Block):
else:
arg_dict['aux:%s'%name] = param._reduce()
save_fn = _mx_npx.save if is_np_array() else ndarray.save
- params_filename = '%s-%04d.params'%(path, epoch)
- save_fn(params_filename, arg_dict)
- return (sym_filename, params_filename)
+ params_filename = '%s-%04d.params'%((path if path is not None else
""), epoch)
+
+ if path is not None:
+ save_fn(params_filename, arg_dict)
+ return (sym_filename, params_filename)
+
+ if remove_amp_cast:
+ handle = SymbolHandle()
+ import ctypes
+ check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle,
ctypes.byref(handle)))
+ sym = type(sym)(handle)
+ return sym, arg_dict
def register_op_hook(self, callback, monitor_all=False):
"""Install op hook for block recursively.
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index eb07f43..6386fc8 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1532,8 +1532,10 @@ def test_symbol_block_save_load(tmpdir):
backbone.initialize()
backbone.hybridize()
backbone(mx.nd.random.normal(shape=(1, 3, 32, 32)))
- sym_file, params_file = backbone.export(tmpfile)
- self.backbone = gluon.SymbolBlock.imports(sym_file, 'data',
params_file)
+ sym, params = backbone.export(None)
+ data = mx.sym.var('data')
+ self.backbone = gluon.SymbolBlock(sym, data)
+ self.backbone.load_dict(params)
self.body = nn.Conv2D(3, 1)
def hybrid_forward(self, F, x):