PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r820851332



##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,322 @@
 #include <mxnet/base.h>
 #include <algorithm>
 #include <functional>
+#include "../operator/operator_common.h"
 
 namespace mxnet {
 using nnvm::Graph;
 using nnvm::Node;
 using nnvm::NodeEntry;
 using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
-  ObjectPtr node   = Node::Create();
-  node->attrs.name = node_name;
-  if (op_name == "nullptr") {
-    node->attrs.op = nullptr;
-    // ugly workaround because VariableParam is not exposed
-    node->attrs.parsed =
-        
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
-  } else {
-    node->attrs.op = Op::Get(op_name);
-  }
-  return node;
-}
 
-static ObjectPtr InsertNode(std::string op_name,
-                            std::string node_name,
-                            ObjectPtr current,
-                            NodeEntry previous) {
-  ObjectPtr node = CreateNode(op_name, node_name);
-  node->inputs.emplace_back(previous);
-  if (current)
-    current->inputs.emplace_back(NodeEntry{node, 0, 0});
-  return node;
+bool is_cast_op(const nnvm::Op* const op) {
+  return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
 }
 
-// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
-                             const std::unordered_map<Node*, ObjectPtr>& 
mirror_map) {
-  static const auto& flist_outputs = 
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
-  std::string suffix               = "";
-  ObjectPtr mirror_node            = mirror_map.at(node_entry.node.get());
-  if (mirror_node->op() != nullptr) {
-    auto list_output_names_func = flist_outputs.get(node_entry.node->op(), 
nullptr);
-    if (list_output_names_func != nullptr) {
-      std::vector<std::string> names = 
list_output_names_func(node_entry.node->attrs);
-      suffix                         = "_" + names[node_entry.index];
-    } else {
-      suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
+ public:
+  MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+      : entry(std::move(node_entry)), original_dtype(original_dtype) {
+    dtype = original_dtype;
+  }
+
+  void convert(const int new_dtype) {
+    CHECK_EQ(dtype, original_dtype);  // dtype should be changed only once
+    dtype = new_dtype;
+  }
+
+  const NodeEntry& as_original() {
+    return as_type(original_dtype);
+  }
+
+  const NodeEntry& as_type(const int target_dtype) {
+    if (dtype == target_dtype) {
+      return entry;
+    }
+    NodeEntry& cast_entry = casts[target_dtype];
+    if (cast_entry.node == nullptr) {
+      cast_entry = cast(target_dtype);
+      CHECK(cast_entry.node);
     }
+    return cast_entry;
   }
-  return suffix;
-}
 
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
-                        const std::string& suffix,
-                        const nnvm::NodeEntry& input,
-                        const std::string dtype,
-                        nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
-                        ObjectPtr curr_node) {
-  ObjectPtr cast_node =
-      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype, curr_node, input);
-  cast_node->attrs.dict["dtype"] = dtype;
-  cast_node->op()->attr_parser(&(cast_node->attrs));
-  (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
-  return;
-}
+  bool has_dtype_entry(const int target_dtype) const {
+    return dtype == target_dtype || casts.count(target_dtype) > 0;
+  }
+
+  bool can_be_cast_offline_to(const int target_dtype) const {

Review comment:
       `casts` is an unordered map, thus the count returns whether the cast to 
this dtype exists or not. A parameter should be cast offline only if there is a 
need for it - there is at least one operator using it through a cast.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to