anko-intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r828015534



##########
File path: src/serialization/cnpy.cc
##########
@@ -109,7 +109,7 @@ std::string dtype_descr(const TBlob& blob) {
     case mshadow::kUint64:
       return "'" MXNET_BYTEORDER "u8'";
     case mshadow::kBfloat16:
-      return "[('bfloat16', '" MXNET_BYTEORDER "u2')]";
+      return "'" MXNET_BYTEORDER "bfloat16'";

Review comment:
       other valuess here seems to safe the space. What about bf16 ? or rather 
according to convention bf2 or b2

##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,336 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+// Before the model conversion, node entries of the original graph are mapped 
to the equivalent node
+// entries in the new graph that will be then converted to a mixed precision 
graph. This class wraps
+// a mapped NodeEntry from the new graph, providing a transparent interface 
for acquiring versions
+// of the wrapped entry with a specific dtype, adding a casting nodes to the 
graph when needed (one
+// for each unique dtype that was requested).
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  // Converts the dtype of this NodeEntry. This should be called after a node 
has been converted and
+  // dtypes of its outputs may have changed
+  void UpdateDTypeAfterConversion(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(entry.node->op());
+    dtype = new_dtype;
+  }
+
+  // If dtype of this NodeEntry was not changed, returns the mapped entry. 
Otherwise returns a
+  // NodeEntry to the node which casts to the original dtype of this NodeEntry.
+  const NodeEntry& AsOriginal() {
+    return AsType(original_dtype);
+  }
+
+  // If dtype of this NodeEntry matches the specified dtype, returns the 
mapped entry. Otherwise
+  // returns a NodeEntry to the node which casts to that type.
+  const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = 
true) {
+    if (dtype == target_dtype) {
+      return entry;
     }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      CHECK(can_add_cast);
+      cast_entry = Cast(target_dtype);
+      CHECK(cast_entry.node);
+    }
+    return cast_entry;
+  }
+
+  // Returns whether this NodeEntry has the specified dtype or an existing 
cast to that dtype
+  bool HasDTypeEntry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  // Returns whether this NodeEntry (of a parameter) can be cast offline
+  bool CanBeCastOfflineTo(const int target_dtype) const {
+    CHECK(entry.node->is_variable());
+    return casts.count(target_dtype) > 0;
+  }
+
+ private:
+  ObjectPtr CreateCastNode(const std::string& op_name, const std::string& 
node_name) {
+    CHECK_GT(op_name.size(), 0);
+
+    ObjectPtr node   = Node::Create();
+    node->attrs.name = node_name;
+    node->attrs.op   = Op::Get(op_name);
+    node->inputs.emplace_back(entry);
+    return node;
+  }
+
+  NodeEntry Cast(const int new_dtype) {
+    CHECK(new_dtype == nnvm::kBfloat16 || new_dtype == nnvm::kFloat16 ||
+          new_dtype == nnvm::kFloat32);
+
+    const std::string dt_name        = mxnet::op::type_string(new_dtype);
+    const std::string suffix         = "_" + std::to_string(entry.index);
+    const std::string cast_node_name = entry.node->attrs.name + suffix + 
"_amp_cast_" + dt_name;
+    ObjectPtr cast_node              = CreateCastNode("amp_cast", 
cast_node_name);
+    cast_node->attrs.dict["dtype"]   = dt_name;
+    cast_node->op()->attr_parser(&(cast_node->attrs));
+    return NodeEntry{std::move(cast_node), 0, 0};
+  }
+
+ public:
+  const NodeEntry entry;
+  const int original_dtype;  // original dtype of the entry
+
+ private:
+  int dtype;  // current dtype of the entry
+  std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t     = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t      = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, 
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t     = std::unordered_map<Node*, std::unordered_map<Node*, 
NodeEntry>>;
+
+// Makes sure the node in the new graph will work with the same precision as 
in the original graph
+static void KeepOriginalNode(const ObjectPtr& old_node,
+                             const NodeMap_t& node_map,
+                             EntryMap_t* const entry_map) {
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (const auto& old_ne : old_node->inputs) {
+    new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
   }
-  return suffix;
 }
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
+// Tries to convert the node to low precision. Returns whether the node has 
been successfully
+// converted
+static bool TryLowPrecision(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  static auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+  static auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  in_types[0] = target_dtype;
+
+  if (infertype.count(old_node->op()) == 0 ||
+      infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == 
false) {
+    return false;
+  }
+
+  if (fmutate_inputs.count(old_node->op()) != 0) {
+    std::vector<uint32_t> mutable_inputs = 
fmutate_inputs[old_node->op()](old_node->attrs);
+    for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+      if (in_types[i] == target_dtype) {
+        if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != 
mutable_inputs.end()) {
+          return false;
+        }
+      }
+    }
+  }
+
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+    
new_node->inputs.push_back(entry_map->at(old_node->inputs[i]).AsType(in_types[i]));
+  }
+
+  for (const NodeEntry& old_ne : nodes_entries.at(old_node.get())) {
+    entry_map->at(old_ne).UpdateDTypeAfterConversion(out_types[old_ne.index]);
+  }
+
+  return true;
 }
 
-// add amp_multicast node between curr_node and inputs
-static void AddMultiCastNode(const std::vector<NodeEntry>& inputs,
-                             const std::string& node_name,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map,
-                             ObjectPtr curr_node) {
-  ObjectPtr node =
-      CreateNode("amp_multicast", inputs[0].node->attrs.name + node_name + 
"_amp_multicast");
-  for (const auto& node_entry : inputs) {
-    ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
-    NodeEntry mirror_entry =
-        NodeEntry{std::move(mirror_node), node_entry.index, 
node_entry.version};
-    node->inputs.emplace_back(mirror_entry);
+// Tries to convert the node to low precision if all of its inputs already 
have the correct dtype.
+// Otherwise keeps the node unchanged.
+static void HandleWidestDtypeNode(const int target_dtype,
+                                  const ObjectPtr& old_node,
+                                  const NodeMap_t& node_map,
+                                  const NodesEntries_t& nodes_entries,
+                                  EntryMap_t* const entry_map) {
+  static auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+
+  std::vector<int> in_types(old_node->inputs.size(), target_dtype);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  const bool inferred = (infertype.count(old_node->op()) > 0 &&
+                         infertype[old_node->op()](old_node->attrs, &in_types, 
&out_types));
+
+  bool has_lp_inputs = inferred;
+  for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
+    const NodeEntry& input = old_node->inputs[i];
+    has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+  }
+
+  if (!has_lp_inputs ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  node->attrs.dict["num_outputs"] = std::to_string(inputs.size());
-  node->op()->attr_parser(&(node->attrs));
-  for (uint32_t i = 0; i < inputs.size(); ++i) {
-    const auto& e = inputs[i];
-    curr_node->inputs.emplace_back(NodeEntry{node, static_cast<uint32_t>(i), 
e.version});
+}
+
+// Tries to convert the node to low precision if some of its inputs already 
are converted.
+// Otherwise keeps the node unchanged.
+void HandleDTypeNeutralNode(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  const auto& is_lp = [&](const auto& old_ne) {
+    return entry_map->at(old_ne).HasDTypeEntry(target_dtype);
+  };
+  if (!std::any_of(old_node->inputs.begin(), old_node->inputs.end(), is_lp) ||
+      !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
entry_map)) {
+    KeepOriginalNode(old_node, node_map, entry_map);
   }
-  return;
 }
 
-static bool CheckConditionalFP32(
-    const std::unordered_map<std::string,
-                             std::unordered_map<std::string, 
std::vector<std::string>>>&
-        conditional_fp32_ops,
-    const std::unordered_set<std::string>& excluded_syms,
-    ObjectPtr node) {
-  if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) ||
-      conditional_fp32_ops.count(node->op()->name) == 0) {
-    return false;
-  } else {
-    // Iterate through all conditional ops
-    auto it = conditional_fp32_ops.find(node->op()->name);
-    if (it != conditional_fp32_ops.end()) {
-      auto it_params = it->second;
-      // For each param name, iterate through param values to check
-      // if the provided param name is equal to any of the values
-      for (auto& it_param : it_params) {
-        auto param_key = node->attrs.dict.find(it_param.first);
-        if (param_key != node->attrs.dict.end()) {
-          auto it_param_vals = it_param.second;
-          if (std::find(it_param_vals.begin(), it_param_vals.end(), 
param_key->second) !=
-              it_param_vals.end()) {
-            return true;
+// Decides which prameters can be cast offline and removes redundant cast 
nodes from the graph
+static void RemoveParamCasts(const int target_dtype,
+                             const std::string& offline_param_cast_attr,
+                             const NodeMap_t& node_map,
+                             const DstNodes_t& old_param_dst_nodes,
+                             EntryMap_t* entry_map) {
+  for (const auto& [old_param, old_param_dsts] : old_param_dst_nodes) {
+    const ObjectPtr& new_param      = node_map.at(old_param);
+    const auto& can_be_cast_offline = [&](const auto& old_node_x_ne_pair) {
+      const ObjectPtr& new_node        = node_map.at(old_node_x_ne_pair.first);
+      const MappedNodeEntry& mapped_ne = 
entry_map->at(old_node_x_ne_pair.second);
+      for (const NodeEntry& node_entry : new_node->inputs) {
+        if (node_entry.node == new_param) {
+          return false;
+        }
+      }
+      return mapped_ne.CanBeCastOfflineTo(target_dtype);
+    };
+
+    if (std::all_of(old_param_dsts.begin(), old_param_dsts.end(), 
can_be_cast_offline)) {
+      nnvm::NodeEntryEqual are_equal;
+      for (const auto& [old_dst_node, old_ne] : old_param_dsts) {
+        MappedNodeEntry& mapped_ne      = entry_map->at(old_ne);
+        const NodeEntry& new_ne_to_skip = mapped_ne.AsType(target_dtype, 
false);
+        const ObjectPtr& new_dst_node   = node_map.at(old_dst_node);
+        bool skipped_amp_cast           = false;
+        for (NodeEntry& new_ne : new_dst_node->inputs) {
+          if (are_equal(new_ne, new_ne_to_skip)) {
+            new_ne           = mapped_ne.entry;
+            skipped_amp_cast = true;
+            break;
           }
         }
+        CHECK(skipped_amp_cast);
       }
+      new_param->attrs.dict[offline_param_cast_attr] = 
mxnet::op::type_string(target_dtype);
     }
-    return false;
   }
 }
 
 Graph ReducePrecision(Graph&& src) {
-  static auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
-  static auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
-  const auto target_dtype_ops = 
src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
-  const auto fp32_ops         = 
src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
-  const auto widest_dtype_ops = 
src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
-  const auto target_dtype     = src.GetAttr<int>("target_dtype");
-  const auto excluded_syms    = 
src.GetAttr<std::unordered_set<std::string>>("excluded_syms");
-  const auto conditional_fp32_ops = src.GetAttr<
-      std::unordered_map<std::string, std::unordered_map<std::string, 
std::vector<std::string>>>>(
-      "conditional_fp32_ops");
-  const auto data_name_types = src.GetAttr<std::unordered_map<std::string, 
int>>("data_name_types");
-  const auto cast_optional_params = src.GetAttr<int>("cast_optional_params");
+  const auto target_dtype             = src.GetAttr<int>("target_dtype");
+  const auto cast_params_offline      = 
src.GetAttr<int>("cast_params_offline");
+  const auto& offline_param_cast_attr = 
src.GetAttr<std::string>("offline_param_cast_attr");
+  const auto& input_names             = 
src.GetAttr<std::unordered_set<std::string>>("input_names");
+  const auto& target_dtype_ops = 
src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
+  const auto& fp32_ops         = 
src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
+  const auto& widest_dtype_ops = 
src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
+  const auto& excluded_syms    = 
src.GetAttr<std::unordered_set<std::string>>("excluded_syms");
+  auto src_dtypes              = src.GetAttr<nnvm::DTypeVector>("dtype");  // 
copy, not reference
 
   CHECK(target_dtype == mshadow::kFloat16 || target_dtype == 
mshadow::kBfloat16)
       << "Only float16 and bfloat16 target_dtype is supported yet," << 
target_dtype;
 
-  std::string target_dtype_str = "float32";
-  if (target_dtype == mshadow::kFloat16) {
-    target_dtype_str = "float16";
-  } else if (target_dtype == mshadow::kBfloat16) {
-    target_dtype_str = "bfloat16";
+  const nnvm::IndexedGraph& src_idx = src.indexed_graph();
+  CHECK_EQ(src_dtypes.size(), src_idx.num_node_entries());
+  for (const int src_dtype : src_dtypes) {
+    CHECK_NE(src_dtype, -1) << "Infer type failed with full information about 
input types";
   }
 
-  // Additional data structures to share common cast node inputs among 
different nodes
-  std::unordered_map<Node*, ObjectPtr> mirror_map;
-  nnvm::NodeEntryMap<NodeEntry> mirror_fp32_map;
-  nnvm::NodeEntryMap<NodeEntry> mirror_target_dtype_map;
+  NodeMap_t node_map;
+  EntryMap_t entry_map;
+  NodesEntries_t nodes_entries;
+  DstNodes_t old_param_dst_nodes;
+
+  const auto& register_node_entry =
+      [&](const NodeEntry& old_ne, const ObjectPtr& old_dst_node, const 
ObjectPtr& new_dst_node) {
+        // new_dst_node is the node that should own `old_ne` equivalent as one 
of its input
+        const uint32_t entry_id       = src_idx.entry_id(old_ne);
+        const int original_ne_dtype   = src_dtypes[entry_id];
+        const ObjectPtr& old_src_node = old_ne.node;
+        const ObjectPtr& new_src_node = node_map.at(old_src_node.get());
+        const NodeEntry new_ne        = NodeEntry(new_src_node, old_ne.index, 
old_ne.version);
 
-  // Visit nodes in a topologically sorted order
-  DFSVisit(src.outputs, [&](const ObjectPtr& node) {
-    ObjectPtr new_node = Node::Create(*node);
+        entry_map.emplace(old_ne, MappedNodeEntry(new_ne, original_ne_dtype));
+
+        // register which nodes use parameters
+        nodes_entries[old_src_node.get()].insert(old_ne);
+        if (new_dst_node && old_src_node->is_variable() &&
+            input_names.count(old_src_node->attrs.name) == 0) {
+          CHECK(old_dst_node);
+          old_param_dst_nodes[old_src_node.get()][old_dst_node.get()] = old_ne;
+        }
+      };
+
+  // gather information about node entries and build a new graph
+  DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+    ObjectPtr new_node = Node::Create(*old_node);
     new_node->inputs.clear();
-    std::vector<uint32_t> mutable_inputs;
-    if (fmutate_inputs.count(node->op()) != 0) {
-      mutable_inputs = fmutate_inputs[node->op()](node->attrs);
+    for (const NodeEntry& old_ne : old_node->inputs) {
+      register_node_entry(old_ne, old_node, new_node);
     }
-    /* 1. for node which needs to run in FP32 mode, add amp_cast operators
-     * (to fp32) after its inputs
-     * 2. for node which needs to run in LP16 mode, add amp_cast operators
-     * (to target_dtype) after its inputs
-     * 3. for nodes which need to run in widest dtype among its inputs, add
-     * amp_multicast operators between op and its inputs
-     * 4. for nodes which need to run in FP32 mode, based on a specific 
condition,
-     * check the condition, and if true add amp_cast (to fp32) after its inputs
-     * 4. for other nodes, create copy node and add it to the mirror_map
-     */
-    if ((!node->is_variable() && fp32_ops.count(node->op()->name) > 0) ||
-        (excluded_syms.count(node->attrs.name) > 0)) {
-      // Add output entry to fp32_map
-      for (size_t i = 0; i < node->num_outputs(); ++i) {
-        const auto out_entry       = NodeEntry(node, i, 0);
-        mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0);
-      }
-      for (size_t i = 0; i < node->inputs.size(); ++i) {
-        const auto& node_entry = node->inputs[i];
-        if (mirror_fp32_map.count(node_entry)) {
-          new_node->inputs.emplace_back(mirror_fp32_map[node_entry]);
-        } else if (node_entry.node->is_variable()) {
-          // For variable, assume they are already fp32
-          ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
-          new_node->inputs.emplace_back(mirror_node, node_entry.index, 
node_entry.version);
-        } else {
-          ObjectPtr mirror_node  = mirror_map.at(node_entry.node.get());
-          NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, 
node_entry.version};
-          std::string suffix     = GetSuffix(node_entry, mirror_map);
-          AddCastNode(node_entry, suffix, mirror_entry, "float32", 
&mirror_fp32_map, new_node);
-        }
-      }
-    } else if (!node->is_variable() && 
target_dtype_ops.count(node->op()->name) > 0 &&
-               excluded_syms.count(node->attrs.name) == 0) {
-      std::vector<int> in_types(node->inputs.size(), -1);
-      std::vector<int> out_types(node->num_outputs(), -1);
-      if (infertype.count(node->op())) {
-        // Try to infertype with target dtype. And add output entry to 
mirror_target_dtype_map or
-        // mirror_fp32_map based on infered result.
-        in_types[0]             = target_dtype;
-        bool infer_type_success = infertype[node->op()](node->attrs, 
&in_types, &out_types);
-        CHECK(infer_type_success == true);
-        for (size_t i = 0; i < node->num_outputs(); ++i) {
-          const auto out_entry = NodeEntry(node, i, 0);
-          if (out_types[i] == target_dtype) {
-            mirror_target_dtype_map[out_entry] = NodeEntry(new_node, i, 0);
-          } else if (out_types[i] == 0) {
-            mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0);
-          }
-        }
-      }
-      for (size_t i = 0; i < node->inputs.size(); ++i) {
-        const auto& node_entry = node->inputs[i];
-        if (mirror_target_dtype_map.count(node_entry)) {
-          new_node->inputs.emplace_back(mirror_target_dtype_map[node_entry]);
-        } else if ((cast_optional_params && node_entry.node->is_variable() &&
-                    !data_name_types.count(node_entry.node->attrs.name)) ||
-                   (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) 
!=
-                    mutable_inputs.end()) ||
-                   !(in_types[i] == target_dtype || in_types[i] == -1)) {
-          // Here's some rules that not insert amp_cast for inputs:
-          // 1. cast_optional_params is True, node_entry.node is variable and 
its not the data of
-          //    the network. This is network params that offline converted to 
target dtype.
-          // 2. Mutable inputs.
-          // 3. Even the input[0] is target dtype, some operations still 
require float32 for other
-          //    inputs. For example, Batchnorm.
-          ObjectPtr mirror_node   = mirror_map.at(node_entry.node.get());
-          const auto mirror_entry = NodeEntry(mirror_node, node_entry.index, 
node_entry.version);
-          new_node->inputs.push_back(mirror_entry);
-          if ((cast_optional_params && node_entry.node->is_variable())) {
-            // Node is target dtype
-            mirror_target_dtype_map[node_entry] = mirror_entry;
-          }
-        } else {
-          ObjectPtr mirror_node  = mirror_map.at(node_entry.node.get());
-          NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, 
node_entry.version};
-          std::string suffix     = GetSuffix(node_entry, mirror_map);
-          AddCastNode(node_entry,
-                      suffix,
-                      mirror_entry,
-                      target_dtype_str,
-                      &mirror_target_dtype_map,
-                      new_node);
-        }
-      }
-    } else if (!node->is_variable() && 
widest_dtype_ops.count(node->op()->name) > 0 &&
-               excluded_syms.count(node->attrs.name) == 0) {
-      CHECK(node->inputs.size() > 0)
-          << "Please check the symbol. node name: " << node->attrs.name << "op 
name "
-          << node->op()->name << " has no inputs."
-          << "It is likely that something went wrong during symbolic 
construction.";
-      CHECK_EQ(mutable_inputs.size(), 0)
-          << "can't handle the widest_dtype_ops with mutable inputs.";
-      int out_dtype           = target_dtype;
-      bool have_unknown_dtype = false;
-      for (auto& input : node->inputs) {
-        // Try to infer output dtype based on input dtype
-        if (!mirror_target_dtype_map.count(input) && 
!mirror_fp32_map.count(input)) {
-          have_unknown_dtype = true;
-          break;
-        } else if (mirror_fp32_map.count(input)) {
-          out_dtype = mshadow::kFloat32;
-        }
-      }
-      if (have_unknown_dtype) {
-        // We can't infer all dtype for inputs, so we need to add 
AddMultiCastNode here.
-        const auto& e      = node->inputs[0];
-        std::string suffix = GetSuffix(e, mirror_map);
-        AddMultiCastNode(node->inputs, suffix, mirror_map, new_node);
-      } else {
-        for (size_t i = 0; i < node->num_outputs(); ++i) {
-          const auto out_entry = NodeEntry(node, i, 0);
-          if (out_dtype == target_dtype) {
-            mirror_target_dtype_map[out_entry] = NodeEntry(new_node, i, 0);
-          } else {
-            mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0);
-          }
-        }
-        // we know all dtype from inputs, then we can use amp_cast instead.
-        for (size_t i = 0; i < node->inputs.size(); ++i) {
-          const auto& node_entry = node->inputs[i];
-          if (out_dtype == target_dtype) {
-            if (mirror_target_dtype_map.count(node_entry)) {
-              
new_node->inputs.emplace_back(mirror_target_dtype_map[node_entry]);
-            } else {
-              ObjectPtr mirror_node  = mirror_map.at(node_entry.node.get());
-              NodeEntry mirror_entry = NodeEntry{mirror_node, 
node_entry.index, node_entry.version};
-              std::string suffix     = GetSuffix(node_entry, mirror_map);
-              AddCastNode(node_entry,
-                          suffix,
-                          mirror_entry,
-                          target_dtype_str,
-                          &mirror_target_dtype_map,
-                          new_node);
-            }
-          } else {
-            if (mirror_fp32_map.count(node_entry)) {
-              new_node->inputs.emplace_back(mirror_fp32_map[node_entry]);
-            } else {
-              ObjectPtr mirror_node  = mirror_map.at(node_entry.node.get());
-              NodeEntry mirror_entry = NodeEntry{mirror_node, 
node_entry.index, node_entry.version};
-              std::string suffix     = GetSuffix(node_entry, mirror_map);
-              AddCastNode(node_entry, suffix, mirror_entry, "float32", 
&mirror_fp32_map, new_node);
-            }
-          }
-        }
-      }
-    } else if (CheckConditionalFP32(conditional_fp32_ops, excluded_syms, 
node)) {
-      for (size_t i = 0; i < node->num_outputs(); ++i) {
-        const auto out_entry       = NodeEntry(node, i, 0);
-        mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0);
+    node_map.emplace(old_node.get(), std::move(new_node));
+  });
+  for (const NodeEntry& old_out_ne : src.outputs) {
+    register_node_entry(old_out_ne, nullptr, nullptr);
+  }
+
+  // convert the model
+  DFSVisit(src.outputs, [&](const ObjectPtr& old_node) {
+    if (old_node->is_variable() || old_node->op() == Op::Get("amp_multicast") 
||
+        IsCastOp(old_node->op())) {
+      const ObjectPtr& new_node = node_map.at(old_node.get());
+      for (const auto& old_ne : old_node->inputs) {
+        const ObjectPtr& new_in_node = node_map.at(old_ne.node.get());
+        new_node->inputs.emplace_back(new_in_node, old_ne.index, 
old_ne.version);
       }
-      for (size_t i = 0; i < node->inputs.size(); ++i) {
-        const auto& node_entry = node->inputs[i];
-        if (mirror_fp32_map.count(node_entry)) {
-          new_node->inputs.emplace_back(mirror_fp32_map[node_entry]);
-        } else if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) 
!=
-                   mutable_inputs.end()) {
-          // Can't insert amp_cast for this inputs. Such op have to handle 
fp32 inputs itself.
-          ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
-          new_node->inputs.emplace_back(mirror_node, node_entry.index, 
node_entry.version);
-        } else {
-          ObjectPtr mirror_node  = mirror_map.at(node_entry.node.get());
-          NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, 
node_entry.version};
-          std::string suffix     = GetSuffix(node_entry, mirror_map);
-          AddCastNode(node_entry, suffix, mirror_entry, "float32", 
&mirror_fp32_map, new_node);
-        }
+      return;
+    }
+
+    if (fp32_ops.count(old_node->op()->name) > 0 || 
excluded_syms.count(old_node->attrs.name) > 0) {
+      KeepOriginalNode(old_node, node_map, &entry_map);
+    } else if (target_dtype_ops.count(old_node->op()->name) > 0) {
+      if (!TryLowPrecision(target_dtype, old_node, node_map, nodes_entries, 
&entry_map)) {
+        LOG(WARNING) << "Conversion to low precision of a node: " + 
old_node->attrs.name +

Review comment:
       in the other places where TryLowPrecision is used there is ni warning 
message, is it intentional. If no, we can put the log within the 
TryLowPrecision function




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to