PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824647707
##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16",
target_dtype_ops=None,
from being casted to LP16 or FP32.
data_names : list of strs, optional
A list of strings that represent input data tensor names to the model
- cast_optional_params : bool, default False
+ cast_params_offline : bool, default False
Whether to cast the arg_params and aux_params that don't require to be
in LP16
because of a cast layer following it, but will reduce the computation
and memory
overhead of the model if casted.
"""
- assert isinstance(sym, Symbol), "First argument to convert_symbol should
be Symbol"
+ import json
- assert target_dtype in ['float16', 'bfloat16'], \
- "Only target_dtype float16 and bfloat16 are supported currently"
+ assert isinstance(sym, Symbol), "First argument to convert_symbol should
be a Symbol"
+ assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+ "target_dtype_ops should be a list of strs"
+ assert fp32_ops is None or isinstance(fp32_ops, list), \
+ "fp32_ops should be a list of strs"
+ assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops,
list), \
+ "conditional_fp32_ops should be a list"
- if target_dtype == 'bfloat16':
- target_dtype = bfloat16
+ target_dtype = get_dtype_name(target_dtype)
+ assert target_dtype in ['float16', *bfloat16.names], \
+ "Only float16 and bfloat16 types are currently supported as
target_dtype"
- if target_dtype_ops is not None:
- assert isinstance(target_dtype_ops, list), "target_dtype_ops should be
a list of strs"
- else:
+ if target_dtype_ops is None:
target_dtype_ops = list_lp16_ops(target_dtype)
-
- if fp32_ops is not None:
- assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
- else:
+ if fp32_ops is None:
fp32_ops = list_fp32_ops(target_dtype)
- if conditional_fp32_ops is not None:
- assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops
should be a list"
- else:
+ # conditional ops
+ if conditional_fp32_ops is None:
conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+ cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+ for cond_op in conditional_fp32_ops:
+ op_name, attr_name, attr_vals = cond_op
+ assert isinstance(op_name, str) and isinstance(attr_name, str) and
isinstance(attr_vals, list), \
+ "conditional_fp32_ops should be a list of (str, str, list of str)"
+ cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+ nodes_attr = sym.attr_dict()
+ nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+ assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \
Review comment:
I changed it to use `logging.warning(...)` since this is how logging is
handled in this file. Do you think we should instead add a logger as an
argument to the conversion functions?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]