piiswrong closed pull request #9681: Better Exception Handling for Operators
URL: https://github.com/apache/incubator-mxnet/pull/9681
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 366a6b61b3..fd1fe89bdb 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -141,13 +141,15 @@ class MXNET_API Engine {
    * \param mutable_vars The variables that current operation will mutate.
    * \param prop Property of the function.
    * \param opr_name The operator name.
+   * \param wait Whether this is a WaitForVar operation
    * \return The new operator allocated.
    */
   virtual OprHandle NewOperator(AsyncFn fn,
                                 std::vector<VarHandle> const& const_vars,
                                 std::vector<VarHandle> const& mutable_vars,
                                 FnProperty prop = FnProperty::kNormal,
-                                const char* opr_name = nullptr) = 0;
+                                const char* opr_name = nullptr,
+                                bool wait = false) = 0;
   /*!
    * \brief Delete the given operator.
    * \param op The operator to delete.
@@ -176,13 +178,15 @@ class MXNET_API Engine {
    * \param prop Property of the function.
    * \param priority Priority of the action, as hint to the engine.
    * \param opr_name The operator name.
+   * \param wait Whether this is a WaitForVar operation
    */
   virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
                          std::vector<VarHandle> const& const_vars,
                          std::vector<VarHandle> const& mutable_vars,
                          FnProperty prop = FnProperty::kNormal,
                          int priority = 0,
-                         const char* opr_name = nullptr) = 0;
+                         const char* opr_name = nullptr,
+                         bool wait = false) = 0;
   /*!
    * \brief Schedule the deletion of a variable.
    *
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 86f3877397..adf14b131b 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -73,7 +73,8 @@ class NaiveEngine final : public Engine {
                         std::vector<VarHandle> const& const_vars,
                         std::vector<VarHandle> const& mutable_vars,
                         FnProperty prop = FnProperty::kNormal,
-                        const char* opr_name = nullptr) override {
+                        const char* opr_name = nullptr,
+                        bool wait = false) override {
     NaiveOpr *opr = new NaiveOpr();
     opr->fn = fn;
     opr->const_vars = const_vars;
@@ -125,7 +126,8 @@ class NaiveEngine final : public Engine {
                  std::vector<VarHandle> const& mutable_vars,
                  FnProperty prop = FnProperty::kNormal,
                  int priority = 0,
-                 const char* opr_name = nullptr) override {
+                 const char* opr_name = nullptr,
+                 bool wait = false) override {
     CallbackOnComplete callback = CreateCallback(
         NaiveEngine::OnComplete, nullptr);
     this->req_completed_ = false;
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 2b28a7d602..e166a6dc90 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator(
     std::vector<VarHandle> const& const_vars,
     std::vector<VarHandle> const& mutable_vars,
     FnProperty prop,
-    const char* opr_name) {
+    const char* opr_name,
+    bool wait) {
   auto ret = ThreadedOpr::New();
   ret->opr_name = opr_name;
   ret->fn = std::move(fn);
   ret->prop = prop;
   ret->const_vars.resize(const_vars.size());
   ret->mutable_vars.resize(mutable_vars.size());
+  ret->wait = wait;
   std::transform(const_vars.begin(), const_vars.end(),
                  ret->const_vars.begin(), ThreadedVar::CastFromBase);
   std::transform(mutable_vars.begin(), mutable_vars.end(),
@@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context 
exec_ctx,
                                std::vector<VarHandle> const& mutable_vars,
                                FnProperty prop,
                                int priority,
-                               const char* opr_name) {
+                               const char* opr_name,
+                               bool wait) {
   BulkFlush();
-  ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, 
prop, opr_name);
+  ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, 
prop, opr_name, wait);
   opr->temporary = true;
 #if MXNET_USE_PROFILER
   Profiler *profiler = Profiler::Get();
@@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
 void ThreadedEngine::WaitForVar(VarHandle var) {
   BulkFlush();
   ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
-  if (threaded_var->ready_to_read()) return;
+  if (threaded_var->ready_to_read()) {
+    ThrowException(threaded_var);
+    return;
+  }
   if (engine_info_) {
     LOG(INFO) << "Wait for " << threaded_var;
     debug_wait_var_ = threaded_var;
@@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
       }
       on_complete();
     }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
-    PROFILER_MESSAGE("WaitForVar"));
+    PROFILER_MESSAGE("WaitForVar"), true);
   {
     std::unique_lock<std::mutex> lock{finished_m_};
     finished_cv_.wait(lock, [this, &done]() {
         return done.load() || kill_.load();
     });
   }
+
+  ThrowException(threaded_var);
 }
 
 void ThreadedEngine::WaitForAll() {
@@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* 
threaded_opr) {
   bool is_temporary_opr = threaded_opr->temporary;
   // Mark complete for read variables
   for (auto&& i : threaded_opr->const_vars) {
-    i->CompleteReadDependency([this](OprBlock* opr) {
-        this->PushToExecute(opr, false);
-      });
+    i->CompleteReadDependency(
+        [this](OprBlock* opr) { this->PushToExecute(opr, false); });
   }
   // Mark complete for write variables.
   for (auto&& i : threaded_opr->mutable_vars) {
+    if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
+      i->var_exception = threaded_opr->opr_exception;
+    }
     const bool debug_info = (engine_info_ && debug_wait_var_ == i);
     if (debug_info) {
       LOG(INFO) << "Complete write dep for " << i;
     }
-    const bool to_delete = i->CompleteWriteDependency(
-        [this, debug_info](OprBlock* opr) {
+    const bool to_delete =
+        i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
           if (debug_info) {
             LOG(INFO) << "PushToExecute " << opr;
             debug_push_opr_ = opr;
@@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* 
threaded_opr) {
   }
 }
 
+inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
+  if (threaded_var->var_exception && *threaded_var->var_exception) {
+    std::exception_ptr tmp = *threaded_var->var_exception;
+    *threaded_var->var_exception = nullptr;
+    std::rethrow_exception(tmp);
+  }
+  return;
+}
+
 void ThreadedEngine::OnCompleteStatic(
     Engine *engine, void *opr_block_) {
   OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 1524f25756..f647074cc3 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -175,6 +175,8 @@ class ThreadedVar final
   static std::atomic<std::size_t> counter;
   ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
 #endif  // ENGINE_DEBUG
+  /*! \brief exception_ptr associated with the ThreadedVar */
+  std::shared_ptr<std::exception_ptr> var_exception;
 
  private:
   // TODO(hotpxl) change this to spinlock for faster runtime
@@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr,
    *        that can be deleted right after the operation completed.
    */
   bool temporary{false};
+  /*!
+   * \brief Whether this is a WaitForVar operation
+   */
+  bool wait{false};
   /*!
    * \brief Cast a Opr pointer to ThreadedOpr pointer
    * \param ptr pointer from base.
@@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
   }
   // define possible debug information
   DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
+  /*! \brief exception_ptr associated with the ThreadedOpr */
+  std::shared_ptr<std::exception_ptr> opr_exception;
 };  // struct ThreadedOpr
 
 /*!
@@ -265,7 +273,8 @@ class ThreadedEngine : public Engine {
                            std::vector<VarHandle> const& const_vars,
                            std::vector<VarHandle> const& mutable_vars,
                            FnProperty prop = FnProperty::kNormal,
-                           const char* opr_name = nullptr) override;
+                           const char* opr_name = nullptr,
+                           bool wait = false) override;
   void DeleteOperator(OprHandle op) override;
   void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = 
false) override;
   void PushAsync(AsyncFn exec_fun, Context exec_ctx,
@@ -273,7 +282,8 @@ class ThreadedEngine : public Engine {
                  std::vector<VarHandle> const& mutable_vars,
                  FnProperty prop = FnProperty::kNormal,
                  int priority = 0,
-                 const char* opr_name = nullptr) override;
+                 const char* opr_name = nullptr,
+                 bool wait = false) override;
   void PushSync(SyncFn exec_fn, Context exec_ctx,
                 std::vector<VarHandle> const& const_vars,
                 std::vector<VarHandle> const& mutable_vars,
@@ -321,50 +331,63 @@ class ThreadedEngine : public Engine {
    * \param run_ctx runtime context used to execute the function.
    * \param opr_block the opr_block to be executed and deleted.
    */
-  void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
+  void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
     ThreadedOpr* threaded_opr = opr_block->opr;
 #if MXNET_USE_PROFILER
     if (opr_block->profiling && threaded_opr->opr_name) {
       const Context& ctx = opr_block->ctx;
-      opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, 
ctx.dev_id);
+      opr_block->opr_stat =
+          Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
       uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id());
       opr_block->opr_stat->thread_id = id;
-      strncpy(opr_block->opr_stat->opr_name,
-        threaded_opr->opr_name,
-        sizeof(opr_block->opr_stat->opr_name) - 1);
+      strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name,
+              sizeof(opr_block->opr_stat->opr_name) - 1);
       // record operator start timestamp
       SetOprStart(opr_block->opr_stat);
     }
 #endif
-    CallbackOnComplete callback = this->CreateCallback(
-        ThreadedEngine::OnCompleteStatic, opr_block);
-    bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
+    CallbackOnComplete callback =
+        this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+    const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
     if (debug_info) {
       LOG(INFO) << "ExecuteOprBlock " << opr_block
                 << "shutdown_phase=" << shutdown_phase_;
     }
     if (!shutdown_phase_) {
       try {
+        OnStart(threaded_opr);
         if (debug_info) {
           LOG(INFO) << "ExecuteOprFn ";
         }
-        threaded_opr->fn(run_ctx, callback);
+        try {
+          if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
+              threaded_opr->wait) {
+            threaded_opr->fn(run_ctx, callback);
+          } else {
+            callback();
+          }
+        } catch (dmlc::Error& e) {
+          threaded_opr->opr_exception =
+              std::make_shared<std::exception_ptr>(std::current_exception());
+          callback();
+        }
         if (debug_info) {
           LOG(INFO) << "Fin ExecuteOprFn ";
         }
-      } catch(dmlc::Error &e) {
+      } catch (std::exception& e) {
         std::string what = e.what();
         if (what.find("driver shutting down") == std::string::npos &&
             !shutdown_phase_) {
-          LOG(FATAL) << e.what() << "\n" <<
-            "A fatal error occurred in asynchronous engine operation. "
-            "If you do not know what caused this error, "
-            "you can try set environment variable MXNET_ENGINE_TYPE "
-            "to NaiveEngine and run with debugger (i.e. gdb). "
-            "This will force all operations to be synchronous and "
-            "backtrace will give you the series of calls that lead "
-            "to this error. Remember to set MXNET_ENGINE_TYPE back to "
-            "empty after debugging.";
+          LOG(FATAL)
+              << e.what() << "\n"
+              << "A fatal error occurred in asynchronous engine operation. "
+                 "If you do not know what caused this error, "
+                 "you can try set environment variable MXNET_ENGINE_TYPE "
+                 "to NaiveEngine and run with debugger (i.e. gdb). "
+                 "This will force all operations to be synchronous and "
+                 "backtrace will give you the series of calls that lead "
+                 "to this error. Remember to set MXNET_ENGINE_TYPE back to "
+                 "empty after debugging.";
         }
       }
     } else {
@@ -414,7 +437,34 @@ class ThreadedEngine : public Engine {
    * On operation completion, this will trigger subsequent operations.
    */
   inline void OnComplete(ThreadedOpr* threaded_opr);
-  // callback to the threaded engine
+  /*!
+   * \brief rethrow caught exception in WaitForVar
+   * \param threaded_var the var that we are waiting to read
+   */
+  inline void ThrowException(ThreadedVar* threaded_var);
+  /*!
+   * \brief Mark exceptions before operation execution.
+   *
+   * Will mark the operator as a failure and associate exception_ptr
+   * if any of the read dependencies have exception associated.
+   */
+  inline void OnStart(ThreadedOpr* threaded_opr) {
+    for (auto&& i : threaded_opr->const_vars) {
+      if (i->var_exception && *i->var_exception) {
+        threaded_opr->opr_exception = i->var_exception;
+        break;
+      }
+    }
+    if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
+      for (auto&& i : threaded_opr->mutable_vars) {
+        if (i->var_exception && *i->var_exception) {
+          threaded_opr->opr_exception = i->var_exception;
+          break;
+        }
+      }
+    }
+  }
+
   static void OnCompleteStatic(Engine *engine, void *threaded_opr);
   /*! \brief append an operator to bulk */
   inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h
index f0dd61f01a..beb73894bd 100644
--- a/src/storage/cpu_device_storage.h
+++ b/src/storage/cpu_device_storage.h
@@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
   void* ptr;
 #if _MSC_VER
   ptr = _aligned_malloc(size, alignment_);
-  if (ptr == NULL) throw std::bad_alloc();
+  if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
 #else
   int ret = posix_memalign(&ptr, alignment_, size);
-  if (ret != 0) throw std::bad_alloc();
+  if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
 #endif
   return ptr;
 }
diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h
index c859892233..435c7e81d2 100644
--- a/src/storage/gpu_device_storage.h
+++ b/src/storage/gpu_device_storage.h
@@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
 #endif  // MXNET_USE_NCCL
   cudaError_t e = cudaMalloc(&ret, size);
   if (e != cudaSuccess && e != cudaErrorCudartUnloading)
-    throw std::bad_alloc();
+    LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
 #else   // MXNET_USE_CUDA
   LOG(FATAL) << "Please compile with CUDA enabled";
 #endif  // MXNET_USE_CUDA
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 55bb30cc7d..83f6193fdc 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -34,6 +34,7 @@
 from test_random import *
 from test_gluon import *
 from test_loss import *
+from test_exc_handling import *
 #from test_rnn import *
 from test_gluon_rnn import *
 from test_sparse_ndarray import test_create_csr, test_create_row_sparse, 
test_sparse_nd_slice
diff --git a/tests/python/unittest/test_exc_handling.py 
b/tests/python/unittest/test_exc_handling.py
new file mode 100644
index 0000000000..0c2c43db65
--- /dev/null
+++ b/tests/python/unittest/test_exc_handling.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.base import MXNetError
+from mxnet.test_utils import assert_exception, default_context, 
set_default_context
+from nose.tools import assert_raises
+
+def test_exc_imperative():
+    def imperative(exec_numpy=True):
+        a = mx.nd.random.normal(0, 1, (2, 2))
+        b = mx.nd.random.normal(0, -1, (2, 2))
+        c = mx.nd.dot(a, b)
+        if exec_numpy:
+            c.asnumpy()
+
+    imperative(exec_numpy=False)
+    assert_raises(MXNetError, imperative, True)
+
+def test_exc_symbolic():
+    def symbolic(exec_backward=True):
+        x = mx.sym.Variable('x')
+        y = mx.sym.Variable('y')
+        z = mx.sym.Variable('z')
+        x_shape = (2, 2)
+        z_shape = (3, 2)
+        inputs = [x, y]
+        out = mx.symbol.ElementWiseSum(*inputs, name="esum")
+        out = mx.sym.dot(z, out)
+        out2 = mx.sym.random.normal(0, -1, x_shape, ctx=default_context())
+        out = mx.sym.dot(out, out2)
+        out = mx.sym.make_loss(out)
+        arr = {'x': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()),
+               'y': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()),
+               'z': mx.nd.random.normal(0, 1, z_shape, ctx=default_context())}
+        arr_grad = {'x': mx.nd.empty(x_shape), 'y': mx.nd.empty(x_shape), 'z': 
mx.nd.empty(z_shape)}
+        exec1 = out.bind(ctx=default_context(), args=arr, args_grad=arr_grad)
+        outputs = exec1.forward()
+        if exec_backward:
+            exec1.backward()
+            exec1.grad_arrays[0].asnumpy()
+        else:
+            outputs[0].asnumpy()
+
+    assert_raises(MXNetError, symbolic, False)
+    assert_raises(MXNetError, symbolic, True)
+
+def test_exc_gluon():
+    def gluon(exec_wait=True):
+        model = nn.Sequential()
+        model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
+        model.add(nn.Dropout(1))
+        model.add(nn.Dense(64, activation='tanh', in_units=256),
+                  nn.Dense(32, in_units=64))
+        x = mx.sym.var('data')
+        y = model(x)
+        model.collect_params().initialize(ctx=[default_context()])
+        z = model(mx.nd.random.normal(10, -10, (32, 2, 10), 
ctx=default_context()))
+        if exec_wait:
+            z.wait_to_read()
+
+    gluon(exec_wait=False)
+    assert_raises(MXNetError, gluon, True)
+
+def test_exc_multiple_waits():
+    caught = False
+    try:
+        a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+        a.wait_to_read()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+    try:
+        b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+        b.wait_to_read()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+
+def test_exc_post_fail():
+    caught = False
+    try:
+        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+        a.asnumpy()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+    b.asnumpy()
+
+def test_exc_mutable_var_fail():
+    def mutable_var_check():
+        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+        a = mx.nd.dot(a, a)
+        a.asnumpy()
+    assert_raises(MXNetError, mutable_var_check)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to