PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824568944



##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"
+    assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, 
list), \
+        "conditional_fp32_ops should be a list"
 
-    if target_dtype == 'bfloat16':
-        target_dtype = bfloat16
+    target_dtype = get_dtype_name(target_dtype)
+    assert target_dtype in ['float16', *bfloat16.names], \
+        "Only float16 and bfloat16 types are currently supported as 
target_dtype"
 
-    if target_dtype_ops is not None:
-        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be 
a list of strs"
-    else:
+    if target_dtype_ops is None:
         target_dtype_ops = list_lp16_ops(target_dtype)
-
-    if fp32_ops is not None:
-        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
-    else:
+    if fp32_ops is None:
         fp32_ops = list_fp32_ops(target_dtype)
 
-    if conditional_fp32_ops is not None:
-        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops 
should be a list"
-    else:
+    # conditional ops
+    if conditional_fp32_ops is None:
         conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+    cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+    for cond_op in conditional_fp32_ops:
+        op_name, attr_name, attr_vals = cond_op
+        assert isinstance(op_name, str) and isinstance(attr_name, str) and 
isinstance(attr_vals, list), \
+            "conditional_fp32_ops should be a list of (str, str, list of str)"
+        cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+    nodes_attr = sym.attr_dict()
+    nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+    assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \
+        "excluded_sym_names are not present in the network. Missing layers: 
{}".format(
+            set(excluded_sym_names) - set(nodes_op.keys()))
+
+    for node_name, node_op in nodes_op.items():
+        if node_op not in cond_ops:
+            continue
+        node_attrs = nodes_attr[node_name]
+        for attr_name, attr_vals in cond_ops[node_op].items():
+            assert attr_name in node_attrs
+            if node_attrs[attr_name] in attr_vals:
+                excluded_sym_names += node_name
+                break
+    excluded_sym_names = list(set(excluded_sym_names))
 
-    original_conditional_op_names = []
-    conditional_op_names = []
-    param_names = []
-    param_vals = []
-    indptr = [0]
-    for conditional_fp32_op in conditional_fp32_ops:
-        assert isinstance(conditional_fp32_op[0], str) and 
isinstance(conditional_fp32_op[1], str) \
-            and isinstance(conditional_fp32_op[2], list), 
"conditional_fp32_ops should be a list of " \
-                                                          "(str, str, list of 
str)"
-        param_vals += conditional_fp32_op[2]
-        indptr.append(len(param_vals))
-        param_names.append(conditional_fp32_op[1])
-        conditional_op_names.append(conditional_fp32_op[0])
-
-    if excluded_sym_names is not None:
-        assert isinstance(excluded_sym_names, list), "excluded_sym_names 
should be a list of strs"
-    else:
-        excluded_sym_names = []
-
-    for original_conditional_fp32_op in 
list_conditional_fp32_ops(target_dtype):
-        original_conditional_op_names.append(original_conditional_fp32_op[0])
-
-    # Op lists should not have intersection
+    # Op lists should not intersect
     common_ops = set(target_dtype_ops) & set(fp32_ops)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in target_dtype_ops and fp32_ops 
{}".format(common_ops)
-    common_ops = set(target_dtype_ops) & set(conditional_op_names)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in target_dtype_ops and 
conditional_fp32_ops {}".format(common_ops)
-    common_ops = set(conditional_op_names) & set(fp32_ops)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in fp32_ops and 
conditional_fp32_ops {}".format(common_ops)
-
-    combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
-    all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + 
list_fp32_ops(target_dtype)
-                            + list_lp16_fp32_ops(target_dtype) + 
original_conditional_op_names)
+    assert len(common_ops) == 0, "Common ops in target_dtype_ops and fp32_ops: 
{}".format(common_ops)
+    common_ops = set(target_dtype_ops) & set(cond_ops)
+    assert len(common_ops) == 0, "Common ops in target_dtype_ops and 
conditional_fp32_ops: {}".format(
+        common_ops)
+    common_ops = set(cond_ops) & set(fp32_ops)
+    assert len(common_ops) == 0, "Common ops in fp32_ops and 
conditional_fp32_ops: {}".format(common_ops)
+
+    combined_ops = set(target_dtype_ops + fp32_ops + list(cond_ops.keys()))
+    original_cond_ops = [cond_op[0] for cond_op in 
list_conditional_fp32_ops(target_dtype)]

Review comment:
       The list from line 491 may contain conditional ops specified by the 
user. Here we are creating a list of all ops listed in the `symbol_bf16.py` 
file (for bfloat16) or `symbol_fp16.py` file (for float16).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to