This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9cc8ea3 Re-enable all op segments when in batch mode (#9055)
9cc8ea3 is described below
commit 9cc8ea3be23fb7adf4630e4cf065a2473094fbc8
Author: Kellen Sunderland <[email protected]>
AuthorDate: Mon Jan 15 21:05:10 2018 +0100
Re-enable all op segments when in batch mode (#9055)
* Re-enable all op segments when in batch mode
* Split training/inference logic, split when required during inference.
As suggested by Haibin, this segments after kLocal and kCrossDeviceCopy ops.
---
src/executor/graph_executor.cc | 135 +++++++++++++++++++++++++----------------
src/executor/graph_executor.h | 4 ++
2 files changed, 86 insertions(+), 53 deletions(-)
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 42508b1..2a7d2b9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1348,71 +1348,100 @@ void GraphExecutor::InitOpSegs() {
bool prefer_bulk_exec_inference =
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
// Whether to perform bulk exec for training
bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1);
+
+ bool is_training = num_forward_nodes_ != total_num_nodes;
+
+ if (prefer_bulk_exec && is_training) {
+ this->BulkTrainingOpSegs(total_num_nodes);
+ }
+
+ if (prefer_bulk_exec_inference && !is_training) {
+ this->BulkInferenceOpSegs();
+ }
+
+ return;
+}
+
+void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) {
// The maximum number of node in a segment executed in bulk
size_t num_nodes_threshold =
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
- if (prefer_bulk_exec_inference && num_forward_nodes_ == total_num_nodes) {
- // bulk the whole graph for inference
- num_nodes_threshold = std::numeric_limits<size_t>::max();
- }
-
- if (prefer_bulk_exec) {
- // create forward segments for training
- size_t topo_start = 0;
- for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
- auto &node = graph_.indexed_graph()[nid].source;
- auto &op_node = op_nodes_[nid];
- // check if the segment relies on external input, or exceeds maxinum
number of node,
- // or requires async ops
- if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
- op_node.exec->exec_type() != ExecType::kSync) {
- // create a new segment for the previous nodes if the current one
cannot be bulked
- cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
nid);
- topo_start = nid + 1;
- }
- }
- // the last segmenet
- if (topo_start != num_forward_nodes_) {
- cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
num_forward_nodes_);
+
+ // create forward segments for training
+ size_t topo_start = 0;
+ for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
+ auto &node = graph_.indexed_graph()[nid].source;
+ auto &op_node = op_nodes_[nid];
+ // check if the segment relies on external input, or exceeds maxinum
number of node,
+ // or requires async ops
+ if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
+ op_node.exec->exec_type() != ExecType::kSync) {
+ // create a new segment for the previous nodes if the current one cannot
be bulked
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
+ topo_start = nid + 1;
}
+ }
+ // the last segment
+ if (topo_start != num_forward_nodes_) {
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
num_forward_nodes_);
+ }
- // create backward segments for training
- // get all gradient variables
- std::unordered_set<engine::VarHandle> grad_vars;
- for (auto &kv : grad_store_) {
- grad_vars.insert(kv.second.var());
+ // create backward segments for training
+ // get all gradient variables
+ std::unordered_set<engine::VarHandle> grad_vars;
+ for (auto &kv : grad_store_) {
+ grad_vars.insert(kv.second.var());
+ }
+ auto &idx = graph_.indexed_graph();
+ topo_start = num_forward_nodes_;
+ for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
+ auto &op_node = op_nodes_[nid];
+ if (op_node.skip_exec_node || op_node.exec == nullptr) {
+ continue;
}
- auto &idx = graph_.indexed_graph();
- topo_start = num_forward_nodes_;
- for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
- auto &op_node = op_nodes_[nid];
- if (op_node.skip_exec_node || op_node.exec == nullptr) {
- continue;
+ if (idx[nid].source->is_variable() || nid - topo_start >
num_nodes_threshold ||
+ op_node.exec->exec_type() != ExecType::kSync) {
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
+ topo_start = nid + 1;
+ } else {
+ // If it produces output gradient, don't include it in the segment
+ bool output_gradient = false;
+ for (auto &out_arr : op_node.exec->out_array) {
+ if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
+ output_gradient = true;
+ }
}
- if (idx[nid].source->is_variable() || nid - topo_start >
num_nodes_threshold ||
- op_node.exec->exec_type() != ExecType::kSync) {
+ if (output_gradient) {
cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
nid);
topo_start = nid + 1;
- } else {
- // If it produces output gradient, don't include it in the segment
- bool output_gradient = false;
- for (auto &out_arr : op_node.exec->out_array) {
- if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
- output_gradient = true;
- }
- }
- if (output_gradient) {
- cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
nid);
- topo_start = nid + 1;
- }
}
}
- // last segment for backward
- if (topo_start < total_num_nodes) {
- cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
total_num_nodes);
- }
}
+ // last segment for backward
+ if (topo_start < total_num_nodes) {
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
total_num_nodes);
+ }
+}
- return;
+void GraphExecutor::BulkInferenceOpSegs() {
+ // Attempt to bulk the whole graph for inference. We will only create new
segments when
+ // required for non-kSync operations.
+ size_t topo_start = 0;
+ for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
+ auto &node = graph_.indexed_graph()[nid].source;
+ auto &op_node = op_nodes_[nid];
+
+ // Variables do not need to be segmented at inference time.
+ if (node->is_variable()) continue;
+
+ if (op_node.exec->exec_type() != ExecType::kSync) {
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
+ topo_start = nid + 1;
+ }
+ }
+ // The last segment
+ if (topo_start != num_forward_nodes_) {
+ cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start,
num_forward_nodes_);
+ }
}
void GraphExecutor::ExecuteMonCallback(size_t nid) {
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index 0e5ef32..ee32db7 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -197,6 +197,10 @@ class GraphExecutor : public Executor {
CachedSegOpr CreateCachedSegOpr(size_t topo_start, size_t topo_end);
// run the monitor callback for node `nid`
void ExecuteMonCallback(size_t nid);
+ // peform bulking and segmentation on an inference graph
+ void BulkInferenceOpSegs();
+ // perform bulking and segmentation on a training graph
+ void BulkTrainingOpSegs(size_t total_num_nodes);
// internal graph
nnvm::Graph graph_;
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].