samskalicky commented on a change in pull request #18894:
URL: https://github.com/apache/incubator-mxnet/pull/18894#discussion_r472514368
##########
File path: example/extensions/lib_subgraph/subgraph_lib.cc
##########
@@ -176,70 +174,42 @@ REGISTER_OP(_custom_subgraph_op)
const std::vector<std::string> op_names({"exp","log"});
-MXReturnValue mySupportedOps(const std::string& json,
+MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
std::vector<int>* ids,
const std::unordered_map<std::string,
std::string>& options) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
- //convert json string to json object
- JsonParser parser;
- JsonVal json_val = parser.parse_to_json(json);
- //get nodes list
- JsonVal nodes = json_val.map[JsonVal("nodes")];
//loop over nodes
- for(int i=0; i<nodes.list.size(); i++) {
- JsonVal node = nodes.list[i];
- JsonVal op = node.map[JsonVal("op")];
+ for(int i=0; i<graph->size(); i++) {
+ const mxnet::ext::Node *node = graph->getNode(i);
//get shape/type if available
std::string shape;
int dtype = -1;
- if(node.map.find(JsonVal("attrs")) != node.map.end()) {
- JsonVal attrs = node.map[JsonVal("attrs")];
- if(attrs.map.find(JsonVal("shape")) != attrs.map.end())
- shape = attrs.map[JsonVal("shape")].str;
- if(attrs.map.find(JsonVal("dtype")) != attrs.map.end())
- dtype = std::stoi(attrs.map[JsonVal("dtype")].str);
- }
+ if(node->attrs.count("shape") > 0)
+ shape = node->attrs.at("shape");
+ if(node->attrs.count("dtype") > 0)
+ dtype = std::stoi(node->attrs.at("dtype"));
//check if op dtype is float, and if option was specified to require float
types
if((dtype == kFloat32 && options.count("reqFloat") > 0) ||
options.count("reqFloat") == 0) {
- //check if op is in whitelist
- if(std::find(op_names.begin(),op_names.end(),op.str.c_str()) !=
op_names.end()) {
- // found op in whitelist, set value to -1 to include op in any subgraph
+ //check if op is in allowlist
+ if(std::find(op_names.begin(),op_names.end(),node->op.c_str()) !=
op_names.end()) {
+ // found op in allowlist, set value to -1 to include op in any subgraph
ids->at(i) = -1;
}
}
}
return MX_SUCCESS;
}
-MXReturnValue myReviewSubgraph(const std::string& json, int subgraph_id, bool*
accept,
- const std::unordered_map<std::string,
std::string>& options,
- std::unordered_map<std::string, std::string>*
attrs,
- const std::unordered_map<std::string,
MXTensor>& args,
- const std::unordered_map<std::string,
MXTensor>& aux) {
+MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int
subgraph_id, bool* accept,
+ const std::unordered_map<std::string,
std::string>& options) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}
- for (auto kv : args) {
- std::cout << "arg: " << kv.first << " ==> (";
- for (auto s : kv.second.shape)
- std::cout << s << ",";
- std::cout << ") [";
- for (int i=0; i<kv.second.size(); i++)
- std::cout << kv.second.data<float>()[i] << ", ";
- std::cout << "]" << std::endl;
- }
-
- // check if option `reqArgs` was specified, and if so check if args were
provided
Review comment:
Plus, now the args are not available in a separate map, they are
embedded within the Node objects in the graph (ie. Node->tensor)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]