This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b24137  Better Exception Handling for Operators (#9681)
7b24137 is described below

commit 7b24137ed45df605defa4ce72ec91554f6e445f0
Author: Anirudh Subramanian <anirudh2...@gmail.com>
AuthorDate: Tue Feb 13 11:13:04 2018 -0800

    Better Exception Handling for Operators (#9681)
    
    * Add support for threaded engine
    
    * Add support for threaded engine
    
    * Remove on_start_callback for else
    
    * Add support for global_ex_ptr
    
    * Rethrow in waitall only once
    
    * run tests for gpu
    
    * Add comments for exception_ptr
    
    * Fix lint
    
    * Push exc_handling tests
    
    * Add comments for OnStart
    
    * Fixes for exc handling
    
    * Catch std::exception for all other exceptions
    
    * Rollback std::move use
    
    * Fix style
    
    * Fix onstart
    
    * Fix debug_info
    
    * Throw exception only once in an execution graph
    
    * make test naming consistent
    
    * Fix symbolic test
    
    * Remove unused code
---
 include/mxnet/engine.h                     |   8 +-
 src/engine/naive_engine.cc                 |   6 +-
 src/engine/threaded_engine.cc              |  39 +++++++---
 src/engine/threaded_engine.h               |  94 +++++++++++++++++------
 src/storage/cpu_device_storage.h           |   4 +-
 src/storage/gpu_device_storage.h           |   2 +-
 tests/python/gpu/test_operator_gpu.py      |   1 +
 tests/python/unittest/test_exc_handling.py | 116 +++++++++++++++++++++++++++++
 8 files changed, 231 insertions(+), 39 deletions(-)

diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 366a6b6..fd1fe89 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -141,13 +141,15 @@ class MXNET_API Engine {
    * \param mutable_vars The variables that current operation will mutate.
    * \param prop Property of the function.
    * \param opr_name The operator name.
+   * \param wait Whether this is a WaitForVar operation
    * \return The new operator allocated.
    */
   virtual OprHandle NewOperator(AsyncFn fn,
                                 std::vector<VarHandle> const& const_vars,
                                 std::vector<VarHandle> const& mutable_vars,
                                 FnProperty prop = FnProperty::kNormal,
-                                const char* opr_name = nullptr) = 0;
+                                const char* opr_name = nullptr,
+                                bool wait = false) = 0;
   /*!
    * \brief Delete the given operator.
    * \param op The operator to delete.
@@ -176,13 +178,15 @@ class MXNET_API Engine {
    * \param prop Property of the function.
    * \param priority Priority of the action, as hint to the engine.
    * \param opr_name The operator name.
+   * \param wait Whether this is a WaitForVar operation
    */
   virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
                          std::vector<VarHandle> const& const_vars,
                          std::vector<VarHandle> const& mutable_vars,
                          FnProperty prop = FnProperty::kNormal,
                          int priority = 0,
-                         const char* opr_name = nullptr) = 0;
+                         const char* opr_name = nullptr,
+                         bool wait = false) = 0;
   /*!
    * \brief Schedule the deletion of a variable.
    *
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 86f3877..adf14b1 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -73,7 +73,8 @@ class NaiveEngine final : public Engine {
                         std::vector<VarHandle> const& const_vars,
                         std::vector<VarHandle> const& mutable_vars,
                         FnProperty prop = FnProperty::kNormal,
-                        const char* opr_name = nullptr) override {
+                        const char* opr_name = nullptr,
+                        bool wait = false) override {
     NaiveOpr *opr = new NaiveOpr();
     opr->fn = fn;
     opr->const_vars = const_vars;
@@ -125,7 +126,8 @@ class NaiveEngine final : public Engine {
                  std::vector<VarHandle> const& mutable_vars,
                  FnProperty prop = FnProperty::kNormal,
                  int priority = 0,
-                 const char* opr_name = nullptr) override {
+                 const char* opr_name = nullptr,
+                 bool wait = false) override {
     CallbackOnComplete callback = CreateCallback(
         NaiveEngine::OnComplete, nullptr);
     this->req_completed_ = false;
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 2b28a7d..e166a6d 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator(
     std::vector<VarHandle> const& const_vars,
     std::vector<VarHandle> const& mutable_vars,
     FnProperty prop,
-    const char* opr_name) {
+    const char* opr_name,
+    bool wait) {
   auto ret = ThreadedOpr::New();
   ret->opr_name = opr_name;
   ret->fn = std::move(fn);
   ret->prop = prop;
   ret->const_vars.resize(const_vars.size());
   ret->mutable_vars.resize(mutable_vars.size());
+  ret->wait = wait;
   std::transform(const_vars.begin(), const_vars.end(),
                  ret->const_vars.begin(), ThreadedVar::CastFromBase);
   std::transform(mutable_vars.begin(), mutable_vars.end(),
@@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context 
exec_ctx,
                                std::vector<VarHandle> const& mutable_vars,
                                FnProperty prop,
                                int priority,
-                               const char* opr_name) {
+                               const char* opr_name,
+                               bool wait) {
   BulkFlush();
-  ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, 
prop, opr_name);
+  ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, 
prop, opr_name, wait);
   opr->temporary = true;
 #if MXNET_USE_PROFILER
   Profiler *profiler = Profiler::Get();
@@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
 void ThreadedEngine::WaitForVar(VarHandle var) {
   BulkFlush();
   ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
-  if (threaded_var->ready_to_read()) return;
+  if (threaded_var->ready_to_read()) {
+    ThrowException(threaded_var);
+    return;
+  }
   if (engine_info_) {
     LOG(INFO) << "Wait for " << threaded_var;
     debug_wait_var_ = threaded_var;
@@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
       }
       on_complete();
     }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
-    PROFILER_MESSAGE("WaitForVar"));
+    PROFILER_MESSAGE("WaitForVar"), true);
   {
     std::unique_lock<std::mutex> lock{finished_m_};
     finished_cv_.wait(lock, [this, &done]() {
         return done.load() || kill_.load();
     });
   }
+
+  ThrowException(threaded_var);
 }
 
 void ThreadedEngine::WaitForAll() {
@@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* 
threaded_opr) {
   bool is_temporary_opr = threaded_opr->temporary;
   // Mark complete for read variables
   for (auto&& i : threaded_opr->const_vars) {
-    i->CompleteReadDependency([this](OprBlock* opr) {
-        this->PushToExecute(opr, false);
-      });
+    i->CompleteReadDependency(
+        [this](OprBlock* opr) { this->PushToExecute(opr, false); });
   }
   // Mark complete for write variables.
   for (auto&& i : threaded_opr->mutable_vars) {
+    if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
+      i->var_exception = threaded_opr->opr_exception;
+    }
     const bool debug_info = (engine_info_ && debug_wait_var_ == i);
     if (debug_info) {
       LOG(INFO) << "Complete write dep for " << i;
     }
-    const bool to_delete = i->CompleteWriteDependency(
-        [this, debug_info](OprBlock* opr) {
+    const bool to_delete =
+        i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
           if (debug_info) {
             LOG(INFO) << "PushToExecute " << opr;
             debug_push_opr_ = opr;
@@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* 
threaded_opr) {
   }
 }
 
+inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
+  if (threaded_var->var_exception && *threaded_var->var_exception) {
+    std::exception_ptr tmp = *threaded_var->var_exception;
+    *threaded_var->var_exception = nullptr;
+    std::rethrow_exception(tmp);
+  }
+  return;
+}
+
 void ThreadedEngine::OnCompleteStatic(
     Engine *engine, void *opr_block_) {
   OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 1524f25..f647074 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -175,6 +175,8 @@ class ThreadedVar final
   static std::atomic<std::size_t> counter;
   ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
 #endif  // ENGINE_DEBUG
+  /*! \brief exception_ptr associated with the ThreadedVar */
+  std::shared_ptr<std::exception_ptr> var_exception;
 
  private:
   // TODO(hotpxl) change this to spinlock for faster runtime
@@ -237,6 +239,10 @@ struct ThreadedOpr final : public Opr,
    */
   bool temporary{false};
   /*!
+   * \brief Whether this is a WaitForVar operation
+   */
+  bool wait{false};
+  /*!
    * \brief Cast a Opr pointer to ThreadedOpr pointer
    * \param ptr pointer from base.
    * \return a casted pointer.
@@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
   }
   // define possible debug information
   DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
+  /*! \brief exception_ptr associated with the ThreadedOpr */
+  std::shared_ptr<std::exception_ptr> opr_exception;
 };  // struct ThreadedOpr
 
 /*!
@@ -265,7 +273,8 @@ class ThreadedEngine : public Engine {
                            std::vector<VarHandle> const& const_vars,
                            std::vector<VarHandle> const& mutable_vars,
                            FnProperty prop = FnProperty::kNormal,
-                           const char* opr_name = nullptr) override;
+                           const char* opr_name = nullptr,
+                           bool wait = false) override;
   void DeleteOperator(OprHandle op) override;
   void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = 
false) override;
   void PushAsync(AsyncFn exec_fun, Context exec_ctx,
@@ -273,7 +282,8 @@ class ThreadedEngine : public Engine {
                  std::vector<VarHandle> const& mutable_vars,
                  FnProperty prop = FnProperty::kNormal,
                  int priority = 0,
-                 const char* opr_name = nullptr) override;
+                 const char* opr_name = nullptr,
+                 bool wait = false) override;
   void PushSync(SyncFn exec_fn, Context exec_ctx,
                 std::vector<VarHandle> const& const_vars,
                 std::vector<VarHandle> const& mutable_vars,
@@ -321,50 +331,63 @@ class ThreadedEngine : public Engine {
    * \param run_ctx runtime context used to execute the function.
    * \param opr_block the opr_block to be executed and deleted.
    */
-  void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
+  void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
     ThreadedOpr* threaded_opr = opr_block->opr;
 #if MXNET_USE_PROFILER
     if (opr_block->profiling && threaded_opr->opr_name) {
       const Context& ctx = opr_block->ctx;
-      opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, 
ctx.dev_id);
+      opr_block->opr_stat =
+          Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
       uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id());
       opr_block->opr_stat->thread_id = id;
-      strncpy(opr_block->opr_stat->opr_name,
-        threaded_opr->opr_name,
-        sizeof(opr_block->opr_stat->opr_name) - 1);
+      strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name,
+              sizeof(opr_block->opr_stat->opr_name) - 1);
       // record operator start timestamp
       SetOprStart(opr_block->opr_stat);
     }
 #endif
-    CallbackOnComplete callback = this->CreateCallback(
-        ThreadedEngine::OnCompleteStatic, opr_block);
-    bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
+    CallbackOnComplete callback =
+        this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
+    const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
     if (debug_info) {
       LOG(INFO) << "ExecuteOprBlock " << opr_block
                 << "shutdown_phase=" << shutdown_phase_;
     }
     if (!shutdown_phase_) {
       try {
+        OnStart(threaded_opr);
         if (debug_info) {
           LOG(INFO) << "ExecuteOprFn ";
         }
-        threaded_opr->fn(run_ctx, callback);
+        try {
+          if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
+              threaded_opr->wait) {
+            threaded_opr->fn(run_ctx, callback);
+          } else {
+            callback();
+          }
+        } catch (dmlc::Error& e) {
+          threaded_opr->opr_exception =
+              std::make_shared<std::exception_ptr>(std::current_exception());
+          callback();
+        }
         if (debug_info) {
           LOG(INFO) << "Fin ExecuteOprFn ";
         }
-      } catch(dmlc::Error &e) {
+      } catch (std::exception& e) {
         std::string what = e.what();
         if (what.find("driver shutting down") == std::string::npos &&
             !shutdown_phase_) {
-          LOG(FATAL) << e.what() << "\n" <<
-            "A fatal error occurred in asynchronous engine operation. "
-            "If you do not know what caused this error, "
-            "you can try set environment variable MXNET_ENGINE_TYPE "
-            "to NaiveEngine and run with debugger (i.e. gdb). "
-            "This will force all operations to be synchronous and "
-            "backtrace will give you the series of calls that lead "
-            "to this error. Remember to set MXNET_ENGINE_TYPE back to "
-            "empty after debugging.";
+          LOG(FATAL)
+              << e.what() << "\n"
+              << "A fatal error occurred in asynchronous engine operation. "
+                 "If you do not know what caused this error, "
+                 "you can try set environment variable MXNET_ENGINE_TYPE "
+                 "to NaiveEngine and run with debugger (i.e. gdb). "
+                 "This will force all operations to be synchronous and "
+                 "backtrace will give you the series of calls that lead "
+                 "to this error. Remember to set MXNET_ENGINE_TYPE back to "
+                 "empty after debugging.";
         }
       }
     } else {
@@ -414,7 +437,34 @@ class ThreadedEngine : public Engine {
    * On operation completion, this will trigger subsequent operations.
    */
   inline void OnComplete(ThreadedOpr* threaded_opr);
-  // callback to the threaded engine
+  /*!
+   * \brief rethrow caught exception in WaitForVar
+   * \param threaded_var the var that we are waiting to read
+   */
+  inline void ThrowException(ThreadedVar* threaded_var);
+  /*!
+   * \brief Mark exceptions before operation execution.
+   *
+   * Will mark the operator as a failure and associate exception_ptr
+   * if any of the read dependencies have exception associated.
+   */
+  inline void OnStart(ThreadedOpr* threaded_opr) {
+    for (auto&& i : threaded_opr->const_vars) {
+      if (i->var_exception && *i->var_exception) {
+        threaded_opr->opr_exception = i->var_exception;
+        break;
+      }
+    }
+    if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
+      for (auto&& i : threaded_opr->mutable_vars) {
+        if (i->var_exception && *i->var_exception) {
+          threaded_opr->opr_exception = i->var_exception;
+          break;
+        }
+      }
+    }
+  }
+
   static void OnCompleteStatic(Engine *engine, void *threaded_opr);
   /*! \brief append an operator to bulk */
   inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
diff --git a/src/storage/cpu_device_storage.h b/src/storage/cpu_device_storage.h
index f0dd61f..beb7389 100644
--- a/src/storage/cpu_device_storage.h
+++ b/src/storage/cpu_device_storage.h
@@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
   void* ptr;
 #if _MSC_VER
   ptr = _aligned_malloc(size, alignment_);
-  if (ptr == NULL) throw std::bad_alloc();
+  if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
 #else
   int ret = posix_memalign(&ptr, alignment_, size);
-  if (ret != 0) throw std::bad_alloc();
+  if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
 #endif
   return ptr;
 }
diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h
index c859892..435c7e8 100644
--- a/src/storage/gpu_device_storage.h
+++ b/src/storage/gpu_device_storage.h
@@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
 #endif  // MXNET_USE_NCCL
   cudaError_t e = cudaMalloc(&ret, size);
   if (e != cudaSuccess && e != cudaErrorCudartUnloading)
-    throw std::bad_alloc();
+    LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
 #else   // MXNET_USE_CUDA
   LOG(FATAL) << "Please compile with CUDA enabled";
 #endif  // MXNET_USE_CUDA
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index f9f00c4..91eb958 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -34,6 +34,7 @@ from test_optimizer import *
 from test_random import *
 from test_gluon import *
 from test_loss import *
+from test_exc_handling import *
 #from test_rnn import *
 from test_gluon_rnn import *
 from test_sparse_ndarray import test_create_csr, test_create_row_sparse, 
test_sparse_nd_slice
diff --git a/tests/python/unittest/test_exc_handling.py 
b/tests/python/unittest/test_exc_handling.py
new file mode 100644
index 0000000..0c2c43d
--- /dev/null
+++ b/tests/python/unittest/test_exc_handling.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.base import MXNetError
+from mxnet.test_utils import assert_exception, default_context, 
set_default_context
+from nose.tools import assert_raises
+
+def test_exc_imperative():
+    def imperative(exec_numpy=True):
+        a = mx.nd.random.normal(0, 1, (2, 2))
+        b = mx.nd.random.normal(0, -1, (2, 2))
+        c = mx.nd.dot(a, b)
+        if exec_numpy:
+            c.asnumpy()
+
+    imperative(exec_numpy=False)
+    assert_raises(MXNetError, imperative, True)
+
+def test_exc_symbolic():
+    def symbolic(exec_backward=True):
+        x = mx.sym.Variable('x')
+        y = mx.sym.Variable('y')
+        z = mx.sym.Variable('z')
+        x_shape = (2, 2)
+        z_shape = (3, 2)
+        inputs = [x, y]
+        out = mx.symbol.ElementWiseSum(*inputs, name="esum")
+        out = mx.sym.dot(z, out)
+        out2 = mx.sym.random.normal(0, -1, x_shape, ctx=default_context())
+        out = mx.sym.dot(out, out2)
+        out = mx.sym.make_loss(out)
+        arr = {'x': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()),
+               'y': mx.nd.random.normal(0, 1, x_shape, ctx=default_context()),
+               'z': mx.nd.random.normal(0, 1, z_shape, ctx=default_context())}
+        arr_grad = {'x': mx.nd.empty(x_shape), 'y': mx.nd.empty(x_shape), 'z': 
mx.nd.empty(z_shape)}
+        exec1 = out.bind(ctx=default_context(), args=arr, args_grad=arr_grad)
+        outputs = exec1.forward()
+        if exec_backward:
+            exec1.backward()
+            exec1.grad_arrays[0].asnumpy()
+        else:
+            outputs[0].asnumpy()
+
+    assert_raises(MXNetError, symbolic, False)
+    assert_raises(MXNetError, symbolic, True)
+
+def test_exc_gluon():
+    def gluon(exec_wait=True):
+        model = nn.Sequential()
+        model.add(nn.Dense(128, activation='tanh', in_units=10, flatten=False))
+        model.add(nn.Dropout(1))
+        model.add(nn.Dense(64, activation='tanh', in_units=256),
+                  nn.Dense(32, in_units=64))
+        x = mx.sym.var('data')
+        y = model(x)
+        model.collect_params().initialize(ctx=[default_context()])
+        z = model(mx.nd.random.normal(10, -10, (32, 2, 10), 
ctx=default_context()))
+        if exec_wait:
+            z.wait_to_read()
+
+    gluon(exec_wait=False)
+    assert_raises(MXNetError, gluon, True)
+
+def test_exc_multiple_waits():
+    caught = False
+    try:
+        a = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+        a.wait_to_read()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+    try:
+        b = mx.nd.random.normal(0, -1, (2, 2)).copyto(default_context())
+        b.wait_to_read()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+
+def test_exc_post_fail():
+    caught = False
+    try:
+        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+        a.asnumpy()
+    except MXNetError:
+        caught = True
+    assert caught, "No exception thrown"
+    b.asnumpy()
+
+def test_exc_mutable_var_fail():
+    def mutable_var_check():
+        a, b = mx.nd.random_normal(0, -1, (2, 2)).copyto(default_context())
+        a = mx.nd.dot(a, a)
+        a.asnumpy()
+    assert_raises(MXNetError, mutable_var_check)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to