zheng-da commented on a change in pull request #11251: [WIP] Graph partitioner 
and subgraph op
URL: https://github.com/apache/incubator-mxnet/pull/11251#discussion_r196630557
 
 

 ##########
 File path: src/operator/subgraph/partition_graph.cc
 ##########
 @@ -0,0 +1,688 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2018 by Contributors
+ * \file partition_graph.cc
+ * \brief
+ */
+#include <queue>
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <mxnet/op_attr_types.h>
+#include <unordered_set>
+#include <stack>
+
+#include "./default_subgraph_op.h"
+#include "./common.h"
+
+namespace nnvm {
+NodePtr CreateVariableNode(const std::string& name);
+}
+
+namespace mxnet {
+
+namespace op {
+
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// TODO(junwu): Change this to 0
+#define SUBGRAPH_DEBUG 1
+
+namespace sg {  // sg stands for subgraph
+
+#if SUBGRAPH_DEBUG
+void PrintSubgraph(const std::vector<SimpleNode*>& simple_nodes) {
+  std::string op_names = "";
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    op_names += simple_nodes[i]->node->attrs.name + ' ';
+  }
+  LOG(INFO) << "Subgraph node names: " << op_names;
+}
+
+void PrintNodeEntry(const nnvm::NodeEntry& entry) {
+  std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
+    + ", index=" + std::to_string(entry.index) + ", version=" + 
std::to_string(entry.version);
+  LOG(INFO) << ret;
+}
+
+void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
+  for (size_t i = 0; i < entries.size(); ++i) {
+    PrintNodeEntry(*entries[i]);
+  }
+}
+#endif
+
+/*!
+ * \brief Given a MXNet computational graph, create an undirected graph from 
it.
+ * \param g the MXNet computational graph
+ * \param simple_nodes the nodes of undirected graph in top sorted order
+ */
+void CreateSimpleGraph(const Graph& g,
+                       std::vector<SimpleNodePtr>* simple_nodes) {
+  const auto& indexed_graph = g.indexed_graph();
+  simple_nodes->reserve(indexed_graph.num_nodes());
+  DFSVisit(g.outputs, [&](const NodePtr& node) {
+    SimpleNodePtr sn = SimpleNode::Create();
+    sn->node = node.get();
+    for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
+      const auto& e = sn->node->inputs[i];
+      const auto input_nid = indexed_graph.node_id(e.node.get());
+      CHECK_LT(input_nid, simple_nodes->size());
+      auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
+      auto it = input_node_outputs.find(sn->node);
+      if (it == input_node_outputs.end()) {
+        input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
+      } else {
+        it->second.push_back(i);
+      }
+    }
+    simple_nodes->emplace_back(std::move(sn));
+  });
+}
+
+/*!
+ * \brief Reset labels of the subgraph nodes to the original state
+ * and clear the vector of subgraph nodes.
+ */
+void ResetNodeLabels(const nnvm::Graph& g,
+                     const std::vector<SimpleNodePtr>& simple_nodes,
+                     std::vector<nnvm::Node*>* subgraph_nodes) {
+  for (auto n : *subgraph_nodes) {
+    const auto nid = g.indexed_graph().node_id(n);
+    simple_nodes[nid]->label = -1;
+  }
+  subgraph_nodes->clear();
+}
+
+/*!
+ * \brief This function traverses the nodes in a computation graph from a 
starting
+ * node following the input edges and output edges, and marks all nodes that
+ * can be accessed from the starting node. Before the function returns,
+ * it will conduct checking whether there is a loop between the potential 
subgraph
+ * and the outside nodes. If so, add the node that should break the loop
+ * in excluded_nodes and return false. Otherwise, return true.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or 
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \excluded_nodes set of nodes that should be excluded from the current 
subgraph
+ */
+bool LabelSubgraph(const Graph& g,
+                   SubgraphSelectorPtr subgraph_selector,
+                   const int label,
+                   const size_t snid,  // simple node id, this is a seed
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<nnvm::Node*>* subgraph_nodes,
+                   std::unordered_set<const nnvm::Node*>* excluded_nodes = 
nullptr) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::queue<SimpleNode*> node_queue;
+  if (!excluded_nodes || !excluded_nodes->count(simple_nodes[snid]->node)) {
+    CHECK_EQ(simple_nodes[snid]->label, -1);
+    simple_nodes[snid]->label = label;
+    node_queue.push(simple_nodes[snid].get());
+  }
+  // key: nodes that serve as input/output nodes to the subgraph
+  // value: pair of vectors of nodes in the subgraph. The first vector 
contains the
+  // output nodes of the key in the subgraph, and the second vector contains 
the
+  // input ndoes of the key in the subgraph. If both vectors are non-empty,
+  // it means there is a loop between the subgraph and the key node.
+  // When breaking the loop, we want to start removing the node with the 
largest node id.
+  std::unordered_map<const nnvm::Node*,
+    std::pair<std::vector<const nnvm::Node*>,
+              std::vector<const nnvm::Node*>>> non_subgraph_node_map;
+  while (!node_queue.empty()) {
+    SimpleNode* cur_node = node_queue.front();
+    node_queue.pop();
+    //cur_node->label = label;
+    subgraph_nodes->push_back(cur_node->node);
+    // get qualified adjacent input nodes
+    for (auto& e : cur_node->node->inputs) {
+      const bool select_input = (!excluded_nodes || 
!excluded_nodes->count(e.node.get()))
+        && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+      if (select_input) {
+        // e.node is a subgraph node
+        const auto nid = indexed_graph.node_id(e.node.get());
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // e.node is an input node of the subgraph
+        non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
+      }
+    }
+    // get qualified output nodes
+    for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); 
++it) {
+      const bool select_output = (!excluded_nodes || 
!excluded_nodes->count(it->first))
+          && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+      if (select_output) {
+        // it->first is a subgraph node
+        const auto nid = indexed_graph.node_id(it->first);
+        CHECK_LT(nid, simple_nodes.size());
+        // this node has not been visited yet
+        if (simple_nodes[nid]->label == -1) {
+          simple_nodes[nid]->label = label;
+          node_queue.push(simple_nodes[nid].get());
+        }
+      } else {
+        // it->first is an output node of the subgraph
+        non_subgraph_node_map[it->first].second.push_back(cur_node->node);
+      }
+    }
+  }
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  // check whether there is a loop between the subgraph and its input/output 
nodes
+  int excluded_node_id = -1;
+  for (auto& kv : non_subgraph_node_map) {
+    auto& output_nodes = kv.second.first;
+    auto& input_nodes = kv.second.second;
+    if (!output_nodes.empty() && !input_nodes.empty()) {
+      // there is a loop between kv->first and the subgraph
+      std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
+      std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
+      const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
+                                    indexed_graph.node_id(input_nodes.back()));
+      excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
+    }
+  }
+  if (excluded_node_id != -1) {
+    CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
+    CHECK_NE(excluded_node_id, static_cast<int>(snid))
+      << "A cycle is found in the computational graph between nodes "
+      << simple_nodes[excluded_node_id]->node->attrs.name << " and "
+      << simple_nodes[snid]->node->attrs.name;
+    excluded_nodes->insert(simple_nodes[excluded_node_id]->node);
+    ResetNodeLabels(g, simple_nodes, subgraph_nodes);
+    return false;
+  }
+  std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), node_cmp);
+  return true;
+}
+
+/*!
+ * \brief Finds all the nodes belonging to the same subgraph given a seed node.
+ * \param g the whole graph
+ * \subgraph_selector determines whether the visited node should be choosen or 
not
+ * \label the label of the current subgraph
+ * \snid node id of the seed simple node
+ * \simple_nodes all simple nodes in the top sorted order
+ * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
+ * \return Subgraph node candidates sorted in the topological order
+ */
+void PreSelectSubgraphNodes(const Graph& g,
+                            SubgraphSelectorPtr subgraph_selector,
+                            const int label,
+                            const size_t snid,
+                            const std::vector<SimpleNodePtr>& simple_nodes,
+                            std::vector<nnvm::Node*>* subgraph_nodes) {
+  std::unordered_set<const nnvm::Node*> excluded_nodes;
+  const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
+  size_t count = 0;
+  bool success = false;
+  while (!success && count < max_num_retry) {
+    success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes,
+                            subgraph_nodes, &excluded_nodes);
+    if (!success) {
+      CHECK(!excluded_nodes.empty());
+      std::string excluded_node_names;
+      for (auto node : excluded_nodes) {
+        excluded_node_names += node->attrs.name + ", ";
+      }
+      LOG(INFO) << "Found a cycle when BFS from node " << 
simple_nodes[snid]->node->attrs.name
+                << ". Excluding nodes " << excluded_node_names << "and 
retrying";
+    }
+    ++count;
+  }
+  if (!success) {
+    LOG(INFO) << "Tried " << count << " times of finding subgraphs starting 
from node "
+               << simple_nodes[snid]->node->attrs.name << " without success 
because a loop "
+                  "is always found between the subgraph and some other nodes. 
Will treat "
+                  "seed node " << simple_nodes[snid]->node->attrs.name << "as 
a subgraph with one node";
+    CHECK(subgraph_nodes->empty());
+    simple_nodes[snid]->label = label;
+    subgraph_nodes->push_back(simple_nodes[snid]->node);
+  }
+}
+
+/*!
+ * \brief Given a vector of nodes, group them into individual subgraphs
+ * based upon their connectivity.
+ */
+void PostProcessNodeCandidates(const nnvm::Graph& g,
+                               const std::vector<nnvm::Node*>& nodes,
+                               const std::vector<SimpleNodePtr>& simple_nodes,
+                               std::vector<std::vector<SimpleNode*>>* 
subgraphs,
+                               size_t* subgraph_id) {
+  const auto& indexed_graph = g.indexed_graph();
+  std::unordered_set<nnvm::Node*> node_set(nodes.begin(), nodes.end());
+  auto simple_node_cmp = [&] (const SimpleNode* node1, const SimpleNode* 
node2) {
+    return indexed_graph.node_id(node1->node) < 
indexed_graph.node_id(node2->node);
+  };
+  for (auto node : nodes) {
+    if (!node_set.count(node)) {
+      // The node has been included in a subgraph
+      continue;
+    }
+    std::queue<nnvm::Node*> q;
+    q.push(node);
+    CHECK_EQ(node_set.erase(node), 1U);
+    subgraphs->emplace_back();
+    const auto nid = indexed_graph.node_id(node);
+    simple_nodes[nid]->label = *subgraph_id;
+    subgraphs->back().push_back(simple_nodes[nid].get());
+    while (!q.empty()) {
+      nnvm::Node* cur_node = q.front();
+      q.pop();
+      for (auto& e : cur_node->inputs) {
+        auto in_it = node_set.find(e.node.get());
+        if (in_it != node_set.end()) {
+          q.push(*in_it);
+          const auto in_nid = indexed_graph.node_id(*in_it);
+          simple_nodes[in_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[in_nid].get());
+          node_set.erase(in_it);
+        }
+      }
+      const auto cur_nid = indexed_graph.node_id(cur_node);
+      const SimpleNode* cur_snode = simple_nodes[cur_nid].get();
+      for (const auto& kv : cur_snode->outputs) {
+        const auto out_it = node_set.find(kv.first);
+        if (out_it != node_set.end()) {
+          q.push(*out_it);
+          const auto out_nid = indexed_graph.node_id(*out_it);
+          simple_nodes[out_nid]->label = *subgraph_id;
+          subgraphs->back().push_back(simple_nodes[out_nid].get());
+          node_set.erase(out_it);
+        }
+      }
+    }
+    ++(*subgraph_id);
+    std::sort(subgraphs->back().begin(), subgraphs->back().end(), 
simple_node_cmp);
+  }
+  CHECK(node_set.empty());
+}
+
+/*!
+ * \brief Finds subgraphs with all nodes that meet certain criteria.
+ * All nodes in a subgraph are marked with the same label.
+ */
+void FindSubgraphs(Graph* g,
+                   const SubgraphProperty &subg_prop,
+                   const std::vector<SimpleNodePtr>& simple_nodes,
+                   std::vector<std::vector<SimpleNode*>>* subgraph_nodes) {
+  //CHECK(simple_nodes != nullptr);
+  const auto& indexed_graph = g->indexed_graph();
+  CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
+  auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
+    return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
+  };
+  size_t subgraph_id = 0;
+  for (size_t i = 0; i < simple_nodes.size(); ++i) {
+    nnvm::Node* node = simple_nodes[i]->node;
+    auto subgraph_selector = subg_prop.CreateSubgraphSelector();
+    if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+      // pre-select nodes that can be grouped in a subgraph
+      std::vector<nnvm::Node*> preselected_nodes;
+      PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i, 
simple_nodes,
+                             &preselected_nodes);
+
+      // filter out unqualified pre-selected nodes
+      std::vector<nnvm::Node*> filtered_nodes = preselected_nodes;
+      subgraph_selector->Filter(g, &filtered_nodes);
+
+      // make sure filtered_nodes is a subset of preselected_nodes 
+      for (const auto n : filtered_nodes) {
+        const auto nit = std::find(preselected_nodes.begin(), 
preselected_nodes.end(), n);
+        CHECK(nit != preselected_nodes.end())
+          << "Node " << n->attrs.name << " is not found in the pre-selected 
subgraph nodes."
+             " Please make sure that no new nodes were added in your subgraph"
+             " selector's Filter function";
+      }
+
+      // make sure nodes are sorted
+      std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
+
+      // reset node labels that are not in filtered nodes
+      for (const auto n : preselected_nodes) {
+        const auto nit = std::find(filtered_nodes.begin(), 
filtered_nodes.end(), n);
+        if (nit == filtered_nodes.end()) {
+          simple_nodes[indexed_graph.node_id(n)]->label = -1;
+        }
+      }
+      // find out subgraphs from the filtered nodes
+      std::vector<std::vector<SimpleNode*>> subgraphs;
+      PostProcessNodeCandidates(*g, filtered_nodes, simple_nodes, &subgraphs, 
&subgraph_id);
 
 Review comment:
   Should we require the nodes returned from Filter belong to a single subgraph?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to