This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3781816 Add exception handling support for waitall (#14397)
3781816 is described below
commit 37818165a7d97031a3db613d98511284bcccfbc7
Author: Anirudh Subramanian <[email protected]>
AuthorDate: Mon Apr 8 00:21:36 2019 -0700
Add exception handling support for waitall (#14397)
* Relax constexpr restriction
* Change imagenet_gen_qsym_mkldnn
* Add exception handling support for waitall
* Fix exception handling documentation
* Revert constexpr change
* Add comments
* Fix test
* Skip exception for op check names
* Print exceptions thrown for CPP Package NDArray module
* Reducing batch_size to make cpp-package example pass
* Fix bug: #14426
* use ExceptionRef in threaded_engine code
* add note for performance impact of waitall
* Add check for GPU contxt
* Use range for with const reference
* Improve comments and error message for exception handling test
* Change exception_ptr name in waitall
* Fix bug
---
cpp-package/example/resnet.cpp | 2 +-
cpp-package/include/mxnet-cpp/ndarray.hpp | 6 +-
docs/architecture/exception_handling.md | 3 -
python/mxnet/ndarray/ndarray.py | 7 +-
src/engine/threaded_engine.cc | 20 +++++
src/engine/threaded_engine.h | 36 ++++++++-
src/resource.cc | 14 ++--
tests/python/unittest/test_exc_handling.py | 113 +++++++++++++++++++++--------
tests/python/unittest/test_operator.py | 14 +++-
9 files changed, 159 insertions(+), 56 deletions(-)
diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp
index f59f606..8f8fd12 100644
--- a/cpp-package/example/resnet.cpp
+++ b/cpp-package/example/resnet.cpp
@@ -185,7 +185,7 @@ int main(int argc, char const *argv[]) {
#if !MXNET_USE_CPU
if (num_gpu > 0) {
ctx = Context::gpu();
- batch_size = 50;
+ batch_size = 32;
}
#endif
diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp
b/cpp-package/include/mxnet-cpp/ndarray.hpp
index 966cf75..b667542 100644
--- a/cpp-package/include/mxnet-cpp/ndarray.hpp
+++ b/cpp-package/include/mxnet-cpp/ndarray.hpp
@@ -233,12 +233,12 @@ inline NDArray NDArray::Reshape(const Shape &new_shape)
const {
return NDArray(handle);
}
inline void NDArray::WaitToRead() const {
- CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0);
+ CHECK_EQ(MXNDArrayWaitToRead(blob_ptr_->handle_), 0) << MXGetLastError();
}
inline void NDArray::WaitToWrite() {
- CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0);
+ CHECK_EQ(MXNDArrayWaitToWrite(blob_ptr_->handle_), 0) << MXGetLastError();
}
-inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0); }
+inline void NDArray::WaitAll() { CHECK_EQ(MXNDArrayWaitAll(), 0) <<
MXGetLastError(); }
inline void NDArray::SampleGaussian(mx_float mu, mx_float sigma, NDArray *out)
{
Operator("_random_normal")(mu, sigma).Invoke(*out);
}
diff --git a/docs/architecture/exception_handling.md
b/docs/architecture/exception_handling.md
index 6a9ab9a..87481bc 100644
--- a/docs/architecture/exception_handling.md
+++ b/docs/architecture/exception_handling.md
@@ -123,6 +123,3 @@ except mx.base.MXNetError as ex:
d.asnumpy()
```
-### Limitation
-
-Rethrowing exceptions as part of `mx.nd.waitall` is not supported. So if your
code executes a few operators and then calls `waitall` instead of
`wait_to_read`/`asnumpy`, the exception will disappear. Please avoid waitalls
in your code unless you are confident about your code not throwing exception in
any scenario.
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index acb7b28..87f2712 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -158,12 +158,9 @@ def waitall():
This function is used for benchmarking only.
- .. warning::
+ .. note::
- If your code has exceptions, `waitall` can cause silent failures.
- For this reason you should avoid `waitall` in your code.
- Use it only if you are confident that your code is error free.
- Then make sure you call `wait_to_read` on all outputs after `waitall`.
+ If your mxnet code throws an exception, then waitall can cause
performance impact.
"""
check_call(_LIB.MXNDArrayWaitAll())
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index b5897a1..986b6ad 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -415,6 +415,23 @@ void ThreadedEngine::WaitForAll() {
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
});
+ std::exception_ptr exception_to_rethrow = nullptr;
+ if (!global_exception_refs_.empty()) {
+ // iterate through all exception refs
+ for (const auto& global_exception_ref : global_exception_refs_) {
+ // the first exception will be saved to be rethrown later
+ if (*global_exception_ref != nullptr && exception_to_rethrow == nullptr)
{
+ exception_to_rethrow = *global_exception_ref;
+ }
+ // clear exceptions, WaitToRead following WaitForAll shouldn't throw
+ *global_exception_ref = nullptr;
+ }
+ // A waitall following a waitall shouldn't throw any exceptions
+ global_exception_refs_.clear();
+ if (exception_to_rethrow != nullptr) {
+ std::rethrow_exception(exception_to_rethrow);
+ }
+ }
}
inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
@@ -428,6 +445,9 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr*
threaded_opr) {
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
+ // add current operator exceptions to global exceptions if not already
+ // added
+ AddToGlobalExceptions(threaded_opr->opr_exception);
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 640eac4..3d2119d 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -60,6 +60,9 @@ namespace engine {
// Forward declarations
struct ThreadedOpr;
+/*! shared_ptr to exception_ptr, used for exception handling */
+typedef std::shared_ptr<std::exception_ptr> ExceptionRef;
+
/*!
* \brief Operation block in the scheduler.
* Each OprBlock corresponds to an operation pushed to the engine.
@@ -177,8 +180,12 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
- /*! \brief exception_ptr associated with the ThreadedVar */
- std::shared_ptr<std::exception_ptr> var_exception;
+ /*!
+ * \brief exception_ptr associated with the ThreadedOpr
+ * cannot modify state of exception object since dereferencing
+ * exception_ptr is undefined behavior. Using shared_ptr to hold
+ * exception_ptr and overcome this limitation */
+ ExceptionRef var_exception;
private:
// TODO(hotpxl) change this to spinlock for faster runtime
@@ -254,8 +261,12 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
- /*! \brief exception_ptr associated with the ThreadedOpr */
- std::shared_ptr<std::exception_ptr> opr_exception;
+ /*!
+ * \brief exception_ptr associated with the ThreadedOpr
+ * cannot modify state of exception object since dereferencing
+ * exception_ptr is undefined behavior. Using shared_ptr to hold
+ * exception_ptr and overcome this limitation */
+ ExceptionRef opr_exception;
}; // struct ThreadedOpr
/*!
@@ -432,6 +443,7 @@ class ThreadedEngine : public Engine {
};
/*! thread local store for bulk */
typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
+
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
@@ -460,6 +472,7 @@ class ThreadedEngine : public Engine {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
+ AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
@@ -467,6 +480,7 @@ class ThreadedEngine : public Engine {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
+ AddToGlobalExceptions(threaded_opr->opr_exception);
break;
}
}
@@ -475,6 +489,18 @@ class ThreadedEngine : public Engine {
static void OnCompleteStatic(Engine *engine, void *threaded_opr,
const dmlc::Error* error);
+ /*!
+ * \brief find exception in global_exception_refs and add it if missing
+ * \param opr_exception the exception to be added to global_exception_refs
+ */
+ inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) {
+ auto it = std::find(global_exception_refs_.begin(),
+ global_exception_refs_.end(), opr_exception);
+ if (it == global_exception_refs_.end()) {
+ global_exception_refs_.push_back(opr_exception);
+ }
+ return;
+ }
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
@@ -542,6 +568,8 @@ class ThreadedEngine : public Engine {
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
+ /*! \brief global exception refs, which are rethrown when WaitForAll is
called */
+ std::vector<ExceptionRef> global_exception_refs_;
/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being
destructed too early
diff --git a/src/resource.cc b/src/resource.cc
index de24286..cd6320d 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -189,12 +189,14 @@ class ResourceManagerImpl : public ResourceManager {
cpu_rand_->Seed(seed);
cpu_parallel_rand_->Seed(seed);
#if MXNET_USE_CUDA
- gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
- return new ResourceRandom<gpu>(ctx, seed);
- })->Seed(seed);
- gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
- return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_, seed);
- })->Seed(seed);
+ if (ctx.dev_type == Context::kGPU) {
+ gpu_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
+ return new ResourceRandom<gpu>(ctx, seed);
+ })->Seed(seed);
+ gpu_parallel_rand_.Get(ctx.dev_id, [ctx, seed, this]() {
+ return new ResourceParallelRandom<gpu>(ctx, gpu_native_rand_copy_,
seed);
+ })->Seed(seed);
+ }
#endif
}
diff --git a/tests/python/unittest/test_exc_handling.py
b/tests/python/unittest/test_exc_handling.py
index e9e161d..60799f8 100644
--- a/tests/python/unittest/test_exc_handling.py
+++ b/tests/python/unittest/test_exc_handling.py
@@ -34,11 +34,11 @@ def test_exc_imperative():
c.asnumpy()
imperative(exec_numpy=False)
- assert_raises(MXNetError, imperative, True)
+ assert_raises(MXNetError, imperative, exec_numpy=True)
@with_seed()
def test_exc_symbolic():
- def symbolic(exec_backward=True):
+ def symbolic(exec_backward=True, waitall=True):
x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = mx.sym.Variable('z')
@@ -58,16 +58,25 @@ def test_exc_symbolic():
outputs = exec1.forward()
if exec_backward:
exec1.backward()
- exec1.grad_arrays[0].asnumpy()
+ if waitall:
+ mx.nd.waitall()
+ else:
+ exec1.grad_arrays[0].asnumpy()
else:
- outputs[0].asnumpy()
+ if waitall:
+ mx.nd.waitall()
+ else:
+ outputs[0].asnumpy()
- assert_raises(MXNetError, symbolic, False)
- assert_raises(MXNetError, symbolic, True)
+ assert_raises(MXNetError, symbolic, exec_backward=False)
+ assert_raises(MXNetError, symbolic, exec_backward=True)
+
+ assert_raises(MXNetError, symbolic, exec_backward=False, waitall=True)
+ assert_raises(MXNetError, symbolic, exec_backward=True, waitall=True)
@with_seed()
def test_exc_gluon():
- def gluon(exec_wait=True):
+ def gluon(exec_wait=True, waitall=False):
model = nn.Sequential()
model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
model.add(nn.Dropout(1))
@@ -77,46 +86,86 @@ def test_exc_gluon():
y = model(x)
model.collect_params().initialize(ctx=[default_context()])
z = model(mx.nd.random.normal(10, -10, (32, 2, 10),
ctx=default_context()))
- if exec_wait:
+ if waitall:
+ mx.nd.waitall()
+ elif exec_wait:
z.wait_to_read()
gluon(exec_wait=False)
- assert_raises(MXNetError, gluon, True)
+ assert_raises(MXNetError, gluon, exec_wait=True)
+
+ assert_raises(MXNetError, gluon, waitall=True)
@with_seed()
def test_exc_multiple_waits():
- caught = False
- try:
- a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
- a.wait_to_read()
- except MXNetError:
- caught = True
- assert caught, "No exception thrown"
- try:
- b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
- b.wait_to_read()
- except MXNetError:
- caught = True
- assert caught, "No exception thrown"
+ def multiple_waits(waitall=False):
+ # Test calling failed op followed by wait_to_read or waitall twice
+ # Intention is to test rethrow for multiple wait_to_reads and waitalls
+ # for vars with exceptions in same scope
+ caught = False
+ try:
+ a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+ if waitall:
+ mx.nd.waitall()
+ else:
+ a.wait_to_read()
+ except MXNetError:
+ caught = True
+ assert caught, "No exception thrown, exception should be rethrown with
wait_to_read/waitall"
+ try:
+ b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+ if waitall:
+ mx.nd.waitall()
+ else:
+ b.wait_to_read()
+ except MXNetError:
+ caught = True
+ assert caught, "No exception thrown, exception should be rethrown with
wait_to_read/waitall"
+
+ multiple_waits(waitall=False)
+ multiple_waits(waitall=True)
@with_seed()
def test_exc_post_fail():
+ def post_fail(waitall=False):
+ caught = False
+ try:
+ a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+ if waitall:
+ mx.nd.waitall()
+ else:
+ a.asnumpy()
+ except MXNetError:
+ caught = True
+ assert caught, "No exception thrown"
+ b.asnumpy()
+ post_fail(waitall=False)
+ post_fail(waitall=True)
+
+@with_seed()
+def test_exc_mutable_var_fail():
+ def mutable_var_check(waitall=False):
+ a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+ a = mx.nd.dot(a, a)
+ if waitall:
+ mx.nd.waitall()
+ else:
+ a.asnumpy()
+ assert_raises(MXNetError, mutable_var_check, waitall=False)
+ assert_raises(MXNetError, mutable_var_check, waitall=True)
+
+@with_seed()
+def test_multiple_waitalls():
caught = False
try:
- a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
- a.asnumpy()
+ a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+ mx.nd.waitall()
except MXNetError:
caught = True
assert caught, "No exception thrown"
- b.asnumpy()
+ mx.nd.waitall()
+
-@with_seed()
-def test_exc_mutable_var_fail():
- def mutable_var_check():
- a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
- a = mx.nd.dot(a, a)
- a.asnumpy()
- assert_raises(MXNetError, mutable_var_check)
if __name__ == '__main__':
import nose
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index f96a6ae..17618e4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -7164,7 +7164,12 @@ def test_op_output_names_monitor():
op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback,
monitor_all=False)
- op_exe.forward()
+ try:
+ op_exe.forward()
+ mx.nd.waitall()
+ except mx.base.MXNetError:
+ # skip errors since test is to check output names
+ pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name
@@ -7210,7 +7215,12 @@ def test_op_all_names_monitor():
op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
op_exe.set_monitor_callback(get_output_names_callback,
monitor_all=True)
- op_exe.forward()
+ try:
+ op_exe.forward()
+ mx.nd.waitall()
+ except mx.base.MXNetError:
+ # skip errors since test is to check all names
+ pass
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name