samskalicky commented on a change in pull request #19220:
URL: https://github.com/apache/incubator-mxnet/pull/19220#discussion_r494683927



##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)
+        params_filename = '%s-%04d.params'%((path if path is not None else 
""), epoch)
+
+        if path is not None:
+            save_fn(params_filename, arg_dict)
+            return (sym_filename, params_filename)
+
+        if remove_amp_cast:
+            handle = SymbolHandle()
+            import ctypes
+            check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle, 
ctypes.byref(handle)))
+            sym = type(sym)(handle)
+        return sym, arg_dict

Review comment:
       What about aux?

##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)

Review comment:
       is there any reason to preserve old behavior when actually exporting to 
a file?

##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)

Review comment:
       is there any reason to preserve old behavior (returning names) when 
actually exporting to a file?

##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)

Review comment:
       no i meant `return (sym_filename, params_filename)`

##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)
+        params_filename = '%s-%04d.params'%((path if path is not None else 
""), epoch)
+
+        if path is not None:
+            save_fn(params_filename, arg_dict)
+            return (sym_filename, params_filename)
+
+        if remove_amp_cast:
+            handle = SymbolHandle()
+            import ctypes
+            check_call(_LIB.MXSymbolRemoveAmpCast(sym.handle, 
ctypes.byref(handle)))
+            sym = type(sym)(handle)
+        return sym, arg_dict

Review comment:
       oh ok, saw that in the other comment. makes sense. although the variable 
names are confusing: `arg_names`, `aux_names` and then putting both args/aux 
into `arg_dict` :-) but thats not your code here...

##########
File path: python/mxnet/gluon/block.py
##########
@@ -1355,9 +1359,18 @@ def export(self, path, epoch=0, remove_amp_cast=True):
                     else:
                         arg_dict['aux:%s'%name] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
-        params_filename = '%s-%04d.params'%(path, epoch)
-        save_fn(params_filename, arg_dict)
-        return (sym_filename, params_filename)

Review comment:
       cool, looks like you're already doing that! 👍 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to