samskalicky commented on a change in pull request #18779:
URL: https://github.com/apache/incubator-mxnet/pull/18779#discussion_r468738020



##########
File path: include/mxnet/lib_api.h
##########
@@ -748,62 +759,410 @@ struct JsonParser {
     return JsonVal();
   }
   // generic parse function
-  JsonVal parse(const std::string& json, unsigned int *idx) {
+  static JsonVal parse(const std::string& json, unsigned int *idx) {
     JsonVal ret;
     while (*idx < json.size()) {
       if (json[*idx] == '"') {
         ++(*idx);
-        ret = parse_string(json, idx);
+        ret = JsonVal::parse_string(json, idx);
       } else if (json[*idx] >= '0' && json[*idx] <= '9') {
-        ret = parse_num(json, idx);
+        ret = JsonVal::parse_num(json, idx);
       } else if (json[*idx] == '[') {
         ++(*idx);
-        ret = parse_list(json, idx);
+        ret = JsonVal::parse_list(json, idx);
       } else if (json[*idx] == '{') {
         ++(*idx);
-        ret = parse_map(json, idx);
+        ret = JsonVal::parse_map(json, idx);
       } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;}
       if (ret.type != ERR) return ret;
       ++(*idx);
     }
     return ret;
   }
-  // convert JSON object back to JSON-compatible string
-  std::string dump(const JsonVal &val) {
+  // debug function to convert data structure to a debugstring
+  std::string toString() const {
     std::string ret;
-    switch (val.type) {
+    switch (type) {
     case ERR:
       ret = "json(Error)";
       break;
     case STR:
-      ret = "\"" + val.str + "\"";
+      ret = "json(STR:" + str + ")";
       break;
     case NUM:
-      ret = val.str;
+      ret = "json(INT:" + str + ")";
       break;
     case LIST:
-      ret = "[";
-      for (unsigned i=0; i < val.list.size(); i++) {
-        auto &item = val.list[i];
-        ret += dump(item);
-        if (i < val.list.size()-1)
-          ret += ",";
-      }
-      ret += "]";
+      ret = "json(LIST:[";
+      for (auto &item : list)
+        ret += item.toString() + ",";
+      ret += "])";
       break;
     case MAP:
-      ret = "{";
-      unsigned cnt = 0;
-      for (auto &item : val.map) {
-        ret += dump(item.first) + " : " + dump(item.second);
-        if (cnt++ < val.map.size()-1)
-          ret += ",";
-      }
-      ret += "}";
+      ret = "json(MAP:{";
+      for (auto &item : map)
+        ret += item.first.toString() + " : " + item.second.toString() + ",";
+      ret += "})";
       break;
     }
     return ret;
   }
+  JsonType type;
+  int num;
+  std::string str;
+  std::vector<JsonVal> list;
+  std::map<JsonVal, JsonVal> map;
+};
+
+/*!
+ * \brief Graph utility to parse serialized subgraph symbol
+ */
+class Node;
+class Graph;
+
+// Representation of an input/output to a node
+struct NodeEntry {
+  Node* node;  // other node thats producing/consuming inputs/outputs
+  int entry;  // entry index from other node (ie. output index from producing 
node)
+};
+
+// Representation of a node in the graph
+class Node {
+ public:
+  Node() {tensor = nullptr;}
+  // internally set passResource to enable tensor allocation for graph passes
+  void _setPassResource(PassResource* res_) {res = res_;}
+  /* \brief allocate an arg tensor for this node */
+  void alloc_arg(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_arg outside of graph 
passes.");
+    tensor = res->alloc_arg(name, shapes, ctx, dtype);
+  }
+  /* \brief allocate an aux tensor for this node */
+  void alloc_aux(const std::vector<int64_t>& shapes,
+                 const MXContext &ctx, MXDType dtype) {
+    if (!res)
+      throw std::runtime_error(
+                 "Node not initialized. Cannot use alloc_aux outside of graph 
passes.");
+    tensor = res->alloc_aux(name, shapes, ctx, dtype);
+  }
+  std::string op;  // operator name (ie. Convolution)
+  std::string name;  // unique node name (ie. conv_0 or conv_1)
+  MXTensor* tensor;  // tensor data for input nodes
+  std::vector<NodeEntry> inputs;  // set of inputs to the node
+  std::vector<NodeEntry> outputs;  // set of outputs from the node
+  std::vector<Graph*> subgraphs;  // set of subgraphs within this node
+  std::unordered_map<std::string, std::string> attrs;  // node attributes
+
+ private:
+  PassResource* res;
+};
+
+// Representation of the graph
+class Graph {
+ public:
+  Graph() : res(nullptr) {}
+  /* \brief deleted nodes when deleting the graph */
+  ~Graph() {
+    for (int i = 0; i < nodes.size(); i++)
+      delete nodes[i];
+  }
+
+  /* \brief create a graph object from an unparsed string */
+  static Graph* fromString(const std::string& json) {
+    JsonVal val = JsonVal::parse(json);
+    return fromJson(val);
+  }
+
+  /* \brief create a graph object from a parsed JSON object */
+  static Graph* fromJson(JsonVal val) {
+    // get nodes list
+    JsonVal nodes = val.map[JsonVal("nodes")];
+    Graph *g = new Graph();
+
+    std::map<int, Node*> nodeMap;
+    // loop over nodes
+    for (int i = 0; i < nodes.list.size(); i++) {
+      Node* n = new Node();
+      g->nodes.push_back(n);
+      JsonVal node = nodes.list[i];
+
+      // set the op info
+      n->op = node.map[JsonVal("op")].str;
+      n->name = node.map[JsonVal("name")].str;
+
+      // if op is null its an input to the graph

Review comment:
       fixed!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to