ArmageddonKnight commented on a change in pull request #18228:
URL: https://github.com/apache/incubator-mxnet/pull/18228#discussion_r420391476
##########
File path: src/executor/graph_executor.cc
##########
@@ -356,29 +343,45 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol
symbol,
}
}
- int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
- auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
- if (node.is_variable()) return 0;
- const std::string& type = node.attrs.op->name;
- if (type == "Dropout") return false;
- if (get_node_attr(node, "__force_mirroring__", false)) return true;
- if (do_mirror == 0) return false;
- if (type == "Convolution") return false;
- if (type == "FullyConnected") return false;
- if (type == "Concat") return false;
- if (type == "SoftmaxOutput") return false;
- return true;
- };
+ std::function<int(const nnvm::Node&)> need_mirror =
+ [](const nnvm::Node& node) -> int {
+ if (node.is_variable()) return false;
+ const std::string& type = node.attrs.op->name;
+ if (type == "Dropout") return false;
+ // We follow the hidden key attribute "force_mirroring" if it is
+ // explicitly set.
+ auto iter = node.attrs.dict.find("__force_mirroring__");
+ if (iter != node.attrs.dict.end()) {
+ bool do_mirror;
+ dmlc::parameter::FieldEntry<bool> e;
+ e.Init("__force_mirroring__", &do_mirror, do_mirror);
+ e.Set(&do_mirror, iter->second);
+ return do_mirror;
+ }
+ if (type == "Embedding") return false;
+ if (type == "Convolution") return false;
+ if (type == "FullyConnected") return false;
+ if (type == "Concat") return false;
+ if (type == "SoftmaxOutput") return false;
+ return true;
+ };
std::vector<const nnvm::Op*> zero_ops;
zero_ops.push_back(nnvm::Op::Get("zeros_like"));
zero_ops.push_back(nnvm::Op::Get("_zeros"));
+ LOG(INFO) << "Doing Memory Optimization?: " <<
dmlc::GetEnv("MXNET_MEMORY_OPT", 0);
+ const char *do_memory_opt = getenv("MXNET_MEMORY_OPT");
+ LOG(INFO) << "Doing Memory Optimization?: " << do_memory_opt;
Review comment:
For some reason, the test cases are not passing in Windows environment.
Since I do not have access to Windows workstations, I am using those for
debugging purpose and will remove them later on.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]