mseth10 commented on a change in pull request #18779:
URL: https://github.com/apache/incubator-mxnet/pull/18779#discussion_r470249438



##########
File path: include/mxnet/lib_api.h
##########
@@ -748,62 +759,410 @@ struct JsonParser {
     return JsonVal();
   }
   // generic parse function
-  JsonVal parse(const std::string& json, unsigned int *idx) {
+  static JsonVal parse(const std::string& json, unsigned int *idx) {
     JsonVal ret;
     while (*idx < json.size()) {
       if (json[*idx] == '"') {
         ++(*idx);
-        ret = parse_string(json, idx);
+        ret = JsonVal::parse_string(json, idx);
       } else if (json[*idx] >= '0' && json[*idx] <= '9') {
-        ret = parse_num(json, idx);
+        ret = JsonVal::parse_num(json, idx);
       } else if (json[*idx] == '[') {
         ++(*idx);
-        ret = parse_list(json, idx);
+        ret = JsonVal::parse_list(json, idx);
       } else if (json[*idx] == '{') {
         ++(*idx);
-        ret = parse_map(json, idx);
+        ret = JsonVal::parse_map(json, idx);
       } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;}
       if (ret.type != ERR) return ret;
       ++(*idx);
     }
     return ret;
   }
-  // convert JSON object back to JSON-compatible string
-  std::string dump(const JsonVal &val) {
+  // debug function to convert data structure to a debugstring
+  std::string toString() const {
     std::string ret;
-    switch (val.type) {
+    switch (type) {
     case ERR:
       ret = "json(Error)";
       break;
     case STR:
-      ret = "\"" + val.str + "\"";
+      ret = "json(STR:" + str + ")";
       break;
     case NUM:
-      ret = val.str;
+      ret = "json(INT:" + str + ")";
       break;
     case LIST:
-      ret = "[";
-      for (unsigned i=0; i < val.list.size(); i++) {
-        auto &item = val.list[i];
-        ret += dump(item);
-        if (i < val.list.size()-1)
-          ret += ",";
-      }
-      ret += "]";
+      ret = "json(LIST:[";
+      for (auto &item : list)
+        ret += item.toString() + ",";
+      ret += "])";
       break;
     case MAP:
-      ret = "{";
-      unsigned cnt = 0;
-      for (auto &item : val.map) {
-        ret += dump(item.first) + " : " + dump(item.second);
-        if (cnt++ < val.map.size()-1)
-          ret += ",";
-      }
-      ret += "}";
+      ret = "json(MAP:{";
+      for (auto &item : map)
+        ret += item.first.toString() + " : " + item.second.toString() + ",";
+      ret += "})";
       break;
     }
     return ret;
   }
+  JsonType type;
+  int num;
+  std::string str;
+  std::vector<JsonVal> list;
+  std::map<JsonVal, JsonVal> map;
+};
+
+/*!
+ * \brief Graph utility to parse serialized subgraph symbol
+ */
+class Node;
+class Graph;
+
+// Representation of an input/output to a node
+struct NodeEntry {
+  Node* node;  // other node thats producing/consuming inputs/outputs
+  int entry;  // entry index from other node (ie. output index from producing 
node)
+};
+
+// Representation of a node in the graph
+class Node {
+ public:
+  Node() {tensor = nullptr;}
+  // internally set passResource to enable tensor allocation for graph passes
+  void _setPassResource(PassResource* res_) {res = res_;}
+  /* \brief allocate an arg tensor for this node */
+  void alloc_arg(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_arg outside of graph 
passes.");
+    tensor = res->alloc_arg(name, shapes, ctx, dtype);
+  }
+  /* \brief allocate an aux tensor for this node */
+  void alloc_aux(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_aux outside of graph 
passes.");
+    tensor = res->alloc_aux(name, shapes, ctx, dtype);
+  }
+  std::string op;  // operator name (ie. Convolution)
+  std::string name;  // unique node name (ie. conv_0 or conv_1)
+  MXTensor* tensor;  // tensor data for input nodes
+  std::vector<NodeEntry> inputs;  // set of inputs to the node
+  std::vector<NodeEntry> outputs;  // set of outputs from the node
+  std::vector<Graph*> subgraphs;  // set of subgraphs within this node
+  std::unordered_map<std::string, std::string> attrs;  // node attributes
+
+ private:
+  PassResource* res;
+};
+
+// Representation of the graph
+class Graph {
+ public:
+  Graph() : res(nullptr) {}
+  /* \brief deleted nodes when deleting the graph */
+  ~Graph() {
+    for (int i = 0; i < nodes.size(); i++)
+      delete nodes[i];
+  }
+
+  /* \brief create a graph object from an unparsed string */
+  static Graph* fromString(const std::string& json) {
+    JsonVal val = JsonVal::parse(json);
+    return fromJson(val);
+  }
+
+  /* \brief create a graph object from a parsed JSON object */
+  static Graph* fromJson(JsonVal val) {
+    // get nodes list
+    JsonVal nodes = val.map[JsonVal("nodes")];
+    Graph *g = new Graph();
+
+    std::map<int, Node*> nodeMap;
+    // loop over nodes
+    for (int i = 0; i < nodes.list.size(); i++) {
+      Node* n = new Node();
+      g->nodes.push_back(n);
+      JsonVal node = nodes.list[i];
+
+      // set the op info
+      n->op = node.map[JsonVal("op")].str;
+      n->name = node.map[JsonVal("name")].str;
+
+      // if op is null it is an input to the graph
+      if (n->op.compare("null") == 0)
+        g->inputs.push_back(n);
+
+      // set attrs
+      JsonVal attributes = node.map[JsonVal("attrs")];
+      for (auto& kv : attributes.map) {
+        n->attrs[kv.first.str] = kv.second.str;
+      }
+
+      // set subgraphs, parsing each into a graph
+      if (node.map.count(JsonVal("subgraphs")) > 0) {
+        JsonVal subgraphs = node.map[JsonVal("subgraphs")];
+        for (auto &subgraph : subgraphs.list) {
+          n->subgraphs.push_back(fromJson(subgraph));
+        }
+      }
+
+      // set node inputs
+      JsonVal node_inputs = node.map[JsonVal("inputs")];
+      n->inputs.resize(node_inputs.list.size());
+      for (int j = 0; j < node_inputs.list.size(); j++) {
+        JsonVal input = node_inputs.list[j];
+        NodeEntry& entry = n->inputs[j];
+        // get pointer to other node
+        entry.node = nodeMap[input.list[0].num];
+        // get the other node's output index
+        entry.entry = input.list[1].num;
+        // set other nodes output as connected to this node
+        entry.node->outputs.push_back({n, j});
+      }
+      nodeMap[i] = n;
+    }
+
+    // set graph level outputs
+    JsonVal& heads = val.map[JsonVal("heads")];
+    g->outputs.resize(heads.list.size());
+    for (int i = 0; i < heads.list.size(); i++) {
+      JsonVal head = heads.list[i];
+      g->outputs[i].node = nodeMap[head.list[0].num];
+      g->outputs[i].entry = head.list[1].num;
+    }
+
+    // add all attributes to the graph
+    for (auto& kv : val.map) {
+      if (kv.first.str.compare("nodes") != 0 &&
+         kv.first.str.compare("heads") != 0 &&
+         kv.first.str.compare("node_row_ptr") != 0 &&
+         kv.first.str.compare("arg_nodes") != 0) {
+        g->attrs[kv.first.str] = kv.second;
+      }
+    }
+    return g;
+  }
+
+  /* \brief convert graph object back to JSON object */
+  JsonVal toJson() {
+    // top level object is a map
+    JsonVal val(MAP);
+
+    // add attributes
+    for (auto& kv : attrs) {
+      val.map[JsonVal(kv.first)] = kv.second;
+    }
+
+    // sort graph nodes in topological order, create mapping of node to index
+    std::map<Node*, int> nodeMap;
+    std::vector<Node*> sorted = topological_sort();
+    // nodes are in reverse topological order in the vector (back is first)
+    // so loop from end to front over the vector 'sorted'
+    for (int i = sorted.size()-1; i >= 0; i--) {
+      nodeMap[sorted[i]] = sorted.size()-1-i;
+    }
+
+    // create node_row_ptr entry
+    val.map[JsonVal("node_row_ptr")] = JsonVal(LIST);
+    JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")];
+    for (int i = 0; i < nodes.size(); i++)
+      node_row_ptr.list.push_back(JsonVal(i));
+
+    // add all input nodes
+    val.map[JsonVal("arg_nodes")] = JsonVal(LIST);
+    JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")];
+    for (int i = 0; i < inputs.size(); i++)
+      arg_nodes.list.push_back(JsonVal(nodeMap[inputs[i]]));
+
+    // add all output nodes
+    val.map[JsonVal("heads")] = JsonVal(LIST);
+    JsonVal& heads = val.map[JsonVal("heads")];
+    for (int i = 0; i < outputs.size(); i++) {
+      heads.list.push_back(JsonVal(LIST));
+      JsonVal& out = heads.list[i];
+      out.list.push_back(JsonVal(nodeMap[outputs[i].node]));
+      out.list.push_back(JsonVal(outputs[i].entry));
+      out.list.push_back(JsonVal(0));
+    }
+
+    // add all graph nodes
+    val.map[JsonVal("nodes")] = JsonVal(LIST);
+    JsonVal& nodes_ = val.map[JsonVal("nodes")];
+    for (int i = sorted.size()-1; i >= 0; i--) {
+      // each node is a map
+      nodes_.list.push_back(JsonVal(MAP));
+      Node* n = sorted[i];
+      JsonVal& n_ = nodes_.list[nodes_.list.size()-1];
+
+      n_.map[JsonVal("op")] = JsonVal(n->op);
+      n_.map[JsonVal("name")] = JsonVal(n->name);
+      n_.map[JsonVal("inputs")] = JsonVal(LIST);
+
+      // add inputs for this node
+      JsonVal& inputs_ = n_.map[JsonVal("inputs")];
+      for (int j = 0; j < n->inputs.size(); j++) {
+        inputs_.list.push_back(JsonVal(LIST));
+        NodeEntry& entry = n->inputs[j];
+        JsonVal& in = inputs_.list[j];
+        in.list.push_back(JsonVal(nodeMap[entry.node]));
+        in.list.push_back(JsonVal(entry.entry));
+        in.list.push_back(JsonVal(0));
+      }
+
+      // add subgraphs for this node, convert each back to JSON
+      if (n->subgraphs.size() > 0) {
+        n_.map[JsonVal("subgraphs")] = JsonVal(LIST);
+        JsonVal &subgraphs_ = n_.map[JsonVal("subgraphs")];
+        for (Graph *subgraph : n->subgraphs) {
+          subgraphs_.list.push_back(subgraph->toJson());
+        }
+      }
+
+      // add attributes for this node
+      n_.map[JsonVal("attrs")] = JsonVal(MAP);
+      JsonVal& attrs_ = n_.map[JsonVal("attrs")];
+      for (auto& kv : n->attrs) {
+        attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second);
+      }
+    }
+    return val;
+  }
+
+  /* \brief convert graph object to JSON string */
+  std::string toString() {
+    return toJson().dump();
+  }
+
+  /* \brief visits a node "n" */
+  void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
+                 std::function<void(Node*)> handler) const {
+    to_visit->erase(n);  // remove node now that we're visiting it
+    for (NodeEntry& e : n->outputs) {
+      Node* o = e.node;
+      if (to_visit->count(o) != 0) {
+        _dfs_util(o, to_visit, handler);  // visit neighbor
+      }
+    }
+    handler(n);  // post-order visit this node
+  }
+
+  /* \brief post-order DFS graph traversal */
+  void DFS(std::function<void(Node*)> handler) const {
+    std::unordered_set<Node*> to_visit;
+    // put all nodes in set to visit
+    for (auto& n : nodes)
+      to_visit.insert(n);
+    // visit all inputs first
+    for (auto& i : inputs)
+      if (to_visit.count(i) != 0)
+        _dfs_util(i, &to_visit, handler);
+    // visit any nodes left
+    while (to_visit.size() > 0)
+      _dfs_util(*(to_visit.begin()), &to_visit, handler);
+  }
+
+  /* \brief sort graph nodes in topological order */
+  std::vector<Node*> topological_sort() const {
+    std::vector<Node*> sorted;
+    auto handler = [&](Node* n) {
+      sorted.push_back(n);  // when visiting each node, add it in order to the 
vector
+    };
+    DFS(handler);
+    return sorted;
+  }
+
+  /* \brief print out graph details */
+  void print(int indent = 0) const {
+    std::string space = "";
+    for (int i = 0; i < indent; i++) space+=" ";
+
+    std::cout << space << "########### Graph #############" << std::endl;
+    std::cout << space << "attributes: " << std::endl;
+    for (auto &kv : attrs)
+      std::cout << space << "\t" << kv.first << " : " << kv.second.str << 
std::endl;
+    std::cout << space << "inputs: " << inputs.size() << std::endl;
+    std::cout << space << "outputs: " << outputs.size() << std::endl;
+    std::cout << space << "nodes: " << nodes.size() << std::endl;
+    std::vector<Node*> sorted = topological_sort();
+    // loop over each node and print out its inputs/outputs
+    for (int i = sorted.size()-1; i >= 0; i--) {
+      std::cout << space << "Node: " << sorted[i]->name << std::endl;
+      for (int j = 0; j < sorted[i]->inputs.size(); j++) {
+        std::cout << space << "\tInput: " << sorted[i]->inputs[j].node->name 
<< " "
+                  << sorted[i]->inputs[j].entry << std::endl;
+      }
+      for (int j = 0; j < sorted[i]->outputs.size(); j++) {
+        std::cout << space << "\tOutput: " << sorted[i]->outputs[j].node->name 
<< " "
+                  << sorted[i]->outputs[j].entry << std::endl;
+      }
+      if (sorted[i]->subgraphs.size() > 0) {
+        for (auto &subgraph : sorted[i]->subgraphs) {
+          std::cout << space << "\tSubgraph:" << std::endl;
+          subgraph->print(indent+2);
+        }
+      }
+    }
+    std::cout << space << "###############################" << std::endl;
+  }
+
+  /* \brief add a new node to this graph */
+  Node* addNode(const std::string& name, const std::string& op) {
+    Node* n = new Node();
+    n->name = name;
+    n->op = op;
+    if (res)
+      n->_setPassResource(res);
+    return n;
+  }
+  /* \brief get node at index in graph */
+  Node* getNode(size_t idx) {

Review comment:
       why do we need two `getNode` functions?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to