This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 7b24137 Better Exception Handling for Operators (#9681) 7b24137 is described below commit 7b24137ed45df605defa4ce72ec91554f6e445f0 Author: Anirudh Subramanian <anirudh2...@gmail.com> AuthorDate: Tue Feb 13 11:13:04 2018 -0800 Better Exception Handling for Operators (#9681) * Add support for threaded engine * Add support for threaded engine * Remove on_start_callback for else * Add support for global_ex_ptr * Rethrow in waitall only once * run tests for gpu * Add comments for exception_ptr * Fix lint * Push exc_handling tests * Add comments for OnStart * Fixes for exc handling * Catch std::exception for all other exceptions * Rollback std::move use * Fix style * Fix onstart * Fix debug_info * Throw exception only once in an execution graph * make test naming consistent * Fix symbolic test * Remove unused code --- include/mxnet/engine.h | 8 +- src/engine/naive_engine.cc | 6 +- src/engine/threaded_engine.cc | 39 +++++++--- src/engine/threaded_engine.h | 94 +++++++++++++++++------ src/storage/cpu_device_storage.h | 4 +- src/storage/gpu_device_storage.h | 2 +- tests/python/gpu/test_operator_gpu.py | 1 + tests/python/unittest/test_exc_handling.py | 116 +++++++++++++++++++++++++++++ 8 files changed, 231 insertions(+), 39 deletions(-) diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 366a6b6..fd1fe89 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -141,13 +141,15 @@ class MXNET_API Engine { * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param opr_name The operator name. + * \param wait Whether this is a WaitForVar operation * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) = 0; + const char* opr_name = nullptr, + bool wait = false) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. @@ -176,13 +178,15 @@ class MXNET_API Engine { * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operator name. + * \param wait Whether this is a WaitForVar operation */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) = 0; + const char* opr_name = nullptr, + bool wait = false) = 0; /*! * \brief Schedule the deletion of a variable. * diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 86f3877..adf14b1 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -73,7 +73,8 @@ class NaiveEngine final : public Engine { std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) override { + const char* opr_name = nullptr, + bool wait = false) override { NaiveOpr *opr = new NaiveOpr(); opr->fn = fn; opr->const_vars = const_vars; @@ -125,7 +126,8 @@ class NaiveEngine final : public Engine { std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) override { + const char* opr_name = nullptr, + bool wait = false) override { CallbackOnComplete callback = CreateCallback( NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 2b28a7d..e166a6d 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator( std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop, - const char* opr_name) { + const char* opr_name, + bool wait) { auto ret = ThreadedOpr::New(); ret->opr_name = opr_name; ret->fn = std::move(fn); ret->prop = prop; ret->const_vars.resize(const_vars.size()); ret->mutable_vars.resize(mutable_vars.size()); + ret->wait = wait; std::transform(const_vars.begin(), const_vars.end(), ret->const_vars.begin(), ThreadedVar::CastFromBase); std::transform(mutable_vars.begin(), mutable_vars.end(), @@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector<VarHandle> const& mutable_vars, FnProperty prop, int priority, - const char* opr_name) { + const char* opr_name, + bool wait) { BulkFlush(); - ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name); + ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait); opr->temporary = true; #if MXNET_USE_PROFILER Profiler *profiler = Profiler::Get(); @@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn, void ThreadedEngine::WaitForVar(VarHandle var) { BulkFlush(); ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); - if (threaded_var->ready_to_read()) return; + if (threaded_var->ready_to_read()) { + ThrowException(threaded_var); + return; + } if (engine_info_) { LOG(INFO) << "Wait for " << threaded_var; debug_wait_var_ = threaded_var; @@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) { } on_complete(); }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0, - PROFILER_MESSAGE("WaitForVar")); + PROFILER_MESSAGE("WaitForVar"), true); { std::unique_lock<std::mutex> lock{finished_m_}; finished_cv_.wait(lock, [this, &done]() { return done.load() || kill_.load(); }); } + + ThrowException(threaded_var); } void ThreadedEngine::WaitForAll() { @@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { bool is_temporary_opr = threaded_opr->temporary; // Mark complete for read variables for (auto&& i : threaded_opr->const_vars) { - i->CompleteReadDependency([this](OprBlock* opr) { - this->PushToExecute(opr, false); - }); + i->CompleteReadDependency( + [this](OprBlock* opr) { this->PushToExecute(opr, false); }); } // Mark complete for write variables. for (auto&& i : threaded_opr->mutable_vars) { + if (threaded_opr->opr_exception && *threaded_opr->opr_exception) { + i->var_exception = threaded_opr->opr_exception; + } const bool debug_info = (engine_info_ && debug_wait_var_ == i); if (debug_info) { LOG(INFO) << "Complete write dep for " << i; } - const bool to_delete = i->CompleteWriteDependency( - [this, debug_info](OprBlock* opr) { + const bool to_delete = + i->CompleteWriteDependency([this, debug_info](OprBlock* opr) { if (debug_info) { LOG(INFO) << "PushToExecute " << opr; debug_push_opr_ = opr; @@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { } } +inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) { + if (threaded_var->var_exception && *threaded_var->var_exception) { + std::exception_ptr tmp = *threaded_var->var_exception; + *threaded_var->var_exception = nullptr; + std::rethrow_exception(tmp); + } + return; +} + void ThreadedEngine::OnCompleteStatic( Engine *engine, void *opr_block_) { OprBlock *opr_block = static_cast<OprBlock*>(opr_block_); diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 1524f25..f647074 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -175,6 +175,8 @@ class ThreadedVar final static std::atomic<std::size_t> counter; ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } #endif // ENGINE_DEBUG + /*! \brief exception_ptr associated with the ThreadedVar */ + std::shared_ptr<std::exception_ptr> var_exception; private: // TODO(hotpxl) change this to spinlock for faster runtime @@ -237,6 +239,10 @@ struct ThreadedOpr final : public Opr, */ bool temporary{false}; /*! + * \brief Whether this is a WaitForVar operation + */ + bool wait{false}; + /*! * \brief Cast a Opr pointer to ThreadedOpr pointer * \param ptr pointer from base. * \return a casted pointer. @@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr, } // define possible debug information DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); + /*! \brief exception_ptr associated with the ThreadedOpr */ + std::shared_ptr<std::exception_ptr> opr_exception; }; // struct ThreadedOpr /*! @@ -265,7 +273,8 @@ class ThreadedEngine : public Engine { std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, - const char* opr_name = nullptr) override; + const char* opr_name = nullptr, + bool wait = false) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, @@ -273,7 +282,8 @@ class ThreadedEngine : public Engine { std::vector<VarHandle> const& mutable_vars, FnProperty prop = FnProperty::kNormal, int priority = 0, - const char* opr_name = nullptr) override; + const char* opr_name = nullptr, + bool wait = false) override; void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector<VarHandle> const& const_vars, std::vector<VarHandle> const& mutable_vars, @@ -321,50 +331,63 @@ class ThreadedEngine : public Engine { * \param run_ctx runtime context used to execute the function. * \param opr_block the opr_block to be executed and deleted. */ - void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) { + void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) { ThreadedOpr* threaded_opr = opr_block->opr; #if MXNET_USE_PROFILER if (opr_block->profiling && threaded_opr->opr_name) { const Context& ctx = opr_block->ctx; - opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); + opr_block->opr_stat = + Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id()); opr_block->opr_stat->thread_id = id; - strncpy(opr_block->opr_stat->opr_name, - threaded_opr->opr_name, - sizeof(opr_block->opr_stat->opr_name) - 1); + strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name, + sizeof(opr_block->opr_stat->opr_name) - 1); // record operator start timestamp SetOprStart(opr_block->opr_stat); } #endif - CallbackOnComplete callback = this->CreateCallback( - ThreadedEngine::OnCompleteStatic, opr_block); - bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); + CallbackOnComplete callback = + this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); + const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); if (debug_info) { LOG(INFO) << "ExecuteOprBlock " << opr_block << "shutdown_phase=" << shutdown_phase_; } if (!shutdown_phase_) { try { + OnStart(threaded_opr); if (debug_info) { LOG(INFO) << "ExecuteOprFn "; } - threaded_opr->fn(run_ctx, callback); + try { + if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || + threaded_opr->wait) { + threaded_opr->fn(run_ctx, callback); + } else { + callback(); + } + } catch (dmlc::Error& e) { + threaded_opr->opr_exception = + std::make_shared<std::exception_ptr>(std::current_exception()); + callback(); + } if (debug_info) { LOG(INFO) << "Fin ExecuteOprFn "; } - } catch(dmlc::Error &e) { + } catch (std::exception& e) { std::string what = e.what(); if (what.find("driver shutting down") == std::string::npos && !shutdown_phase_) { - LOG(FATAL) << e.what() << "\n" << - "A fatal error occurred in asynchronous engine operation. " - "If you do not know what caused this error, " - "you can try set environment variable MXNET_ENGINE_TYPE " - "to NaiveEngine and run with debugger (i.e. gdb). " - "This will force all operations to be synchronous and " - "backtrace will give you the series of calls that lead " - "to this error. Remember to set MXNET_ENGINE_TYPE back to " - "empty after debugging."; + LOG(FATAL) + << e.what() << "\n" + << "A fatal error occurred in asynchronous engine operation. " + "If you do not know what caused this error, " + "you can try set environment variable MXNET_ENGINE_TYPE " + "to NaiveEngine and run with debugger (i.e. gdb). " + "This will force all operations to be synchronous and " + "backtrace will give you the series of calls that lead " + "to this error. Remember to set MXNET_ENGINE_TYPE back to " + "empty after debugging."; } } } else { @@ -414,7 +437,34 @@ class ThreadedEngine : public Engine { * On operation completion, this will trigger subsequent operations. */ inline void OnComplete(ThreadedOpr* threaded_opr); - // callback to the threaded engine + /*! + * \brief rethrow caught exception in WaitForVar + * \param threaded_var the var that we are waiting to read + */ + inline void ThrowException(ThreadedVar* threaded_var); + /*! + * \brief Mark exceptions before operation execution. + * + * Will mark the operator as a failure and associate exception_ptr + * if any of the read dependencies have exception associated. + */ + inline void OnStart(ThreadedOpr* threaded_opr) { + for (auto&& i : threaded_opr->const_vars) { + if (i->var_exception && *i->var_exception) { + threaded_opr->opr_exception = i->var_exception; + break; + } + } + if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) { + for (auto&& i : threaded_opr->mutable_vars) { + if (i->var_exception && *i->var_exception) { + threaded_opr->opr_exception = i->var_exception; + break; + } + } + } + } + static void OnCompleteStatic(Engine *engine, void *threaded_opr); /*! \brief append an operator to bulk */ inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h index f0dd61f..beb7389 100644 --- a/src/storage/cpu_device_storage.h +++ b/src/storage/cpu_device_storage.h @@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) { void* ptr; #if _MSC_VER ptr = _aligned_malloc(size, alignment_); - if (ptr == NULL) throw std::bad_alloc(); + if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory"; #else int ret = posix_memalign(&ptr, alignment_, size); - if (ret != 0) throw std::bad_alloc(); + if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory"; #endif return ptr; } diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h index c859892..435c7e8 100644 --- a/src/storage/gpu_device_storage.h +++ b/src/storage/gpu_device_storage.h @@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) { #endif // MXNET_USE_NCCL cudaError_t e = cudaMalloc(&ret, size); if (e != cudaSuccess && e != cudaErrorCudartUnloading) - throw std::bad_alloc(); + LOG(FATAL) << "CUDA: " << cudaGetErrorString(e); #else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f9f00c4..91eb958 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -34,6 +34,7 @@ from test_optimizer import * from test_random import * from test_gluon import * from test_loss import * +from test_exc_handling import * #from test_rnn import * from test_gluon_rnn import * from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice diff --git a/tests/python/unittest/test_exc_handling.py b/tests/python/unittest/test_exc_handling.py new file mode 100644 index 0000000..0c2c43d --- /dev/null +++ b/tests/python/unittest/test_exc_handling.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +import numpy as np +from mxnet import gluon +from mxnet.gluon import nn +from mxnet.base import MXNetError +from mxnet.test_utils import assert_exception, default_context, set_default_context +from nose.tools import assert_raises + +def test_exc_imperative(): + def imperative(exec_numpy=True): + a = mx.nd.random.normal(0, 1, (2, 2)) + b = mx.nd.random.normal(0, -1, (2, 2)) + c = mx.nd.dot(a, b) + if exec_numpy: + c.asnumpy() + + imperative(exec_numpy=False) + assert_raises(MXNetError, imperative, True) + +def test_exc_symbolic(): + def symbolic(exec_backward=True): + x = mx.sym.Variable('x') + y = mx.sym.Variable('y') + z = mx.sym.Variable('z') + x_shape = (2, 2) + z_shape = (3, 2) + inputs = [x, y] + out = mx.symbol.ElementWiseSum(*inputs, name="esum") + out = mx.sym.dot(z, out) + out2 = mx.sym.random.normal(0, -1, x_shape, ctx=default_context()) + out = mx.sym.dot(out, out2) + out = mx.sym.make_loss(out) + arr = {'x': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()), + 'y': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()), + 'z': mx.nd.random.normal(0, 1, z_shape, ctx=default_context())} + arr_grad = {'x': mx.nd.empty(x_shape), 'y': mx.nd.empty(x_shape), 'z': mx.nd.empty(z_shape)} + exec1 = out.bind(ctx=default_context(), args=arr, args_grad=arr_grad) + outputs = exec1.forward() + if exec_backward: + exec1.backward() + exec1.grad_arrays[0].asnumpy() + else: + outputs[0].asnumpy() + + assert_raises(MXNetError, symbolic, False) + assert_raises(MXNetError, symbolic, True) + +def test_exc_gluon(): + def gluon(exec_wait=True): + model = nn.Sequential() + model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False)) + model.add(nn.Dropout(1)) + model.add(nn.Dense(64, activation='tanh', in_units=256), + nn.Dense(32, in_units=64)) + x = mx.sym.var('data') + y = model(x) + model.collect_params().initialize(ctx=[default_context()]) + z = model(mx.nd.random.normal(10, -10, (32, 2, 10), ctx=default_context())) + if exec_wait: + z.wait_to_read() + + gluon(exec_wait=False) + assert_raises(MXNetError, gluon, True) + +def test_exc_multiple_waits(): + caught = False + try: + a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + a.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown" + try: + b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context()) + b.wait_to_read() + except MXNetError: + caught = True + assert caught, "No exception thrown" + +def test_exc_post_fail(): + caught = False + try: + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + a.asnumpy() + except MXNetError: + caught = True + assert caught, "No exception thrown" + b.asnumpy() + +def test_exc_mutable_var_fail(): + def mutable_var_check(): + a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context()) + a = mx.nd.dot(a, a) + a.asnumpy() + assert_raises(MXNetError, mutable_var_check) + +if __name__ == '__main__': + import nose + nose.runmodule() -- To stop receiving notification emails like this one, please contact j...@apache.org.