bgawrych commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r828872365



##########
File path: python/mxnet/amp/amp.py
##########
@@ -700,79 +670,60 @@ def convert_hybrid_block(block, target_dtype="float16", 
target_dtype_ops=None,
         from being quantized
     device : Context
         Context on which model parameters should live
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
     from ..gluon import HybridBlock, SymbolBlock
+    from ..ndarray import NDArray as ND_NDArray
+    from ..numpy import ndarray as NP_NDArray
+
     assert isinstance(block, HybridBlock), "block input should be a 
HybridBlock"
+    if not isinstance(data_example, (list, tuple)):
+        assert isinstance(data_example, (ND_NDArray, NP_NDArray))
+        data_example = [data_example]
+
     if not block._cached_graph:
-        raise RuntimeError(
-            "Please first call block.hybridize() and then run forward with "
-            "this block at least once before calling export.")
-
-    # Prepare inputs to pass to the convert_symbol API
-    inputs, sym = block._cached_graph
-    input_names = []
-    for inp in inputs:
-        input_names.append(inp.name)
-    converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
-                                   fp32_ops, conditional_fp32_ops,
-                                   excluded_sym_names, data_names=input_names,
-                                   cast_optional_params=cast_optional_params)
-
-    arg_names = set(converted_sym.list_arguments())
-    aux_names = set(converted_sym.list_auxiliary_states())
-    arg_dict = {}
-
-    # If dtype for the param was set in the json, cast the
-    # param to this dtype
-    attr_dict = converted_sym.attr_dict()
-    for param in block.collect_params().values():
-        name = param.name
-        if name in arg_names:
-            arg_dict['arg:%s'%name] = param._reduce()
-            if name in attr_dict and "__dtype__" in attr_dict[name]:
-                if attr_dict[name]["__dtype__"] != "-1":
-                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
-                    if typ == bfloat16:
-                        arg_dict['arg:%s' % name] = 
_cast_symbol_NDArray(arg_dict['arg:%s' % name], bfloat16)
-                    else:
-                        arg_dict['arg:%s'%name] = 
arg_dict['arg:%s'%name].astype(typ)
+        block.hybridize()
+    block(*data_example)

Review comment:
       Seems like it is not necessary - for peace of mind it can be added




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to