PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824568944
##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16",
target_dtype_ops=None,
from being casted to LP16 or FP32.
data_names : list of strs, optional
A list of strings that represent input data tensor names to the model
- cast_optional_params : bool, default False
+ cast_params_offline : bool, default False
Whether to cast the arg_params and aux_params that don't require to be
in LP16
because of a cast layer following it, but will reduce the computation
and memory
overhead of the model if casted.
"""
- assert isinstance(sym, Symbol), "First argument to convert_symbol should
be Symbol"
+ import json
- assert target_dtype in ['float16', 'bfloat16'], \
- "Only target_dtype float16 and bfloat16 are supported currently"
+ assert isinstance(sym, Symbol), "First argument to convert_symbol should
be a Symbol"
+ assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+ "target_dtype_ops should be a list of strs"
+ assert fp32_ops is None or isinstance(fp32_ops, list), \
+ "fp32_ops should be a list of strs"
+ assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops,
list), \
+ "conditional_fp32_ops should be a list"
- if target_dtype == 'bfloat16':
- target_dtype = bfloat16
+ target_dtype = get_dtype_name(target_dtype)
+ assert target_dtype in ['float16', *bfloat16.names], \
+ "Only float16 and bfloat16 types are currently supported as
target_dtype"
- if target_dtype_ops is not None:
- assert isinstance(target_dtype_ops, list), "target_dtype_ops should be
a list of strs"
- else:
+ if target_dtype_ops is None:
target_dtype_ops = list_lp16_ops(target_dtype)
-
- if fp32_ops is not None:
- assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
- else:
+ if fp32_ops is None:
fp32_ops = list_fp32_ops(target_dtype)
- if conditional_fp32_ops is not None:
- assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops
should be a list"
- else:
+ # conditional ops
+ if conditional_fp32_ops is None:
conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+ cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+ for cond_op in conditional_fp32_ops:
+ op_name, attr_name, attr_vals = cond_op
+ assert isinstance(op_name, str) and isinstance(attr_name, str) and
isinstance(attr_vals, list), \
+ "conditional_fp32_ops should be a list of (str, str, list of str)"
+ cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+ nodes_attr = sym.attr_dict()
+ nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+ assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \
+ "excluded_sym_names are not present in the network. Missing layers:
{}".format(
+ set(excluded_sym_names) - set(nodes_op.keys()))
+
+ for node_name, node_op in nodes_op.items():
+ if node_op not in cond_ops:
+ continue
+ node_attrs = nodes_attr[node_name]
+ for attr_name, attr_vals in cond_ops[node_op].items():
+ assert attr_name in node_attrs
+ if node_attrs[attr_name] in attr_vals:
+ excluded_sym_names += node_name
+ break
+ excluded_sym_names = list(set(excluded_sym_names))
- original_conditional_op_names = []
- conditional_op_names = []
- param_names = []
- param_vals = []
- indptr = [0]
- for conditional_fp32_op in conditional_fp32_ops:
- assert isinstance(conditional_fp32_op[0], str) and
isinstance(conditional_fp32_op[1], str) \
- and isinstance(conditional_fp32_op[2], list),
"conditional_fp32_ops should be a list of " \
- "(str, str, list of
str)"
- param_vals += conditional_fp32_op[2]
- indptr.append(len(param_vals))
- param_names.append(conditional_fp32_op[1])
- conditional_op_names.append(conditional_fp32_op[0])
-
- if excluded_sym_names is not None:
- assert isinstance(excluded_sym_names, list), "excluded_sym_names
should be a list of strs"
- else:
- excluded_sym_names = []
-
- for original_conditional_fp32_op in
list_conditional_fp32_ops(target_dtype):
- original_conditional_op_names.append(original_conditional_fp32_op[0])
-
- # Op lists should not have intersection
+ # Op lists should not intersect
common_ops = set(target_dtype_ops) & set(fp32_ops)
- assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
- "Common ops in target_dtype_ops and fp32_ops
{}".format(common_ops)
- common_ops = set(target_dtype_ops) & set(conditional_op_names)
- assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
- "Common ops in target_dtype_ops and
conditional_fp32_ops {}".format(common_ops)
- common_ops = set(conditional_op_names) & set(fp32_ops)
- assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
- "Common ops in fp32_ops and
conditional_fp32_ops {}".format(common_ops)
-
- combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
- all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) +
list_fp32_ops(target_dtype)
- + list_lp16_fp32_ops(target_dtype) +
original_conditional_op_names)
+ assert len(common_ops) == 0, "Common ops in target_dtype_ops and fp32_ops:
{}".format(common_ops)
+ common_ops = set(target_dtype_ops) & set(cond_ops)
+ assert len(common_ops) == 0, "Common ops in target_dtype_ops and
conditional_fp32_ops: {}".format(
+ common_ops)
+ common_ops = set(cond_ops) & set(fp32_ops)
+ assert len(common_ops) == 0, "Common ops in fp32_ops and
conditional_fp32_ops: {}".format(common_ops)
+
+ combined_ops = set(target_dtype_ops + fp32_ops + list(cond_ops.keys()))
+ original_cond_ops = [cond_op[0] for cond_op in
list_conditional_fp32_ops(target_dtype)]
Review comment:
The list from line 491 may contain conditional ops specified by the
user. Here we are creating a list of all ops listed in the `symbol_bf16.py`
file (for bfloat16) or `symbol_fp16.py` file (for float16).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]