cjolivier01 commented on a change in pull request #15167: Pointwise fusion for
GPU
URL: https://github.com/apache/incubator-mxnet/pull/15167#discussion_r306512834
##########
File path: src/executor/infer_graph_attr_pass.cc
##########
@@ -63,6 +63,135 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const
nnvm::Graph& g,
return true;
}
+template<typename AttrType, typename IsNone>
+inline void GetAttrFromForwardNode(const uint32_t nid,
+ const nnvm::IndexedGraph &idx,
+ std::vector<AttrType>* rshape_ptr,
+ IsNone fis_none) {
+ std::vector<AttrType>& rshape = *rshape_ptr;
+ const auto& inode = idx[nid];
+ // gradient function, used to get node correspondence.
+ static auto& fgrad =
+ Op::GetAttr<nnvm::FGradient>("FGradient");
+ nnvm::NodePtr fwd_ptr = inode.source->control_deps[0];
+ const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
+ // use gradient function to find out the correspondence.
+ std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
+ for (size_t i = 0; i < ograd.size(); ++i) {
+ ograd[i].index = static_cast<uint32_t>(i);
+ }
+ // input gradient list
+ const std::vector<nnvm::NodeEntry>& igrad = fgrad[fwd_ptr->op()](fwd_ptr,
ograd);
+ const nnvm::Node* igrad_node = nullptr;
+ // Input gradient assignement
+ for (size_t i = 0; i < igrad.size(); ++i) {
+ if (igrad[i].node->op() == inode.source->op()) {
+ uint32_t eid = idx.entry_id(nid, igrad[i].index);
+ if (fis_none(rshape[eid])) {
+ rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
+ } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
+ // Need to skip empty forward shape, because it may not be
+ // available now and it is possible to infer the forward
+ // shape in one of the next a few passes
+ CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
+ << "Backward shape inconsistent with the forward shape";
+ }
+ if (igrad_node == nullptr) {
+ igrad_node = igrad[i].node.get();
+ } else {
+ CHECK(igrad_node == igrad[i].node.get());
+ }
+ }
+ }
+ // out grad entries
+ CHECK(igrad_node != nullptr)
+ << "Cannot find matching backward op for " << inode.source->attrs.name;
+ for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
+ const nnvm::NodeEntry& e = igrad_node->inputs[i];
+ if (e.node == nullptr) {
+ uint32_t eid = idx.entry_id(inode.inputs[i]);
+ if (fis_none(rshape[eid])) {
+ rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
+ }
+ }
+ }
+}
+
+template<typename FAccessSubgraphType, typename AttrType, typename IsNone>
+void GetAttrFromFusedNode(uint32_t nid,
+ const nnvm::IndexedGraph& idx,
+ std::vector<AttrType>* rshape_ptr,
+ IsNone fis_none,
+ const std::string& infer_fusion_name) {
+ std::vector<AttrType>& rshape = *rshape_ptr;
+ const auto& inode = idx[nid];
Review comment:
nit: too much use of “auto” where type is not obvious in this review. Makes
the code harder to read.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services