This is an automated email from the ASF dual-hosted git repository.

arcadiaphy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f84682  Properly handling custom op exception by modify engine 
(#14693)
1f84682 is described below

commit 1f84682576433ec152ae8a5874a687ddad9190b4
Author: Wang Jiajun <[email protected]>
AuthorDate: Tue Apr 16 03:03:33 2019 -0500

    Properly handling custom op exception by modify engine (#14693)
    
    * fix custom except handling by modify engine
    
    * add test
    
    * fix lint
    
    * update test
    
    * fix test
    
    * trigger CI
---
 docs/faq/env_var.md                    |   3 -
 include/mxnet/engine.h                 |   6 +-
 src/engine/naive_engine.cc             |   3 +
 src/engine/threaded_engine.cc          |   5 ++
 src/engine/threaded_engine.h           |   5 +-
 src/operator/custom/custom-inl.h       |  45 +++++++++---
 tests/python/unittest/test_operator.py | 129 +++++++++++++++++++++++++++------
 7 files changed, 158 insertions(+), 38 deletions(-)

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 2768f64..095c214 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -60,9 +60,6 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 * MXNET_MP_OPENCV_NUM_THREADS
   - Values: Int ```(default=0)```
   - The number of OpenCV execution threads given to multiprocess workers. 
OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 
(default). Enlarge this number may boost the performance of individual workers 
when executing underlying OpenCV functions but please consider reducing the 
overall `num_workers` to avoid thread contention (not available on Windows).
-* MXNET_CUSTOM_OP_NUM_THREADS
-  - Values: Int ```(default=16)```
-  - The maximum number of threads given to custom operators.
 
 ## Memory Options
 
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 408a70a..9d63675 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -106,7 +106,9 @@ enum class FnProperty {
   /*! \brief Delete variable call */
   kDeleteVar,
   /*! \brief Prioritized sync operation on GPU */
-  kGPUPrioritized
+  kGPUPrioritized,
+  /*! \brief Operation not to be skipped even with associated exception */
+  kNoSkip
 };  // enum class FnProperty
 
 /*!
@@ -230,6 +232,8 @@ class MXNET_API Engine {
    * \brief Wait until all the activity of engine finishes.
    */
   virtual void WaitForAll() = 0;
+  /*!\brief Throw if threre are associated exception with var */
+  virtual void Throw(VarHandle var) = 0;
   /*!\brief virtual destructor */
   virtual ~Engine() noexcept(false) {}
   /*!
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index db44919..93853c4 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -212,6 +212,9 @@ class NaiveEngine final : public Engine {
   void WaitForAll() override {
   }
 
+  void Throw(VarHandle var) override {
+  }
+
   void NotifyShutdown() override {
     shutdown_phase_.store(true);
   }
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index 986b6ad..3831190 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -498,6 +498,11 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* 
threaded_var) {
   return;
 }
 
+void ThreadedEngine::Throw(VarHandle var) {
+  ThreadedVar *threaded_var = ThreadedVar::CastFromBase(var);
+  ThrowException(threaded_var);
+}
+
 void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_,
                                       const dmlc::Error* error) {
   OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index 3d2119d..7df232b 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -306,6 +306,7 @@ class ThreadedEngine : public Engine {
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) 
override;
   void WaitForVar(VarHandle var) override;
   void WaitForAll() override;
+  void Throw(VarHandle var) override;
   void NotifyShutdown() override {
     shutdown_phase_.store(true);
   }
@@ -374,8 +375,8 @@ class ThreadedEngine : public Engine {
           LOG(INFO) << "ExecuteOprFn ";
         }
         try {
-          if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
-              threaded_opr->wait) {
+          if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) 
||
+              threaded_opr->prop == FnProperty::kNoSkip) || 
threaded_opr->wait) {
             threaded_opr->fn(run_ctx, callback);
           } else {
             callback();
diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h
index c5eaea1..3bf63b7 100644
--- a/src/operator/custom/custom-inl.h
+++ b/src/operator/custom/custom-inl.h
@@ -96,7 +96,12 @@ class CustomOperator {
       bool prev_recording = Imperative::Get()->set_is_recording(recording);
       bool prev_training = Imperative::Get()->set_is_training(training);
 
-      func();
+      try {
+        func();
+      } catch (dmlc::Error& e) {
+        exception_ =
+            std::make_shared<std::exception_ptr>(std::current_exception());
+      }
 
       Imperative::Get()->set_is_training(prev_training);
       Imperative::Get()->set_is_recording(prev_recording);
@@ -116,6 +121,16 @@ class CustomOperator {
 
       Engine::Get()->PushSync(
           [=](RunContext rctx) {
+            try {
+              Throw();
+              for (const auto& i : arrs) {
+                Engine::Get()->Throw(i.var());
+              }
+            } catch(dmlc::Error& err) {
+              ctx.async_on_complete(&err);
+              return;
+            }
+
             for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
               if (arrs[i].storage_type() == kDefaultStorage ||
                   arrs[i].storage_type() == kUndefinedStorage)
@@ -125,14 +140,15 @@ class CustomOperator {
                 out_idx++;
               }
             }
+
             ctx.async_on_complete();
           },
-          ctx.run_ctx.ctx, vars, vars2, FnProperty::kNormal, 0,
+          ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0,
           "CustomOperator");
     });
     // increase num_threads if there is not enough threads to execute custom 
operator
-    if (q_.size() > num_free_threads)
-      CreateThreads(q_.size() - num_free_threads);
+    if (q_.size() > num_free_threads_)
+      CreateThreads(q_.size() - num_free_threads_);
     cv_.notify_all();
   }
 
@@ -142,9 +158,10 @@ class CustomOperator {
   }
 
   void Start() {
-    num_free_threads = 0;
+    num_free_threads_ = 0;
     destructing_ = false;
     naive_engine_ = true;
+    exception_ = nullptr;
     if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", 
std::string())) {
       naive_engine_ = false;
     }
@@ -162,6 +179,14 @@ class CustomOperator {
     workers_.clear();
   }
 
+  inline void Throw() {
+    if (exception_ && *exception_) {
+      std::exception_ptr tmp = *exception_;
+      exception_ = nullptr;
+      std::rethrow_exception(tmp);
+    }
+  }
+
  private:
   CustomOperator() {
     this->Start();
@@ -171,21 +196,20 @@ class CustomOperator {
     while (!q_.empty() || !destructing_) {
       cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
       while (!q_.empty()) {
-        --num_free_threads;
+        --num_free_threads_;
         auto fn = q_.front();
         q_.pop();
         lock.unlock();
         fn();
-        ++num_free_threads;
+        ++num_free_threads_;
         lock.lock();
       }
     }
   }
   void SetNumThreads(int num_threads) {
-    num_threads = std::min(dmlc::GetEnv("MXNET_CUSTOM_OP_NUM_THREADS", 16), 
num_threads);
     for (int i = workers_.size(); i < num_threads; ++i) {
       workers_.emplace_back(std::thread([this]{this->ThreadTarget();}));
-      ++num_free_threads;
+      ++num_free_threads_;
     }
   }
   void CreateThreads(int num_new_threads) {
@@ -196,8 +220,9 @@ class CustomOperator {
   // async worker
   std::condition_variable cv_;
   std::vector<std::thread> workers_;
-  std::atomic<uint32_t> num_free_threads;
+  std::atomic<uint32_t> num_free_threads_;
   std::queue<std::function<void(void)> > q_;
+  std::shared_ptr<std::exception_ptr> exception_;
   bool naive_engine_;
   bool destructing_;
 };
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 59d72d4..287b974 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -29,6 +29,8 @@ from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
 from mxnet.base import py_str, MXNetError, _as_list
 from common import setup_module, with_seed, teardown, 
assert_raises_cudnn_not_satisfied, assertRaises
+from common import run_in_spawned_process
+from nose.tools import assert_raises
 import unittest
 import os
 
@@ -5360,29 +5362,29 @@ def test_custom_op():
 
     # test custom operator fork
     # see https://github.com/apache/incubator-mxnet/issues/14396
-    if not sys.platform.startswith('win'):  # no fork in windows
-        class AdditionOP(mx.operator.CustomOp):
-            def __init__(self):
-                super(AdditionOP, self).__init__()
-            def forward(self, is_train, req, in_data, out_data, aux):
-                out_data[0][:] = in_data[0] + in_data[1]
-            def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
-                in_grad[0][:] = out_grad[0]
-                in_grad[1][:] = out_grad[0]
-
-        @mx.operator.register("AdditionOP")
-        class AdditionOPProp(mx.operator.CustomOpProp):
-            def __init__(self):
-                super(AdditionOPProp, self).__init__()
-            def list_arguments(self):
-                return ['a', 'b']
-            def list_outputs(self):
-                return ['output']
-            def infer_shape(self, in_shape):
-                return in_shape, [in_shape[0]]
-            def create_operator(self, ctx, shapes, dtypes):
-                return AdditionOP()
+    class AdditionOP(mx.operator.CustomOp):
+        def __init__(self):
+            super(AdditionOP, self).__init__()
+        def forward(self, is_train, req, in_data, out_data, aux):
+            out_data[0][:] = in_data[0] + in_data[1]
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            in_grad[0][:] = out_grad[0]
+            in_grad[1][:] = out_grad[0]
 
+    @mx.operator.register("AdditionOP")
+    class AdditionOPProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(AdditionOPProp, self).__init__()
+        def list_arguments(self):
+            return ['a', 'b']
+        def list_outputs(self):
+            return ['output']
+        def infer_shape(self, in_shape):
+            return in_shape, [in_shape[0]]
+        def create_operator(self, ctx, shapes, dtypes):
+            return AdditionOP()
+
+    if not sys.platform.startswith('win'):  # no fork in windows
         def custom_add():
             a = mx.nd.array([1, 2, 3])
             b = mx.nd.array([4, 5, 6])
@@ -5397,6 +5399,89 @@ def test_custom_op():
         p.join(5)
         assert not p.is_alive(), "deadlock may exist in custom operator"
 
+
+def _build_dot_custom(fun_forward, name):
+    class Dot(mx.operator.CustomOp):
+        def __init__(self):
+            super(Dot, self).__init__()
+        def forward(self, is_train, req, in_data, out_data, aux):
+            fun_forward(in_data, out_data)
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            pass
+
+    @mx.operator.register(name)
+    class DotProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(DotProp, self).__init__()
+        def list_arguments(self):
+            return ['a', 'b']
+        def list_outputs(self):
+            return ['output']
+        def infer_shape(self, in_shape):
+            return in_shape, [(in_shape[0][0], in_shape[1][1])]
+        def create_operator(self, ctx, shapes, dtypes):
+            return Dot()
+
+def _custom_exc3(seed):
+    def custom_exc3():
+        def f(in_data, out_data):
+            out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
+            out_data[0].wait_to_read()
+        _build_dot_custom(f, 'Dot3')
+        n = int(1e8)
+        a = mx.nd.zeros((n, 1))
+        b = mx.nd.zeros((1, n))
+        # trigger OOM
+        c = mx.nd.Custom(a, b, op_type='Dot3')
+        c.wait_to_read()
+    assert_raises(MXNetError, custom_exc3)
+
+def _custom_exc4(seed):
+    def custom_exc4():
+        def f(in_data, out_data):
+            out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
+        _build_dot_custom(f, 'Dot4')
+        n = int(1e8)
+        a = mx.nd.zeros((n, 1))
+        b = mx.nd.zeros((1, n))
+        # trigger OOM
+        c = mx.nd.Custom(a, b, op_type='Dot4')
+        c.wait_to_read()
+    assert_raises(MXNetError, custom_exc4)
+
+@with_seed()
+def test_custom_op_exc():
+    # test except handling
+    # see https://github.com/apache/incubator-mxnet/pull/14693
+    # 1. error in python code
+    def custom_exc1():
+        def f(in_data, out_data):
+            assert False
+            out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
+        _build_dot_custom(f, 'Dot1')
+        a = mx.nd.zeros((4, 1))
+        b = mx.nd.zeros((1, 4))
+        c = mx.nd.Custom(a, b, op_type='Dot1')
+        c.wait_to_read()
+    assert_raises(MXNetError, custom_exc1)
+
+    # 2. error in pushing operator to engine
+    def custom_exc2():
+        def f(in_data, out_data):
+            out_data[0][:] = mx.nd.dot(in_data[0], in_data[1])
+        _build_dot_custom(f, 'Dot2')
+        a = mx.nd.zeros((4, 2))
+        b = mx.nd.zeros((1, 4))
+        # trigger error by invalid input shapes of operands
+        c = mx.nd.Custom(a, b, op_type='Dot2')
+        c.wait_to_read()
+    assert_raises(MXNetError, custom_exc2)
+
+    # 3. error in real execution
+    run_in_spawned_process(_custom_exc3, {})
+    run_in_spawned_process(_custom_exc4, {})
+
+
 @with_seed()
 def test_psroipooling():
     for num_rois in [1, 2]:

Reply via email to