PawelGlomski-Intel commented on code in PR #20983:
URL: https://github.com/apache/incubator-mxnet/pull/20983#discussion_r853259275
##########
python/mxnet/amp/amp.py:
##########
@@ -497,22 +495,30 @@ def convert_symbol(sym, input_dtypes, param_dtypes,
target_dtype="float16", targ
"conditional_fp32_ops should be a list of (str, str, list of str)"
cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
- nodes_attr = sym.attr_dict()
+ nodes_attrs = sym.attr_dict()
nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
- if not set(excluded_sym_names).issubset(set(nodes_op.keys())):
- logging.warning("excluded_sym_names are not present in the network.
Missing layers: {}".format(
- set(excluded_sym_names) - set(nodes_op.keys())))
-
for node_name, node_op in nodes_op.items():
if node_op not in cond_ops:
continue
- node_attrs = nodes_attr[node_name]
+ node_attrs = nodes_attrs[node_name]
for attr_name, attr_vals in cond_ops[node_op].items():
assert attr_name in node_attrs
if node_attrs[attr_name] in attr_vals:
- excluded_sym_names += node_name
+ excluded_sym_names.append(node_name)
break
- excluded_sym_names = list(set(excluded_sym_names))
+
+ excluded_sym_names = set(excluded_sym_names)
+ for node in sym.get_internals():
+ if node.name in excluded_sym_names:
+ excluded_sym_names.remove(node.name)
+ opt_constraints = node.attr('__opt_constraint__')
+ opt_constraints = 0 if opt_constraints is None else opt_constraints
Review Comment:
Good catch! I believe all attributes are strings actually. I added a test
covering this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]