mseth10 commented on a change in pull request #17623: Dynamic subgraph compile 
support
URL: https://github.com/apache/incubator-mxnet/pull/17623#discussion_r393451877
 
 

 ##########
 File path: src/operator/subgraph/partitioner/custom_subgraph_property.h
 ##########
 @@ -188,31 +278,133 @@ class  CustomSubgraphProperty: public SubgraphProperty {
       }
 
       std::string subgraph_json = nnvm::pass::SaveJSON(g);
-      CHECK(call_review_subgraph_(review_subgraph_, subgraph_json.c_str(),
-                                subgraph_id, &accept, opt_keys_.data(),
-                                opt_vals_.data(), opt_keys_.size(),
-                                &attr_keys, &attr_vals, &num_attr))
+      CHECK(call_review_subgraph_(review_subgraph_, subgraph_json.c_str(),  
subgraph_id,
+                                  &accept, opt_keys_.data(), opt_vals_.data(),
+                                  opt_keys_.size(),  &attr_keys, &attr_vals, 
&num_attr,
+                                  arg_names.data(), arg_names.size(), 
arg_data.data(),
+                                  arg_shapes.data(), arg_dims.data(), 
arg_types.data(),
+                                  arg_verIDs.data(), arg_dev_type.data(),
+                                  arg_dev_id.data(), aux_names.data(), 
aux_names.size(),
+                                  aux_data.data(), aux_shapes.data(), 
aux_dims.data(),
+                                  aux_types.data(), aux_verIDs.data(),
+                                  aux_dev_type.data(), aux_dev_id.data()))
         << "Error calling review_subgraph for '" << subgraph_prop << "'";
+
+      if (num_attr > 0) {
+        // set user specified attributes
+        for (int i=0; i < num_attr; i++) {
+          user_attrs[attr_keys[i]] = attr_vals[i];
+          call_free_(attr_vals[i]);
+          call_free_(attr_keys[i]);
+        }
+        // free memory used by custom op to allocate attributes
+        call_free_(attr_vals);
+        call_free_(attr_keys);
+      }
     }
+
     if (accept) {
       nnvm::ObjectPtr n = nnvm::Node::Create();
       n->attrs.op = Op::Get(subgraph_op_name);
       n->attrs.name = "_op" + std::to_string(subgraph_id);
       n->attrs.subgraphs.push_back(std::make_shared<nnvm::Symbol>(sym));
-      // set user specified attributes
-      for (int i=0; i < num_attr; i++) {
-        n->attrs.dict[attr_keys[i]] = attr_vals[i];
-        call_free_(attr_vals[i]);
-        call_free_(attr_keys[i]);
+
+      // set shapes
+      {
+        std::stringstream ss;
+        ss << "[";
+        for (unsigned i=0; i < sym.outputs.size(); i++) {
+          nnvm::Node* n = sym.outputs[i].node.get();
+          if (n->attrs.dict.count("__shape__") > 0) {
+            std::string& shape = n->attrs.dict["__shape__"];
 
 Review comment:
   modify logic for the case when `n` is a subgraph node and `shape` a list of 
lists

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to