This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3781816  Add exception handling support for waitall (#14397)
3781816 is described below

commit 37818165a7d97031a3db613d98511284bcccfbc7
Author: Anirudh Subramanian <[email protected]>
AuthorDate: Mon Apr 8 00:21:36 2019 -0700

    Add exception handling support for waitall (#14397)
    
    * Relax constexpr restriction
    
    * Change imagenet_gen_qsym_mkldnn
    
    * Add exception handling support for waitall
    
    * Fix exception handling documentation
    
    * Revert constexpr change
    
    * Add comments
    
    * Fix test
    
    * Skip exception for op check names
    
    * Print exceptions thrown for CPP Package NDArray module
    
    * Reducing batch_size to make cpp-package example pass
    
    * Fix bug: #14426
    
    * use ExceptionRef in threaded_engine code
    
    * add note for performance impact of waitall
    
    * Add check for GPU contxt
    
    * Use range for with const reference
    
    * Improve comments and error message for exception handling test
    
    * Change exception_ptr name in waitall
    
    * Fix bug
---
 cpp-package/example/resnet.cpp             |   2 +-
 cpp-package/include/mxnet-cpp/ndarray.hpp  |   6 +-
 docs/architecture/exception_handling.md    |   3 -
 python/mxnet/ndarray/ndarray.py            |   7 +-
 src/engine/threaded_engine.cc              |  20 +++++
 src/engine/threaded_engine.h               |  36 ++++++++-
 src/resource.cc                            |  14 ++--
 tests/python/unittest/test_exc_handling.py | 113 +++++++++++++++++++++--------
 tests/python/unittest/test_operator.py     |  14 +++-
 9 files changed, 159 insertions(+), 56 deletions(-)

diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp
index f59f606..8f8fd12 100644
--- a/cpp-package/example/resnet.cpp
+++ b/cpp-package/example/resnet.cpp
@@ -185,7 +185,7 @@ int main(int argc, char const *argv[]) {
 #if !MXNET_USE_CPU
   if (num_gpu > 0) {
     ctx = Context::gpu();
-    batch_size = 50;
+    batch_size = 32;
   }
 #endif
 
diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp 
b/cpp-package/include/mxnet-cpp/ndarray.hpp
index 966cf75..b667542 100644
--- a/cpp-package/include/mxnet-cpp/ndarray.hpp
+++ b/cpp-package/include/mxnet-cpp/ndarray.hpp
@@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape) 
const {
   return NDArray(handle);
 }
 inline void NDArray::WaitToRead() const {
-  CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0);
+  CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError();
 }
 inline void NDArray::WaitToWrite() {
-  CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0);
+  CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError();
 }
-inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); }
+inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) << 
MXGetLastError(); }
 inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out) 
{
   Operator("_random_normal")(mu, sigma).Invoke(*out);
 }
diff --git a/docs/architecture/exception_handling.md 
b/docs/architecture/exception_handling.md
index 6a9ab9a..87481bc 100644
--- a/docs/architecture/exception_handling.md
+++ b/docs/architecture/exception_handling.md
@@ -123,6 +123,3 @@ except mx.base.MXNetError as ex:
 d.asnumpy()
 ```
 
-### Limitation
-
-Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your 
code executes a few operators and then calls `waitall` instead of 
`wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls 
in your code unless you are confident about your code not throwing exception in 
any scenario.
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index acb7b28..87f2712 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -158,12 +158,9 @@ def waitall():
 
     This function is used for benchmarking only.
 
-    .. warning::
+    .. note::
 
-       If your code has exceptions, `waitall` can cause silent failures.
-       For this reason you should avoid `waitall` in your code.
-       Use it only if you are confident that your code is error free.
-       Then make sure you call `wait_to_read` on all outputs after `waitall`.
+       If your mxnet code throws an exception, then waitall can cause 
performance impact.
     """
     check_call(_LIB.MXNDArrayWaitAll())
 
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index b5897a1..986b6ad 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -415,6 +415,23 @@ void ThreadedEngine::WaitForAll() {
   finished_cv_.wait(lock, [this]() {
       return pending_.load() == 0 || kill_.load();
     });
+  std::exception_ptr exception_to_rethrow = nullptr;
+  if (!global_exception_refs_.empty()) {
+    // iterate through all exception refs
+    for (const auto& global_exception_ref : global_exception_refs_) {
+      // the first exception will be saved to be rethrown later
+      if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr) 
{
+        exception_to_rethrow = *global_exception_ref;
+      }
+      // clear exceptions, WaitToRead following WaitForAll shouldn't throw
+      *global_exception_ref = nullptr;
+    }
+    // A waitall following a waitall shouldn't throw any exceptions
+    global_exception_refs_.clear();
+    if (exception_to_rethrow != nullptr) {
+      std::rethrow_exception(exception_to_rethrow);
+    }
+  }
 }
 
 inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
@@ -428,6 +445,9 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* 
threaded_opr) {
   for (auto&& i : threaded_opr->mutable_vars) {
     if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
       i->var_exception = threaded_opr->opr_exception;
+      // add current operator exceptions to global exceptions if not already
+      // added
+      AddToGlobalExceptions(threaded_opr->opr_exception);
     }
     const bool debug_info = (engine_info_ && debug_wait_var_ == i);
     if (debug_info) {
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 640eac4..3d2119d 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -60,6 +60,9 @@ namespace engine {
 // Forward declarations
 struct ThreadedOpr;
 
+/*! shared_ptr to exception_ptr, used for exception handling */
+typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
+
 /*!
  * \brief Operation block in the scheduler.
  *  Each OprBlock corresponds to an operation pushed to the engine.
@@ -177,8 +180,12 @@ class ThreadedVar final
   static std::atomic<std::size_t> counter;
   ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
 #endif  // ENGINE_DEBUG
-  /*! \brief exception_ptr associated with the ThreadedVar */
-  std::shared_ptr<std::exception_ptr> var_exception;
+  /*!
+   * \brief exception_ptr associated with the ThreadedOpr
+   * cannot modify state of exception object since dereferencing
+   * exception_ptr is undefined behavior. Using shared_ptr to hold
+   * exception_ptr and overcome this limitation */
+  ExceptionRef var_exception;
 
  private:
   // TODO(hotpxl) change this to spinlock for faster runtime
@@ -254,8 +261,12 @@ struct ThreadedOpr final : public Opr,
   }
   // define possible debug information
   DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
-  /*! \brief exception_ptr associated with the ThreadedOpr */
-  std::shared_ptr<std::exception_ptr> opr_exception;
+  /*!
+   * \brief exception_ptr associated with the ThreadedOpr
+   * cannot modify state of exception object since dereferencing
+   * exception_ptr is undefined behavior. Using shared_ptr to hold
+   * exception_ptr and overcome this limitation */
+  ExceptionRef opr_exception;
 };  // struct ThreadedOpr
 
 /*!
@@ -432,6 +443,7 @@ class ThreadedEngine : public Engine {
   };
   /*! thread local store for bulk */
   typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
+
   /*!
    * \brief check if thee is duplication in const_vars and mutable_vars.
    * \param const_vars the variables to read from.
@@ -460,6 +472,7 @@ class ThreadedEngine : public Engine {
     for (auto&& i : threaded_opr->const_vars) {
       if (i->var_exception && *i->var_exception) {
         threaded_opr->opr_exception = i->var_exception;
+        AddToGlobalExceptions(threaded_opr->opr_exception);
         break;
       }
     }
@@ -467,6 +480,7 @@ class ThreadedEngine : public Engine {
       for (auto&& i : threaded_opr->mutable_vars) {
         if (i->var_exception && *i->var_exception) {
           threaded_opr->opr_exception = i->var_exception;
+          AddToGlobalExceptions(threaded_opr->opr_exception);
           break;
         }
       }
@@ -475,6 +489,18 @@ class ThreadedEngine : public Engine {
 
   static void OnCompleteStatic(Engine *engine, void *threaded_opr,
                                const dmlc::Error* error);
+  /*!
+   * \brief find exception in global_exception_refs and add it if missing
+   * \param opr_exception the exception to be added to global_exception_refs
+   */
+  inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
+    auto it = std::find(global_exception_refs_.begin(),
+                        global_exception_refs_.end(), opr_exception);
+    if (it == global_exception_refs_.end()) {
+      global_exception_refs_.push_back(opr_exception);
+    }
+    return;
+  }
   /*! \brief append an operator to bulk */
   inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
                          std::vector<VarHandle> const& const_vars,
@@ -542,6 +568,8 @@ class ThreadedEngine : public Engine {
    */
   std::mutex finished_m_;
   std::condition_variable finished_cv_;
+  /*! \brief global exception refs, which are rethrown when WaitForAll is 
called */
+  std::vector<ExceptionRef> global_exception_refs_;
 
   /*!
    * \brief Holding a shared_ptr to the object pool to prevent it from being 
destructed too early
diff --git a/src/resource.cc b/src/resource.cc
index de24286..cd6320d 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -189,12 +189,14 @@ class ResourceManagerImpl : public ResourceManager {
     cpu_rand_->Seed(seed);
     cpu_parallel_rand_->Seed(seed);
 #if MXNET_USE_CUDA
-    gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
-      return new ResourceRandom<gpu>(ctx, seed);
-    })->Seed(seed);
-    gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
-      return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
-    })->Seed(seed);
+    if (ctx.dev_type == Context::kGPU) {
+      gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
+        return new ResourceRandom<gpu>(ctx, seed);
+      })->Seed(seed);
+      gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
+        return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, 
seed);
+      })->Seed(seed);
+    }
 #endif
   }
 
diff --git a/tests/python/unittest/test_exc_handling.py 
b/tests/python/unittest/test_exc_handling.py
index e9e161d..60799f8 100644
--- a/tests/python/unittest/test_exc_handling.py
+++ b/tests/python/unittest/test_exc_handling.py
@@ -34,11 +34,11 @@ def test_exc_imperative():
             c.asnumpy()
 
     imperative(exec_numpy=False)
-    assert_raises(MXNetError, imperative, True)
+    assert_raises(MXNetError, imperative, exec_numpy=True)
 
 @with_seed()
 def test_exc_symbolic():
-    def symbolic(exec_backward=True):
+    def symbolic(exec_backward=True, waitall=True):
         x = mx.sym.Variable('x')
         y = mx.sym.Variable('y')
         z = mx.sym.Variable('z')
@@ -58,16 +58,25 @@ def test_exc_symbolic():
         outputs = exec1.forward()
         if exec_backward:
             exec1.backward()
-            exec1.grad_arrays[0].asnumpy()
+            if waitall:
+                mx.nd.waitall()
+            else:
+                exec1.grad_arrays[0].asnumpy()
         else:
-            outputs[0].asnumpy()
+            if waitall:
+                mx.nd.waitall()
+            else:
+                outputs[0].asnumpy()
 
-    assert_raises(MXNetError, symbolic, False)
-    assert_raises(MXNetError, symbolic, True)
+    assert_raises(MXNetError, symbolic, exec_backward=False)
+    assert_raises(MXNetError, symbolic, exec_backward=True)
+
+    assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True)
+    assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True)
 
 @with_seed()
 def test_exc_gluon():
-    def gluon(exec_wait=True):
+    def gluon(exec_wait=True, waitall=False):
         model = nn.Sequential()
         model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
         model.add(nn.Dropout(1))
@@ -77,46 +86,86 @@ def test_exc_gluon():
         y = model(x)
         model.collect_params().initialize(ctx=[default_context()])
         z = model(mx.nd.random.normal(10, -10, (32, 2, 10), 
ctx=default_context()))
-        if exec_wait:
+        if waitall:
+            mx.nd.waitall()
+        elif exec_wait:
             z.wait_to_read()
 
     gluon(exec_wait=False)
-    assert_raises(MXNetError, gluon, True)
+    assert_raises(MXNetError, gluon, exec_wait=True)
+
+    assert_raises(MXNetError, gluon, waitall=True)
 
 @with_seed()
 def test_exc_multiple_waits():
-    caught = False
-    try:
-        a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
-        a.wait_to_read()
-    except MXNetError:
-        caught = True
-    assert caught, "No exception thrown"
-    try:
-        b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
-        b.wait_to_read()
-    except MXNetError:
-        caught = True
-    assert caught, "No exception thrown"
+    def multiple_waits(waitall=False):
+        # Test calling failed op followed by wait_to_read or waitall twice
+        # Intention is to test rethrow for multiple wait_to_reads and waitalls
+        # for vars with exceptions in same scope
+        caught = False
+        try:
+            a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+            if waitall:
+                mx.nd.waitall()
+            else:
+                a.wait_to_read()
+        except MXNetError:
+            caught = True
+        assert caught, "No exception thrown, exception should be rethrown with 
wait_to_read/waitall"
+        try:
+            b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+            if waitall:
+                mx.nd.waitall()
+            else:
+                b.wait_to_read()
+        except MXNetError:
+            caught = True
+        assert caught, "No exception thrown, exception should be rethrown with 
wait_to_read/waitall"
+
+    multiple_waits(waitall=False)
+    multiple_waits(waitall=True)
 
 @with_seed()
 def test_exc_post_fail():
+    def post_fail(waitall=False):
+        caught = False
+        try:
+            a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+            if waitall:
+                mx.nd.waitall()
+            else:
+                a.asnumpy()
+        except MXNetError:
+            caught = True
+        assert caught, "No exception thrown"
+        b.asnumpy()
+    post_fail(waitall=False)
+    post_fail(waitall=True)
+
+@with_seed()
+def test_exc_mutable_var_fail():
+    def mutable_var_check(waitall=False):
+        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+        a = mx.nd.dot(a, a)
+        if waitall:
+            mx.nd.waitall()
+        else:
+            a.asnumpy()
+    assert_raises(MXNetError, mutable_var_check, waitall=False)
+    assert_raises(MXNetError, mutable_var_check, waitall=True)
+
+@with_seed()
+def test_multiple_waitalls():
     caught = False
     try:
-        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
-        a.asnumpy()
+        a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+        mx.nd.waitall()
     except MXNetError:
         caught = True
     assert caught, "No exception thrown"
-    b.asnumpy()
+    mx.nd.waitall()
+
 
-@with_seed()
-def test_exc_mutable_var_fail():
-    def mutable_var_check():
-        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
-        a = mx.nd.dot(a, a)
-        a.asnumpy()
-    assert_raises(MXNetError, mutable_var_check)
 
 if __name__ == '__main__':
     import nose
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index f96a6ae..17618e4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7164,7 +7164,12 @@ def test_op_output_names_monitor():
 
         op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
         op_exe.set_monitor_callback(get_output_names_callback, 
monitor_all=False)
-        op_exe.forward()
+        try:
+            op_exe.forward()
+            mx.nd.waitall()
+        except mx.base.MXNetError:
+            # skip errors since test is to check output names
+            pass
         for output_name, expected_name in zip(output_names, expected_names):
             assert output_name == expected_name
 
@@ -7210,7 +7215,12 @@ def test_op_all_names_monitor():
 
         op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
         op_exe.set_monitor_callback(get_output_names_callback, 
monitor_all=True)
-        op_exe.forward()
+        try:
+            op_exe.forward()
+            mx.nd.waitall()
+        except mx.base.MXNetError:
+            # skip errors since test is to check all names
+            pass
         for output_name, expected_name in zip(output_names, expected_names):
             assert output_name == expected_name
 

Reply via email to