This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 89c7d57  Improve bulking in Gluon (#13890)
89c7d57 is described below

commit 89c7d57333ce121cdda528129ffe6efdea2dc524
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Wed Jan 30 15:47:18 2019 -0800

    Improve bulking in Gluon (#13890)
    
    * Improve bulking in Gluon
    
    * Trigger CI
---
 include/mxnet/base.h              |  4 ++++
 src/engine/stream_manager.h       | 10 ++++++----
 src/engine/threaded_engine.h      |  6 ++++++
 src/imperative/cached_op.cc       |  5 +----
 src/imperative/imperative_utils.h | 10 +---------
 5 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index f88b227..7f12643 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -233,6 +233,10 @@ struct RunContext {
    */
   void *stream;
   /*!
+   * \brief indicator of whether this execution is run in bulk mode
+   */
+  bool is_bulk;
+  /*!
    * \brief get mshadow stream from Context
    * \return the mshadow stream
    * \tparam xpu the device type of the stream
diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h
index 516e04b..8d44d9c 100644
--- a/src/engine/stream_manager.h
+++ b/src/engine/stream_manager.h
@@ -67,7 +67,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
   RunContext ret;
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
-      ret = RunContext{ctx, nullptr};
+      ret = RunContext{ctx, nullptr, false};
       break;
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
@@ -85,7 +85,9 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
         use_counter = counter;
         counter = (counter + 1) % kStreams;
       }
-      ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)};
+      ret = RunContext{ctx,
+                       gpu_streams_.at(ctx.dev_id).at(use_counter),
+                       false};
       break;
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
@@ -103,7 +105,7 @@ RunContext StreamManager<kNumGpus, 
kStreams>::GetIORunContext(
   RunContext ret;
   switch (ctx.dev_mask()) {
     case cpu::kDevMask:
-      ret = RunContext{ctx, nullptr};
+      ret = RunContext{ctx, nullptr, false};
       break;
     case gpu::kDevMask: {
 #if MXNET_USE_CUDA
@@ -114,7 +116,7 @@ RunContext StreamManager<kNumGpus, 
kStreams>::GetIORunContext(
           gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, 
false, ctx.dev_id);
         }
       }
-      ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id)};
+      ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false};
       break;
 #else
       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index fae120d..18018cb 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -499,7 +499,13 @@ class ThreadedEngine : public Engine {
     DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
     SyncFn fn = std::move(bulk_status.fn);
     this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) {
+        ctx.is_bulk = true;
         fn(ctx);
+        ctx.is_bulk = false;
+        bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask;
+        if (is_gpu) {
+          ctx.get_stream<gpu>()->Wait();
+        }
         on_complete();
       }, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
       FnProperty::kNormal, 0, "ImperativeBulk");
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index f4047d1..58ec4e6 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -583,14 +583,11 @@ void CachedOp::StaticInitExec(
     }
 
     size_t bulk_size = idx.num_nodes();
-    std::unordered_set<uint32_t> excludes;
     if (recording || keep_fwd) {
       bulk_size = keep_fwd ? config_.backward_bulk_size : 
config_.forward_bulk_size;
-      for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
-      for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 
0));
     }
 
-    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, 
excludes,
+    CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
                       state.execs, skip_plus_node, &state.opr_segs);
   }
 
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 4b0d131..98f6c8f 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -406,7 +406,7 @@ inline void PushFCompute(const FCompute& fn,
       fn(attrs, opctx, input_blobs, tmp_req, output_blobs);
       // post-fcompute fallback, cast to original storage type
       CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-      if (is_gpu) {
+      if (is_gpu && !rctx.is_bulk) {
         rctx.get_stream<gpu>()->Wait();
       }
     }, ctx, read_vars, write_vars, FnProperty::kNormal,
@@ -928,7 +928,6 @@ inline void CreateEngineOpSeg(
     const size_t start_nid,
     const size_t end_nid,
     const size_t bulk_size,
-    const std::unordered_set<uint32_t>& excludes,
     const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
     const std::vector<int> skip_plus_node,
     std::vector<EngineOprSeg> *opr_segs) {
@@ -944,13 +943,6 @@ inline void CreateEngineOpSeg(
 
     // Stop at async nodes and invalid node (due to input/output is not 
allocated)
     bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
-    for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
-      if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
-    }
-    auto num_outputs = node.source->num_outputs();
-    for (size_t i = 0; i < num_outputs && !stop; ++i) {
-      if (excludes.count(idx.entry_id(nid, i))) stop = true;
-    }
 
     // Create opr segment for previous nodes.
     if (stop && nid > seg_start) {

Reply via email to