PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824645133



##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool IsCastOp(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  // Converts the dtype of this NodeEntry. This should be called after a node 
has been converted and
+  // dtypes of its outputs may have changed
+  void UpdateDTtypeAfterConversion(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    CHECK(entry.node->op());
+    dtype = new_dtype;
+  }
+
+  // If dtype of this NodeEntry was not changed, returns the mapped entry. 
Otherwise returns a
+  // NodeEntry to the node which casts to the original dtype of this NodeEntry.
+  const NodeEntry& AsOriginal() {
+    return AsType(original_dtype);
+  }
+
+  // If dtype of this NodeEntry matches the specified dtype, returns the 
mapped entry. Otherwise
+  // returns a NodeEntry to the node which casts to that type.
+  const NodeEntry& AsType(const int target_dtype, const bool can_add_cast = 
true) {
+    if (dtype == target_dtype) {
+      return entry;
     }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      CHECK(can_add_cast);
+      cast_entry = Cast(target_dtype);
+      CHECK(cast_entry.node);
+    }
+    return cast_entry;
+  }
+
+  // Returns whether this NodeEntry has the specified dtype or an existing 
cast to that dtype
+  bool HasDTypeEntry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  // Returns whether this NodeEntry (of a parameter) can be cast offline
+  bool CanBeCastOfflineTo(const int target_dtype) const {
+    CHECK(entry.node->is_variable());
+    return casts.count(target_dtype) > 0;
+  }
+
+ private:
+  ObjectPtr CreateCastNode(const std::string& op_name, const std::string& 
node_name) {
+    CHECK_GT(op_name.size(), 0);
+
+    ObjectPtr node   = Node::Create();
+    node->attrs.name = node_name;
+    node->attrs.op   = Op::Get(op_name);
+    node->inputs.emplace_back(entry);
+    return node;
+  }
+
+  NodeEntry Cast(const int new_dtype) {
+    CHECK(new_dtype == nnvm::kBfloat16 || new_dtype == nnvm::kFloat16 ||
+          new_dtype == nnvm::kFloat32);
+
+    const std::string dt_name        = mxnet::op::type_string(new_dtype);
+    const std::string suffix         = "_" + std::to_string(entry.index);
+    const std::string cast_node_name = entry.node->attrs.name + suffix + 
"_amp_cast_" + dt_name;
+    ObjectPtr cast_node              = CreateCastNode("amp_cast", 
cast_node_name);
+    cast_node->attrs.dict["dtype"]   = dt_name;
+    cast_node->op()->attr_parser(&(cast_node->attrs));
+    return NodeEntry{std::move(cast_node), 0, 0};
+  }
+
+ public:
+  const NodeEntry entry;
+  const int original_dtype;  // original dtype of the entry
+
+ private:
+  int dtype;  // current dtype of the entry
+  std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t     = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t      = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash, 
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t     = std::unordered_map<Node*, std::unordered_map<Node*, 
NodeEntry>>;
+
+// Makes sure the node in the new graph will work with the same precision as 
in the original graph
+static void KeepOriginalNode(const ObjectPtr& old_node,
+                             const NodeMap_t& node_map,
+                             EntryMap_t* const entry_map) {
+  const ObjectPtr& new_node = node_map.at(old_node.get());
+  for (const auto& old_ne : old_node->inputs) {
+    new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
   }
-  return suffix;
 }
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
+// Tries to convert the node to low precision. Returns whether the node has 
been successfully
+// converted
+static bool TryLowPrecision(const int target_dtype,
+                            const ObjectPtr& old_node,
+                            const NodeMap_t& node_map,
+                            const NodesEntries_t& nodes_entries,
+                            EntryMap_t* const entry_map) {
+  static auto& infertype      = 
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+  static auto& fmutate_inputs = 
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+  std::vector<int> in_types(old_node->inputs.size(), -1);
+  std::vector<int> out_types(old_node->num_outputs(), -1);
+  in_types[0] = target_dtype;
+
+  if (infertype.count(old_node->op()) == 0 ||
+      infertype[old_node->op()](old_node->attrs, &in_types, &out_types) == 
false) {
+    return false;
+  }
+
+  if (fmutate_inputs.count(old_node->op()) != 0) {
+    std::vector<uint32_t> mutable_inputs = 
fmutate_inputs[old_node->op()](old_node->attrs);
+    for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+      if (in_types[i] == target_dtype) {
+        if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != 
mutable_inputs.end()) {
+          return false;
+        }
+      }
+    }
+  }

Review comment:
       Me too 🤣 I just left this since this was already here. I couldn't find 
any explanation as to why we have to do this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to