PawelGlomski-Intel commented on a change in pull request #20753:
URL: https://github.com/apache/incubator-mxnet/pull/20753#discussion_r836275200
##########
File path: src/nnvm/low_precision_pass.cc
##########
@@ -29,374 +29,371 @@
#include <mxnet/base.h>
#include <algorithm>
#include <functional>
+#include "operator/operator_common.h"
namespace mxnet {
using nnvm::Graph;
using nnvm::Node;
using nnvm::NodeEntry;
using nnvm::ObjectPtr;
-using nnvm::Symbol;
-
-// create a node for operator : op_name with name : node_name
-static ObjectPtr CreateNode(std::string op_name, std::string node_name) {
- ObjectPtr node = Node::Create();
- node->attrs.name = node_name;
- if (op_name == "nullptr") {
- node->attrs.op = nullptr;
- // ugly workaround because VariableParam is not exposed
- node->attrs.parsed =
-
nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed;
- } else {
- node->attrs.op = Op::Get(op_name);
- }
- return node;
-}
-static ObjectPtr InsertNode(std::string op_name,
- std::string node_name,
- ObjectPtr current,
- NodeEntry previous) {
- ObjectPtr node = CreateNode(op_name, node_name);
- node->inputs.emplace_back(previous);
- if (current)
- current->inputs.emplace_back(NodeEntry{node, 0, 0});
- return node;
+bool IsCastOp(const nnvm::Op* const op) {
+ return op && (op == Op::Get("amp_cast") || op == Op::Get("Cast"));
}
+/*!
+ * \brief Before the model conversion, node entries of the original graph are
mapped to the
+ * equivalent node entries in the new graph that will be then converted to a
mixed precision graph.
+ * This class wraps a mapped NodeEntry from the new graph, providing a
transparent interface for
+ * acquiring versions of the wrapped entry with a specific dtype, adding a
casting nodes to the
+ * graph when needed (one for each unique dtype that was requested).
+ */
+class MappedNodeEntry {
+ public:
+ MappedNodeEntry(NodeEntry node_entry, const int original_dtype)
+ : entry(std::move(node_entry)), original_dtype(original_dtype) {
+ dtype = original_dtype;
+ }
-// get suffix for a node entry so that it can be used for
amp_cast/amp_multicast node name
-static std::string GetSuffix(const nnvm::NodeEntry& node_entry,
- const std::unordered_map<Node*, ObjectPtr>&
mirror_map) {
- static const auto& flist_outputs =
nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
- std::string suffix = "";
- ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
- if (mirror_node->op() != nullptr) {
- auto list_output_names_func = flist_outputs.get(node_entry.node->op(),
nullptr);
- if (list_output_names_func != nullptr) {
- std::vector<std::string> names =
list_output_names_func(node_entry.node->attrs);
- suffix = "_" + names[node_entry.index];
- } else {
- suffix = "_" + std::to_string(node_entry.index);
+ /*!
+ * \brief Converts the dtype of this NodeEntry. This should be called after
a node has been
+ * converted and dtypes of its outputs may have changed
+ */
+ void UpdateDTypeAfterConversion(const int new_dtype) {
+ CHECK_EQ(dtype, original_dtype); // dtype should be changed only once
+ CHECK(entry.node->op());
+ dtype = new_dtype;
+ }
+
+ /*!
+ * \brief If dtype of this NodeEntry was not changed, returns the mapped
entry. Otherwise returns
+ * a NodeEntry to the node which casts to the original dtype of this
NodeEntry.
+ */
+ const NodeEntry& AsOriginal() {
+ return AsType(original_dtype);
+ }
+
+ /*!
+ * \brief If dtype of this NodeEntry matches the specified dtype, returns
the mapped entry.
+ * Otherwise returns a NodeEntry to the node which casts to that type.
+ */
+ const NodeEntry& AsType(const int target_dtype, const bool can_add_cast =
true) {
+ if (dtype == target_dtype || target_dtype == -1) {
+ return entry;
+ }
+ NodeEntry& cast_entry = casts[target_dtype];
+ if (cast_entry.node == nullptr) {
+ CHECK(can_add_cast);
+ cast_entry = Cast(target_dtype);
+ CHECK(cast_entry.node);
}
+ return cast_entry;
+ }
+
+ /*! \brief Returns whether this entry has the specified dtype or an existing
cast to that dtype */
+ bool HasDTypeEntry(const int target_dtype) const {
+ return dtype == target_dtype || casts.count(target_dtype) > 0;
+ }
+
+ /*!
+ * \brief Returns whether this entry can be cast to a specific dtype. This
should be called on
+ * input entires of a node before its conversion.
+ */
+ bool CanBeCastTo(const int target_dtype) {
+ static const auto& amp_cast_op = Op::Get("amp_cast");
+ static const auto& infertype =
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType")[amp_cast_op];
+ nnvm::NodeAttrs dummy_atts;
+ dummy_atts.dict["dtype"] = mxnet::op::type_string(target_dtype);
+ amp_cast_op->attr_parser(&dummy_atts);
+
+ std::vector<int> in_types = {dtype};
+ std::vector<int> out_types = {-1};
+ return infertype(dummy_atts, &in_types, &out_types);
+ }
+
+ /*! \brief Returns whether this NodeEntry (of a parameter) can be cast
offline */
+ bool CanBeCastOfflineTo(const int target_dtype) const {
+ CHECK(entry.node->is_variable());
+ return casts.count(target_dtype) > 0;
+ }
+
+ private:
+ ObjectPtr CreateCastNode(const std::string& op_name, const std::string&
node_name) {
+ CHECK_GT(op_name.size(), 0);
+
+ ObjectPtr node = Node::Create();
+ node->attrs.name = node_name;
+ node->attrs.op = Op::Get(op_name);
+ node->inputs.emplace_back(entry);
+ return node;
+ }
+
+ NodeEntry Cast(const int target_dtype) {
+ CHECK(CanBeCastTo(target_dtype));
+
+ const std::string dt_name = mxnet::op::type_string(target_dtype);
+ const std::string suffix = "_" + std::to_string(entry.index);
+ const std::string cast_node_name = entry.node->attrs.name + suffix +
"_amp_cast_" + dt_name;
+ ObjectPtr cast_node = CreateCastNode("amp_cast",
cast_node_name);
+ cast_node->attrs.dict["dtype"] = dt_name;
+ cast_node->op()->attr_parser(&(cast_node->attrs));
+ return NodeEntry{std::move(cast_node), 0, 0};
+ }
+
+ public:
+ const NodeEntry entry;
+ const int original_dtype; // original dtype of the entry
+
+ private:
+ int dtype; // current dtype of the entry
+ std::unordered_map<int, NodeEntry> casts;
+};
+
+using EntryMap_t = nnvm::NodeEntryMap<MappedNodeEntry>;
+using NodeMap_t = std::unordered_map<Node*, ObjectPtr>;
+using NodeEntrySet_t = std::unordered_set<NodeEntry, nnvm::NodeEntryHash,
nnvm::NodeEntryEqual>;
+using NodesEntries_t = std::unordered_map<Node*, NodeEntrySet_t>;
+using DstNodes_t = std::unordered_map<Node*, std::unordered_map<Node*,
NodeEntry>>;
+
+/*! \brief Makes sure the node in the new graph will work with the same
precision as in the original
+ * graph */
+static void KeepOriginalNode(const ObjectPtr& old_node,
+ const NodeMap_t& node_map,
+ EntryMap_t* const entry_map) {
+ const ObjectPtr& new_node = node_map.at(old_node.get());
+ for (const auto& old_ne : old_node->inputs) {
+ new_node->inputs.push_back(entry_map->at(old_ne).AsOriginal());
}
- return suffix;
}
-// add amp_cast node between curr_node and input
-static void AddCastNode(const nnvm::NodeEntry& e,
- const std::string& suffix,
- const nnvm::NodeEntry& input,
- const std::string dtype,
- nnvm::NodeEntryMap<NodeEntry>* mirror_entry_map,
- ObjectPtr curr_node) {
- ObjectPtr cast_node =
- InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" +
dtype, curr_node, input);
- cast_node->attrs.dict["dtype"] = dtype;
- cast_node->op()->attr_parser(&(cast_node->attrs));
- (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version};
- return;
+/*! \brief Tries to convert the node to low precision. Returns whether the
node has been
+ * successfully converted
+ */
+static bool TryLowPrecision(const int target_dtype,
+ const ObjectPtr& old_node,
+ const NodeMap_t& node_map,
+ const NodesEntries_t& nodes_entries,
+ EntryMap_t* const entry_map) {
+ static const auto& infertype =
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+ static const auto& fmutate_inputs =
Op::GetAttr<nnvm::FMutateInputs>("FMutateInputs");
+
+ std::vector<int> in_types(old_node->inputs.size(), -1);
+ std::vector<int> out_types(old_node->num_outputs(), -1);
+ in_types[0] = target_dtype;
+ if (infertype.count(old_node->op()) == 0 ||
+ infertype[old_node->op()](old_node->attrs, &in_types, &out_types) ==
false) {
+ return false;
+ }
+
+ if (fmutate_inputs.count(old_node->op()) != 0) {
+ std::vector<uint32_t> mutable_inputs =
fmutate_inputs[old_node->op()](old_node->attrs);
+ for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+ if (in_types[i] == target_dtype) {
+ if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) !=
mutable_inputs.end()) {
+ return false;
+ }
+ }
+ }
+ }
+
+ for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+ MappedNodeEntry& mapped_ne = entry_map->at(old_node->inputs[i]);
+ // if this tensor needs a cast, check whether MappedNodeEntry can actually
cast it
+ if (in_types[i] != -1 && !mapped_ne.HasDTypeEntry(in_types[i]) &&
+ !mapped_ne.CanBeCastTo(in_types[i])) {
+ return false;
+ }
+ }
+
+ const ObjectPtr& new_node = node_map.at(old_node.get());
+ for (size_t i = 0; i < old_node->inputs.size(); ++i) {
+
new_node->inputs.push_back(entry_map->at(old_node->inputs[i]).AsType(in_types[i]));
+ }
+
+ for (const NodeEntry& old_ne : nodes_entries.at(old_node.get())) {
+ entry_map->at(old_ne).UpdateDTypeAfterConversion(out_types[old_ne.index]);
+ }
+
+ return true;
}
-// add amp_multicast node between curr_node and inputs
-static void AddMultiCastNode(const std::vector<NodeEntry>& inputs,
- const std::string& node_name,
- const std::unordered_map<Node*, ObjectPtr>&
mirror_map,
- ObjectPtr curr_node) {
- ObjectPtr node =
- CreateNode("amp_multicast", inputs[0].node->attrs.name + node_name +
"_amp_multicast");
- for (const auto& node_entry : inputs) {
- ObjectPtr mirror_node = mirror_map.at(node_entry.node.get());
- NodeEntry mirror_entry =
- NodeEntry{std::move(mirror_node), node_entry.index,
node_entry.version};
- node->inputs.emplace_back(mirror_entry);
+/*! \brief Tries to convert the node to low precision if all of its inputs
already have the correct
+ * dtype. Otherwise keeps the node unchanged.
+ */
+static void HandleWidestDtypeNode(const int target_dtype,
+ const ObjectPtr& old_node,
+ const NodeMap_t& node_map,
+ const NodesEntries_t& nodes_entries,
+ EntryMap_t* const entry_map) {
+ static const auto& infertype =
nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
+
+ std::vector<int> in_types(old_node->inputs.size(), target_dtype);
+ std::vector<int> out_types(old_node->num_outputs(), -1);
+ const bool inferred = (infertype.count(old_node->op()) > 0 &&
+ infertype[old_node->op()](old_node->attrs, &in_types,
&out_types));
+
+ bool has_lp_inputs = inferred;
+ for (int i = 0; has_lp_inputs && i < old_node->inputs.size(); ++i) {
+ const NodeEntry& input = old_node->inputs[i];
+ has_lp_inputs &= entry_map->at(input).HasDTypeEntry(in_types[i]);
+ }
+
+ if (!has_lp_inputs ||
+ !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries,
entry_map)) {
+ KeepOriginalNode(old_node, node_map, entry_map);
}
- node->attrs.dict["num_outputs"] = std::to_string(inputs.size());
- node->op()->attr_parser(&(node->attrs));
- for (uint32_t i = 0; i < inputs.size(); ++i) {
- const auto& e = inputs[i];
- curr_node->inputs.emplace_back(NodeEntry{node, static_cast<uint32_t>(i),
e.version});
+}
+/*!
+ * \brief Tries to convert the node to low precision if some of its inputs
already are converted.
+ * Otherwise keeps the node unchanged.
+ */
+void HandleDTypeNeutralNode(const int target_dtype,
+ const ObjectPtr& old_node,
+ const NodeMap_t& node_map,
+ const NodesEntries_t& nodes_entries,
+ EntryMap_t* const entry_map) {
+ const auto& is_lp = [&](const auto& old_ne) {
+ return entry_map->at(old_ne).HasDTypeEntry(target_dtype);
+ };
+ if (!std::any_of(old_node->inputs.begin(), old_node->inputs.end(), is_lp) ||
+ !TryLowPrecision(target_dtype, old_node, node_map, nodes_entries,
entry_map)) {
+ KeepOriginalNode(old_node, node_map, entry_map);
}
- return;
}
-static bool CheckConditionalFP32(
- const std::unordered_map<std::string,
- std::unordered_map<std::string,
std::vector<std::string>>>&
- conditional_fp32_ops,
- const std::unordered_set<std::string>& excluded_syms,
- ObjectPtr node) {
- if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) ||
- conditional_fp32_ops.count(node->op()->name) == 0) {
- return false;
- } else {
- // Iterate through all conditional ops
- auto it = conditional_fp32_ops.find(node->op()->name);
- if (it != conditional_fp32_ops.end()) {
- auto it_params = it->second;
- // For each param name, iterate through param values to check
- // if the provided param name is equal to any of the values
- for (auto& it_param : it_params) {
- auto param_key = node->attrs.dict.find(it_param.first);
- if (param_key != node->attrs.dict.end()) {
- auto it_param_vals = it_param.second;
- if (std::find(it_param_vals.begin(), it_param_vals.end(),
param_key->second) !=
- it_param_vals.end()) {
- return true;
+/* \brief Decides which prameters can be cast offline and removes redundant
cast nodes from the
+ * graph */
+static void RemoveParamCasts(const int target_dtype,
+ const std::string& offline_param_cast_attr,
+ const NodeMap_t& node_map,
+ const DstNodes_t& old_param_dst_nodes,
+ EntryMap_t* entry_map) {
+ for (const auto& [old_param, old_param_dsts] : old_param_dst_nodes) {
+ const ObjectPtr& new_param = node_map.at(old_param);
+ const auto& can_be_cast_offline = [&](const auto& old_node_x_ne_pair) {
Review comment:
`old_node_x_ne_pair` is an instance of `std::pair<Node*, NodeEntry>`. I
will try to improve this part.
> i would add more indent in lambda body
I don't think I can - the formatter picks the indent.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]