PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824647707



##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"
+    assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, 
list), \
+        "conditional_fp32_ops should be a list"
 
-    if target_dtype == 'bfloat16':
-        target_dtype = bfloat16
+    target_dtype = get_dtype_name(target_dtype)
+    assert target_dtype in ['float16', *bfloat16.names], \
+        "Only float16 and bfloat16 types are currently supported as 
target_dtype"
 
-    if target_dtype_ops is not None:
-        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be 
a list of strs"
-    else:
+    if target_dtype_ops is None:
         target_dtype_ops = list_lp16_ops(target_dtype)
-
-    if fp32_ops is not None:
-        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
-    else:
+    if fp32_ops is None:
         fp32_ops = list_fp32_ops(target_dtype)
 
-    if conditional_fp32_ops is not None:
-        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops 
should be a list"
-    else:
+    # conditional ops
+    if conditional_fp32_ops is None:
         conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+    cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+    for cond_op in conditional_fp32_ops:
+        op_name, attr_name, attr_vals = cond_op
+        assert isinstance(op_name, str) and isinstance(attr_name, str) and 
isinstance(attr_vals, list), \
+            "conditional_fp32_ops should be a list of (str, str, list of str)"
+        cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+    nodes_attr = sym.attr_dict()
+    nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+    assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \

Review comment:
       I changed it to use `logging.warning(...)` since this is how logging is 
handled in this file. Do you think we should instead add a logger as an 
argument to the conversion functions?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to