PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r828335077
##########
File path: python/mxnet/amp/amp.py
##########
@@ -700,79 +670,60 @@ def convert_hybrid_block(block, target_dtype="float16",
target_dtype_ops=None,
from being quantized
device : Context
Context on which model parameters should live
- cast_optional_params : bool, default False
+ cast_params_offline : bool, default False
Whether to cast the arg_params and aux_params that don't require to be
in LP16
because of a cast layer following it, but will reduce the computation
and memory
overhead of the model if casted.
"""
from ..gluon import HybridBlock, SymbolBlock
+ from ..ndarray import NDArray as ND_NDArray
+ from ..numpy import ndarray as NP_NDArray
+
assert isinstance(block, HybridBlock), "block input should be a
HybridBlock"
+ if not isinstance(data_example, (list, tuple)):
+ assert isinstance(data_example, (ND_NDArray, NP_NDArray))
+ data_example = [data_example]
+
if not block._cached_graph:
- raise RuntimeError(
- "Please first call block.hybridize() and then run forward with "
- "this block at least once before calling export.")
-
- # Prepare inputs to pass to the convert_symbol API
- inputs, sym = block._cached_graph
- input_names = []
- for inp in inputs:
- input_names.append(inp.name)
- converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
- fp32_ops, conditional_fp32_ops,
- excluded_sym_names, data_names=input_names,
- cast_optional_params=cast_optional_params)
-
- arg_names = set(converted_sym.list_arguments())
- aux_names = set(converted_sym.list_auxiliary_states())
- arg_dict = {}
-
- # If dtype for the param was set in the json, cast the
- # param to this dtype
- attr_dict = converted_sym.attr_dict()
- for param in block.collect_params().values():
- name = param.name
- if name in arg_names:
- arg_dict['arg:%s'%name] = param._reduce()
- if name in attr_dict and "__dtype__" in attr_dict[name]:
- if attr_dict[name]["__dtype__"] != "-1":
- typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
- if typ == bfloat16:
- arg_dict['arg:%s' % name] =
_cast_symbol_NDArray(arg_dict['arg:%s' % name], bfloat16)
- else:
- arg_dict['arg:%s'%name] =
arg_dict['arg:%s'%name].astype(typ)
+ block.hybridize()
Review comment:
I will handle this as it is handled in quantization - explicitly pass
`static_alloc=False, static_shape=False` arguments.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]