reminisce closed pull request #11907: [DO NOT REVIEW] Fix bug of eliminating
cycles
URL: https://github.com/apache/incubator-mxnet/pull/11907
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/subgraph/partition_graph.cc
b/src/operator/subgraph/partition_graph.cc
index 0546430dc08..9672877eb1d 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -147,9 +147,12 @@ bool LabelSubgraph(const Graph& g,
// key: nodes that serve as input/output nodes to the subgraph
// value: pair of vectors of nodes in the subgraph. The first vector
contains the
// output nodes of the key in the subgraph, and the second vector contains
the
- // input ndoes of the key in the subgraph. If both vectors are non-empty,
- // it means there is a loop between the subgraph and the key node.
- // When breaking the loop, we want to start removing the node with the
largest node id.
+ // input nodes of the key in the subgraph.
+ // If a non-subgraph node has inputs from the subgraph and the other
non-subgraph node
+ // has outputs to the subgraph, and the first non-subgraph node is an
ancestor
+ // of the second non-subgraph node, there exits a cycle.
+ // When breaking the cycle, we want to start from removing the node with the
largest node id
+ // in the subgraph.
std::unordered_map<const nnvm::Node*,
std::pair<std::vector<const nnvm::Node*>,
std::vector<const nnvm::Node*>>> non_subgraph_node_map;
@@ -194,23 +197,75 @@ bool LabelSubgraph(const Graph& g,
}
}
}
+ // prepare to check if there is a cycle
auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
};
- // check whether there is a loop between the subgraph and its input/output
nodes
- int excluded_node_id = -1;
+ std::vector<const nnvm::Node*> non_subgraph_nodes;
+ non_subgraph_nodes.reserve(non_subgraph_node_map.size());
for (auto& kv : non_subgraph_node_map) {
auto& output_nodes = kv.second.first;
+ std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
auto& input_nodes = kv.second.second;
+ std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ non_subgraph_nodes.push_back(kv.first);
+ }
+ // check whether there is a cycle between the subgraph and its input/output
nodes
+ auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node*
descendant,
+ const std::vector<nnvm::Node*>& snodes) {
+ if (ancestor == descendant) return true;
+ std::stack<const nnvm::Node*> s;
+ s.push(descendant);
+ size_t count = 0;
+ while (!s.empty()) {
+ CHECK_LT(count, indexed_graph.num_nodes()) << "Finding ancestor failed.
There is probably"
+ " a loop in the graph";
+ ++count;
+ const nnvm::Node* top = s.top();
+ s.pop();
+ if (top == ancestor) {
+ return true;
+ }
+ for (const auto& entry : top->inputs) {
+ // when searching for the ancestor, the path cannot cross any subgraph
node
+ auto it = std::find(snodes.begin(), snodes.end(), entry.node.get());
+ if (it == snodes.end()) {
+ s.push(entry.node.get());
+ }
+ }
+ }
+ return false;
+ };
+ std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
+ int excluded_node_id = -1;
+ for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
+ auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
+ CHECK(it1 != non_subgraph_node_map.end());
+ auto& output_nodes = it1->second.first; // has been top sorted
+ auto& input_nodes = it1->second.second; // has been top sorted
if (!output_nodes.empty() && !input_nodes.empty()) {
- // there is a loop between kv->first and the subgraph
- std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
- std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+ // there is a loop between node i and the subgraph
const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
indexed_graph.node_id(input_nodes.back()));
excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+ } else if (!input_nodes.empty()) {
+ // node i is an input to the subgraph, find out if there is a node j
+ // which is an output of the subgraph and also a child of node i.
+ for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
+ auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
+ CHECK(it2 != non_subgraph_node_map.end());
+ // i is topologically before j, j might be a direct/indirect output
node of i
+ CHECK_LT(indexed_graph.node_id(it1->first),
indexed_graph.node_id(it2->first));
+ if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first,
*subgraph_nodes)) {
+ // found a loop
+ const auto node_id =
std::max(indexed_graph.node_id(input_nodes.back()),
+
indexed_graph.node_id(it2->second.first.back()));
+ excluded_node_id = std::max(excluded_node_id,
static_cast<int>(node_id));
+ }
+ }
}
}
+
if (excluded_node_id != -1) {
CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
CHECK_NE(excluded_node_id, static_cast<int>(snid))
diff --git a/tests/python/unittest/test_subgraph_op.py
b/tests/python/unittest/test_subgraph_op.py
index c3b408b9027..f6a33c244a7 100644
--- a/tests/python/unittest/test_subgraph_op.py
+++ b/tests/python/unittest/test_subgraph_op.py
@@ -114,12 +114,23 @@ def get_graph():
for sym, op_names in get_graph():
check_subgraph_exe(sym, op_names)
+ def test_network_structure_7():
+ # in this graph, the subgraph node and the other two external nodes
form a cycle
+ data = mx.sym.Variable('data', shape=(1,))
+ ret1 = mx.sym.sin(data)
+ ret2 = mx.sym.cos(ret1)
+ for _ in range(5):
+ ret2 = mx.sym.cos(ret2)
+ ret = ret1 + ret2
+ check_subgraph_exe(ret, ['sin', 'elemwise_add', '_plus', '_Plus'])
+
test_network_structure_1()
test_network_structure_2()
test_network_structure_3()
test_network_structure_4()
test_network_structure_5()
test_network_structure_6()
+ test_network_structure_7()
if __name__ == '__main__':
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services