PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r824646210
##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,328 @@
#include <mxnet/base.h>
#include <algorithm>
#include <functional>
+#include "../operator/operator_common.h"
namespace mxnet {
using nnvm::Graph;
using nnvm::Node;
using nnvm::NodeEntry;
using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
- ObjectPtr node = Node::Create();
- node->attrs.name = node_name;
- if (op_name == "nullptr") {
- node->attrs.op = nullptr;
- // ugly workaround because VariableParam is not exposed
- node->attrs.parsed =
-
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
- } else {
- node->attrs.op = Op::Get(op_name);
- }
- return node;
-}
-static ObjectPtr InsertNode(std::string op_name,
- std::string node_name,
- ObjectPtr current,
- NodeEntry previous) {
- ObjectPtr node = CreateNode(op_name, node_name);
- node->inputs.emplace_back(previous);
- if (current)
- current->inputs.emplace_back(NodeEntry{node, 0, 0});
- return node;
+bool IsCastOp(const nnvm::Op* const op) {
+ return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
}
-// get suffix for a node entry so that it can be used for
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
- const std::unordered_map<Node*, ObjectPtr>&
mirror_map) {
- static const auto& flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
- std::string suffix = "";
- ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
- if (mirror_node->op() != nullptr) {
- auto list_output_names_func = flist_outputs.get(node_entry.node->op(),
nullptr);
- if (list_output_names_func != nullptr) {
- std::vector<std::string> names =
list_output_names_func(node_entry.node->attrs);
- suffix = "_" + names[node_entry.index];
- } else {
- suffix = "_" + std::to_string(node_entry.index);
+class MappedNodeEntry {
Review comment:
Can you check if the description is clear to you?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]