bgawrych commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r823594531



##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"

Review comment:
       ```suggestion
           "target_dtype_ops should be a list of strings"
   ```

##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"

Review comment:
       ```suggestion
           "fp32_ops should be a list of strings"
   ```

##########
File path: python/mxnet/executor.py
##########
@@ -371,11 +371,7 @@ def copy_params_from(self, arg_params, aux_params=None, 
allow_extra_params=False
         for name, array in arg_params.items():
             if name in self.arg_dict:
                 dst = self.arg_dict[name]
-                if dst.dtype == np.dtype([('bfloat16', np.uint16)]):
-                    cast_array = ndarray.amp_cast(array, dtype=dst.dtype)
-                    cast_array.copyto(dst)
-                else:
-                    array.astype(dst.dtype).copyto(dst)
+                array.astype(dst.dtype).copyto(dst)

Review comment:
       Does it work for bfloat? I man in numpy there is no bfloat dtype so 
uint16 is used to indicate that it is 16 byte data type. Will it convert data 
to bfloat or to uint16?

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {

Review comment:
       Can you document whole class? What is stored here etc.

##########
File path: src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
##########
@@ -60,6 +61,9 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLPostQuantizeAlignScalePr
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCSumFuseProperty)
     .set_attr("quantize", true);
 
+MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_AMP).set_attr("context", 
Context::CPU());

Review comment:
       Do we need new backend for this? 

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  // Converts the dtype of this NodeEntry. This should be called after a node 
has been converted and
+  // dtypes of its outputs may have changed
+  void UpdateDTtypeAfterConversion(const int new_dtype) {

Review comment:
       DTtype?

##########
File path: include/mxnet/c_api.h
##########
@@ -1992,51 +1992,40 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle,
  * casting
  * \param sym_handle symbol to be converted
  * \param ret_sym_handle mixed precision symbol result
- * \param num_args number of arguments for known dtypes
- * \param arg_type_data arg types of the arguments
  * \param target_dtype target_dtype for mixed precision symbol
- * \param cast_optional_params whether to cast optional params to target_dtype
- * \param num_target_dtype_op_names number of ops to be casted to target_dtype
- * \param num_fp32_op_names number of ops to be casted to FP32
- * \param num_widest_dtype_op_names number of ops to be casted to widest dtype
- * \param num_conditional_fp32_op_names number of ops to be casted to FP32 
based on a condition
+ * \param cast_params_offline whether to cast parameters offline to 
target_dtype
+ * \param num_inputs number of model inputs
+ * \param num_all_args number of all model arguments
+ * \param num_target_dtype_ops number of ops to be casted to target_dtype
+ * \param num_fp32_ops number of ops to be casted to FP32
+ * \param num_widest_dtype_ops number of ops to be casted to widest dtype
  * \param num_excluded_symbols number of symbols to be excluded from casting
- * \param num_model_params number of model parameters
- * \param num_widest_dtype_op_names number of ops to be casted to the widest 
dtype
- * \param num_conditional_fp32_op_names number of ops to be cast to fp32 based 
on precision
- * \param target_dtype_op_names op names to be casted to target_dtype
- * \param fp32_op_names op names to be casted to fp32
- * \param widest_dtype_op_names names to be casted to widest dtype
- * \param conditional_fp32_op_names names to be casted to FP32 conditionally
- * \param excluded_symbols symbol names to be excluded from casting
- * \param param_names param names for conditional FP32 casting
- * \param param_values param values for conditional FP32 casting
- * \param arg_names argument names for which type information is provided
- * \param model_param_names names for model parameters
+ * \param input_names_p names of model inputs
+ * \param all_arg_names_p names of all model arguments
+ * \param all_arg_types_p dtypes of all model arguments
+ * \param target_dtype_ops_p op names to be casted to target_dtype
+ * \param fp32_ops_p op names to be casted to fp32
+ * \param widest_dtype_ops_p op names to be casted to widest dtype
+ * \param excluded_syms_p symbol names to be excluded from casting
  */
 MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
                                       SymbolHandle* ret_sym_handle,
-                                      uint32_t num_args,
-                                      const int* arg_type_data,
-                                      uint32_t num_ind_ptr,
-                                      const int* ind_ptr,
-                                      const int* target_dtype,
-                                      const int cast_optional_params,
-                                      const uint32_t num_target_dtype_op_names,
-                                      const uint32_t num_fp32_op_names,
-                                      const uint32_t num_widest_dtype_op_names,
-                                      const uint32_t 
num_conditional_fp32_op_names,
+                                      const int target_dtype,
+                                      const int cast_params_offline,
+                                      const uint32_t num_inputs,
+                                      const uint32_t num_all_args,
+                                      const uint32_t num_target_dtype_ops,
+                                      const uint32_t num_fp32_ops,
+                                      const uint32_t num_widest_dtype_ops,
                                       const uint32_t num_excluded_symbols,
-                                      const uint32_t num_model_params,
-                                      const char** target_dtype_op_names,
-                                      const char** fp32_op_names,
-                                      const char** widest_dtype_op_names,
-                                      const char** conditional_fp32_op_names,
-                                      const char** excluded_symbols,
-                                      const char** conditional_param_names,
-                                      const char** conditional_param_vals,
-                                      const char** model_param_names,
-                                      const char** arg_names);
+                                      const char** input_names_p,
+                                      const char** all_arg_names_p,
+                                      const int* all_arg_types_p,
+                                      const char** target_dtype_ops_p,
+                                      const char** fp32_ops_p,
+                                      const char** widest_dtype_ops_p,
+                                      const char** excluded_syms_p);
+

Review comment:
       What do you think about changing order of parameters to mix length of 
array and array e.g.:
   const uint32_t num_inputs,
   const char** input_names_p,
   const uint32_t num_all_args,
   const char** all_arg_names_p,
   ...
   this way in python it would be easier to identify length and actual array

##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"
+    assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, 
list), \
+        "conditional_fp32_ops should be a list"

Review comment:
       ```suggestion
           "conditional_fp32_ops should be a list of strings"
   ```

##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"
+    assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, 
list), \
+        "conditional_fp32_ops should be a list"
 
-    if target_dtype == 'bfloat16':
-        target_dtype = bfloat16
+    target_dtype = get_dtype_name(target_dtype)
+    assert target_dtype in ['float16', *bfloat16.names], \
+        "Only float16 and bfloat16 types are currently supported as 
target_dtype"
 
-    if target_dtype_ops is not None:
-        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be 
a list of strs"
-    else:
+    if target_dtype_ops is None:
         target_dtype_ops = list_lp16_ops(target_dtype)
-
-    if fp32_ops is not None:
-        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
-    else:
+    if fp32_ops is None:
         fp32_ops = list_fp32_ops(target_dtype)
 
-    if conditional_fp32_ops is not None:
-        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops 
should be a list"
-    else:
+    # conditional ops
+    if conditional_fp32_ops is None:
         conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+    cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+    for cond_op in conditional_fp32_ops:
+        op_name, attr_name, attr_vals = cond_op
+        assert isinstance(op_name, str) and isinstance(attr_name, str) and 
isinstance(attr_vals, list), \
+            "conditional_fp32_ops should be a list of (str, str, list of str)"
+        cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+    nodes_attr = sym.attr_dict()
+    nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+    assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \

Review comment:
       should it be assert? maybe warning is enough?

##########
File path: python/mxnet/amp/amp.py
##########
@@ -459,73 +461,73 @@ def convert_symbol(sym, target_dtype="float16", 
target_dtype_ops=None,
         from being casted to LP16 or FP32.
     data_names : list of strs, optional
         A list of strings that represent input data tensor names to the model
-    cast_optional_params : bool, default False
+    cast_params_offline : bool, default False
         Whether to cast the arg_params and aux_params that don't require to be 
in LP16
         because of a cast layer following it, but will reduce the computation 
and memory
         overhead of the model if casted.
     """
-    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be Symbol"
+    import json
 
-    assert target_dtype in ['float16', 'bfloat16'], \
-               "Only target_dtype float16 and bfloat16 are supported currently"
+    assert isinstance(sym, Symbol), "First argument to convert_symbol should 
be a Symbol"
+    assert target_dtype_ops is None or isinstance(target_dtype_ops, list), \
+        "target_dtype_ops should be a list of strs"
+    assert fp32_ops is None or isinstance(fp32_ops, list), \
+        "fp32_ops should be a list of strs"
+    assert conditional_fp32_ops is None or isinstance(conditional_fp32_ops, 
list), \
+        "conditional_fp32_ops should be a list"
 
-    if target_dtype == 'bfloat16':
-        target_dtype = bfloat16
+    target_dtype = get_dtype_name(target_dtype)
+    assert target_dtype in ['float16', *bfloat16.names], \
+        "Only float16 and bfloat16 types are currently supported as 
target_dtype"
 
-    if target_dtype_ops is not None:
-        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be 
a list of strs"
-    else:
+    if target_dtype_ops is None:
         target_dtype_ops = list_lp16_ops(target_dtype)
-
-    if fp32_ops is not None:
-        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
-    else:
+    if fp32_ops is None:
         fp32_ops = list_fp32_ops(target_dtype)
 
-    if conditional_fp32_ops is not None:
-        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops 
should be a list"
-    else:
+    # conditional ops
+    if conditional_fp32_ops is None:
         conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
+    cond_ops = {cond_op[0]: {} for cond_op in conditional_fp32_ops}
+    for cond_op in conditional_fp32_ops:
+        op_name, attr_name, attr_vals = cond_op
+        assert isinstance(op_name, str) and isinstance(attr_name, str) and 
isinstance(attr_vals, list), \
+            "conditional_fp32_ops should be a list of (str, str, list of str)"
+        cond_ops[op_name].setdefault(attr_name, []).extend(attr_vals)
+
+    nodes_attr = sym.attr_dict()
+    nodes_op = {n['name']: n['op'] for n in json.loads(sym.tojson())['nodes']}
+    assert set(excluded_sym_names).issubset(set(nodes_op.keys())), \
+        "excluded_sym_names are not present in the network. Missing layers: 
{}".format(
+            set(excluded_sym_names) - set(nodes_op.keys()))
+
+    for node_name, node_op in nodes_op.items():
+        if node_op not in cond_ops:
+            continue
+        node_attrs = nodes_attr[node_name]
+        for attr_name, attr_vals in cond_ops[node_op].items():
+            assert attr_name in node_attrs
+            if node_attrs[attr_name] in attr_vals:
+                excluded_sym_names += node_name
+                break
+    excluded_sym_names = list(set(excluded_sym_names))
 
-    original_conditional_op_names = []
-    conditional_op_names = []
-    param_names = []
-    param_vals = []
-    indptr = [0]
-    for conditional_fp32_op in conditional_fp32_ops:
-        assert isinstance(conditional_fp32_op[0], str) and 
isinstance(conditional_fp32_op[1], str) \
-            and isinstance(conditional_fp32_op[2], list), 
"conditional_fp32_ops should be a list of " \
-                                                          "(str, str, list of 
str)"
-        param_vals += conditional_fp32_op[2]
-        indptr.append(len(param_vals))
-        param_names.append(conditional_fp32_op[1])
-        conditional_op_names.append(conditional_fp32_op[0])
-
-    if excluded_sym_names is not None:
-        assert isinstance(excluded_sym_names, list), "excluded_sym_names 
should be a list of strs"
-    else:
-        excluded_sym_names = []
-
-    for original_conditional_fp32_op in 
list_conditional_fp32_ops(target_dtype):
-        original_conditional_op_names.append(original_conditional_fp32_op[0])
-
-    # Op lists should not have intersection
+    # Op lists should not intersect
     common_ops = set(target_dtype_ops) & set(fp32_ops)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in target_dtype_ops and fp32_ops 
{}".format(common_ops)
-    common_ops = set(target_dtype_ops) & set(conditional_op_names)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in target_dtype_ops and 
conditional_fp32_ops {}".format(common_ops)
-    common_ops = set(conditional_op_names) & set(fp32_ops)
-    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
-                                 "Common ops in fp32_ops and 
conditional_fp32_ops {}".format(common_ops)
-
-    combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
-    all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + 
list_fp32_ops(target_dtype)
-                            + list_lp16_fp32_ops(target_dtype) + 
original_conditional_op_names)
+    assert len(common_ops) == 0, "Common ops in target_dtype_ops and fp32_ops: 
{}".format(common_ops)
+    common_ops = set(target_dtype_ops) & set(cond_ops)
+    assert len(common_ops) == 0, "Common ops in target_dtype_ops and 
conditional_fp32_ops: {}".format(
+        common_ops)
+    common_ops = set(cond_ops) & set(fp32_ops)
+    assert len(common_ops) == 0, "Common ops in fp32_ops and 
conditional_fp32_ops: {}".format(common_ops)
+
+    combined_ops = set(target_dtype_ops + fp32_ops + list(cond_ops.keys()))
+    original_cond_ops = [cond_op[0] for cond_op in 
list_conditional_fp32_ops(target_dtype)]

Review comment:
       ```suggestion
       original_cond_ops = list(cond_ops.keys())
   ```
   Is it the same? similar loop is in line 491 

##########
File path: src/operator/subgraph/dnnl/dnnl_post_amp_property.h
##########
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_AMP_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <set>
+#include <string>
+#include <vector>
+
+#include "../../tensor/amp_cast.h"
+#include "../common.h"

Review comment:
       Use "operator/tensor/amp_cast.h" and "operator/subgraph/common.h"

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  // Converts the dtype of this NodeEntry. This should be called after a node 
has been converted and
+  // dtypes of its outputs may have changed
+  void UpdateDTtypeAfterConversion(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(entry.node->op());
+    dtype = new_dtype;
+  }
+
+  // If dtype of this NodeEntry was not changed, returns the mapped entry. 
Otherwise returns a
+  // NodeEntry to the node which casts to the original dtype of this NodeEntry.
+  const NodeEntry& AsOriginal() {
+    return AsType(original_dtype);
+  }
+
+  // If dtype of this NodeEntry matches the specified dtype, returns the 
mapped entry. Otherwise
+  // returns a NodeEntry to the node which casts to that type.
+  const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = 
true) {
+    if (dtype == target_dtype) {
+      return entry;
     }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      CHECK(can_add_cast);
+      cast_entry = Cast(target_dtype);
+      CHECK(cast_entry.node);
+    }
+    return cast_entry;
+  }
+
+  // Returns whether this NodeEntry has the specified dtype or an existing 
cast to that dtype
+  bool HasDTypeEntry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  // Returns whether this NodeEntry (of a parameter) can be cast offline
+  bool CanBeCastOfflineTo(const int target_dtype) const {
+    CHECK(entry.node->is_variable());
+    return casts.count(target_dtype) > 0;
+  }
+
+ private:
+  ObjectPtr CreateCastNode(const std::string& op_name, const std::string& 
node_name) {
+    CHECK_GT(op_name.size(), 0);
+
+    ObjectPtr node   = Node::Create();
+    node->attrs.name = node_name;
+    node->attrs.op   = Op::Get(op_name);
+    node->inputs.emplace_back(entry);
+    return node;
+  }
+
+  NodeEntry Cast(const int new_dtype) {
+    CHECK(new_dtype == nnvm::kBfloat16 || new_dtype == nnvm::kFloat16 ||
+          new_dtype == nnvm::kFloat32);
+
+    const std::string dt_name        = mxnet::op::type_string(new_dtype);
+    const std::string suffix         = "_" + std::to_string(entry.index);
+    const std::string cast_node_name = entry.node->attrs.name + suffix + 
"_amp_cast_" + dt_name;
+    ObjectPtr cast_node              = CreateCastNode("amp_cast", 
cast_node_name);
+    cast_node->attrs.dict["dtype"]   = dt_name;
+    cast_node->op()->attr_parser(&(cast_node->attrs));
+    return NodeEntry{std::move(cast_node), 0, 0};
+  }
+
+ public:
+  const NodeEntry entry;
+  const int original_dtype;  // original dtype of the entry
+
+ private:
+  int dtype;  // current dtype of the entry
+  std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t     = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t      = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, 
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t     = std::unordered_map<Node*, std::unordered_map<Node*, 
NodeEntry>>;
+
+// Makes sure the node in the new graph will work with the same precision as 
in the original graph
+static void KeepOriginalNode(const ObjectPtr& old_node,
+                             const NodeMap_t& node_map,
+                             EntryMap_t* const entry_map) {
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (const auto& old_ne : old_node->inputs) {
+    new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
   }
-  return suffix;
 }
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
+// Tries to convert the node to low precision. Returns whether the node has 
been successfully
+// converted
+static bool TryLowPrecision(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  static auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+  static auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  in_types[0] = target_dtype;
+
+  if (infertype.count(old_node->op()) == 0 ||
+      infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == 
false) {
+    return false;
+  }
+
+  if (fmutate_inputs.count(old_node->op()) != 0) {
+    std::vector<uint32_t> mutable_inputs = 
fmutate_inputs[old_node->op()](old_node->attrs);
+    for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+      if (in_types[i] == target_dtype) {
+        if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != 
mutable_inputs.end()) {
+          return false;
+        }
+      }
+    }
+  }

Review comment:
       i would appreciate comment why we checking it (in code)

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  // Converts the dtype of this NodeEntry. This should be called after a node 
has been converted and
+  // dtypes of its outputs may have changed
+  void UpdateDTtypeAfterConversion(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(entry.node->op());
+    dtype = new_dtype;
+  }
+
+  // If dtype of this NodeEntry was not changed, returns the mapped entry. 
Otherwise returns a
+  // NodeEntry to the node which casts to the original dtype of this NodeEntry.
+  const NodeEntry& AsOriginal() {
+    return AsType(original_dtype);
+  }
+
+  // If dtype of this NodeEntry matches the specified dtype, returns the 
mapped entry. Otherwise
+  // returns a NodeEntry to the node which casts to that type.
+  const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = 
true) {
+    if (dtype == target_dtype) {
+      return entry;
     }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      CHECK(can_add_cast);
+      cast_entry = Cast(target_dtype);
+      CHECK(cast_entry.node);
+    }
+    return cast_entry;
+  }
+
+  // Returns whether this NodeEntry has the specified dtype or an existing 
cast to that dtype
+  bool HasDTypeEntry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  // Returns whether this NodeEntry (of a parameter) can be cast offline
+  bool CanBeCastOfflineTo(const int target_dtype) const {
+    CHECK(entry.node->is_variable());
+    return casts.count(target_dtype) > 0;
+  }
+
+ private:
+  ObjectPtr CreateCastNode(const std::string& op_name, const std::string& 
node_name) {
+    CHECK_GT(op_name.size(), 0);
+
+    ObjectPtr node   = Node::Create();
+    node->attrs.name = node_name;
+    node->attrs.op   = Op::Get(op_name);
+    node->inputs.emplace_back(entry);
+    return node;
+  }
+
+  NodeEntry Cast(const int new_dtype) {
+    CHECK(new_dtype == nnvm::kBfloat16 || new_dtype == nnvm::kFloat16 ||
+          new_dtype == nnvm::kFloat32);
+
+    const std::string dt_name        = mxnet::op::type_string(new_dtype);
+    const std::string suffix         = "_" + std::to_string(entry.index);
+    const std::string cast_node_name = entry.node->attrs.name + suffix + 
"_amp_cast_" + dt_name;
+    ObjectPtr cast_node              = CreateCastNode("amp_cast", 
cast_node_name);
+    cast_node->attrs.dict["dtype"]   = dt_name;
+    cast_node->op()->attr_parser(&(cast_node->attrs));
+    return NodeEntry{std::move(cast_node), 0, 0};
+  }
+
+ public:
+  const NodeEntry entry;
+  const int original_dtype;  // original dtype of the entry
+
+ private:
+  int dtype;  // current dtype of the entry
+  std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t     = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t      = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, 
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t     = std::unordered_map<Node*, std::unordered_map<Node*, 
NodeEntry>>;
+
+// Makes sure the node in the new graph will work with the same precision as 
in the original graph
+static void KeepOriginalNode(const ObjectPtr& old_node,
+                             const NodeMap_t& node_map,
+                             EntryMap_t* const entry_map) {
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (const auto& old_ne : old_node->inputs) {
+    new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
   }
-  return suffix;
 }
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
+// Tries to convert the node to low precision. Returns whether the node has 
been successfully
+// converted
+static bool TryLowPrecision(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  static auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+  static auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  in_types[0] = target_dtype;
+
+  if (infertype.count(old_node->op()) == 0 ||
+      infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == 
false) {
+    return false;
+  }
+
+  if (fmutate_inputs.count(old_node->op()) != 0) {
+    std::vector<uint32_t> mutable_inputs = 
fmutate_inputs[old_node->op()](old_node->attrs);
+    for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+      if (in_types[i] == target_dtype) {
+        if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != 
mutable_inputs.end()) {
+          return false;
+        }
+      }
+    }
+  }
+
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+    
new_node->inputs.push_back(entry_map->at(old_node->inputs[i]).AsType(in_types[i]));
+  }
+
+  for (const NodeEntry& old_ne : nodes_entries.at(old_node.get())) {
+    entry_map->at(old_ne).UpdateDTtypeAfterConversion(out_types[old_ne.index]);
+  }
+
+  return true;
 }
 
-// add amp_multicast node between curr_node and inputs
-static void AddMultiCastNode(const std::vector<NodeEntry>& inputs,
-                             const std::string& node_name,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map,
-                             ObjectPtr curr_node) {
-  ObjectPtr node =
-      CreateNode("amp_multicast", inputs[0].node->attrs.name + node_name + 
"_amp_multicast");
-  for (const auto& node_entry : inputs) {
-    ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
-    NodeEntry mirror_entry =
-        NodeEntry{std::move(mirror_node), node_entry.index, 
node_entry.version};
-    node->inputs.emplace_back(mirror_entry);
+// Tries to convert the node to low precision if all of its inputs already 
have the correct dtype.
+// Otherwise keeps the node unchanged.
+static void HandleWidestDtypeNode(const int target_dtype,
+                                  const ObjectPtr& old_node,
+                                  const NodeMap_t& node_map,
+                                  const NodesEntries_t& nodes_entries,
+                                  EntryMap_t* const entry_map) {
+  static auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+
+  std::vector<int> in_types(old_node->inputs.size(), target_dtype);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  const bool inferred = (infertype.count(old_node->op()) > 0 &&
+                         infertype[old_node->op()](old_node->attrs, &in_types, 
&out_types));
+
+  bool has_lp_inputs = inferred;
+  for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
+    const NodeEntry& input = old_node->inputs[i];
+    has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+  }
+
+  if (!has_lp_inputs ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  node->attrs.dict["num_outputs"] = std::to_string(inputs.size());
-  node->op()->attr_parser(&(node->attrs));
-  for (uint32_t i = 0; i < inputs.size(); ++i) {
-    const auto& e = inputs[i];
-    curr_node->inputs.emplace_back(NodeEntry{node, static_cast<uint32_t>(i), 
e.version});
+}
+
+// Tries to convert the node to low precision if some of its inputs already 
are converted.
+// Otherwise keeps the node unchanged.
+void HandleDTypeNeutralNode(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  const auto& is_lp = [&](const auto& old_ne) {
+    return entry_map->at(old_ne).HasDTypeEntry(target_dtype);
+  };
+  if (!std::any_of(old_node->inputs.begin(), old_node->inputs.end(), is_lp) ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  return;
 }
 
-static bool CheckConditionalFP32(
-    const std::unordered_map<std::string,
-                             std::unordered_map<std::string, 
std::vector<std::string>>>&
-        conditional_fp32_ops,
-    const std::unordered_set<std::string>& excluded_syms,
-    ObjectPtr node) {
-  if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) ||
-      conditional_fp32_ops.count(node->op()->name) == 0) {
-    return false;
-  } else {
-    // Iterate through all conditional ops
-    auto it = conditional_fp32_ops.find(node->op()->name);
-    if (it != conditional_fp32_ops.end()) {
-      auto it_params = it->second;
-      // For each param name, iterate through param values to check
-      // if the provided param name is equal to any of the values
-      for (auto& it_param : it_params) {
-        auto param_key = node->attrs.dict.find(it_param.first);
-        if (param_key != node->attrs.dict.end()) {
-          auto it_param_vals = it_param.second;
-          if (std::find(it_param_vals.begin(), it_param_vals.end(), 
param_key->second) !=
-              it_param_vals.end()) {
-            return true;
+// Decides which prameters can be cast offline and removes redundant cast 
nodes from the graph
+static void RemoveParamCasts(const NodeMap_t& node_map,
+                             const DstNodes_t& old_param_dst_nodes,
+                             const int target_dtype,
+                             EntryMap_t* entry_map) {
+  for (const auto& [old_param, old_param_dsts] : old_param_dst_nodes) {
+    const ObjectPtr& new_param      = node_map.at(old_param);
+    const auto& can_be_cast_offline = [&](const auto& old_node_x_ne_pair) {
+      const ObjectPtr& new_node        = node_map.at(old_node_x_ne_pair.first);
+      const MappedNodeEntry& mapped_ne = 
entry_map->at(old_node_x_ne_pair.second);
+      for (const NodeEntry& node_entry : new_node->inputs) {
+        if (node_entry.node == new_param) {
+          return false;
+        }
+      }
+      return mapped_ne.CanBeCastOfflineTo(target_dtype);
+    };
+
+    if (std::all_of(old_param_dsts.begin(), old_param_dsts.end(), 
can_be_cast_offline)) {
+      nnvm::NodeEntryEqual are_equal;
+      for (const auto& [old_dst_node, old_ne] : old_param_dsts) {
+        MappedNodeEntry& mapped_ne      = entry_map->at(old_ne);
+        const NodeEntry& new_ne_to_skip = mapped_ne.AsType(target_dtype, 
false);
+        const ObjectPtr& new_dst_node   = node_map.at(old_dst_node);
+        bool skipped_amp_cast           = false;
+        for (NodeEntry& new_ne : new_dst_node->inputs) {
+          if (are_equal(new_ne, new_ne_to_skip)) {
+            new_ne           = mapped_ne.entry;
+            skipped_amp_cast = true;
+            break;
           }
         }
+        CHECK(skipped_amp_cast);
       }
+      new_param->attrs.dict["__dtype__"] = std::to_string(target_dtype);

Review comment:
       Can we use type_string function from src/operator/operator_common.h?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to