This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new 804403e  Fix the bug of `MXEnginePushAsyncND` and `MXEnginePushSyncND` 
(#15751) (#15792)
804403e is described below

commit 804403e999d1567f371c5243f5565127ad7f2f93
Author: JackieWu <w...@live.cn>
AuthorDate: Thu Aug 8 13:55:35 2019 +0800

    Fix the bug of `MXEnginePushAsyncND` and `MXEnginePushSyncND` (#15751) 
(#15792)
    
    * fix push sync nd api
    
    * align code
    
    * update test for syncnd
    
    * fix bug in tests/cpp/engine/threaded_engine_test
    
    * add more testcases for MXEnginePushSyncND and MXEnginePushAsyncND
    
    * fix test
    
    * fix
    
    * fix
    
    * lint
    
    * ci
    
    * retrigger CI
---
 include/mxnet/c_api.h                    |  22 +++---
 src/c_api/c_api.cc                       |  40 +++++------
 tests/cpp/engine/threaded_engine_test.cc | 117 +++++++++++++++++++------------
 3 files changed, 105 insertions(+), 74 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index a2da6db..c73b366 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2863,12 +2863,12 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc 
sync_func, void* func_param,
   * \param wait Whether this is a WaitForVar operation.
   */
 MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
-                                EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
-                                NDArrayHandle const_nds_handle, int 
num_const_nds,
-                                NDArrayHandle mutable_nds_handle, int 
num_mutable_nds,
-                                EngineFnPropertyHandle prop_handle 
DEFAULT(NULL),
-                                int priority DEFAULT(0), const char* opr_name 
DEFAULT(NULL),
-                                bool wait DEFAULT(false));
+                                  EngineFuncParamDeleter deleter, 
ContextHandle ctx_handle,
+                                  NDArrayHandle* const_nds_handle, int 
num_const_nds,
+                                  NDArrayHandle* mutable_nds_handle, int 
num_mutable_nds,
+                                  EngineFnPropertyHandle prop_handle 
DEFAULT(NULL),
+                                  int priority DEFAULT(0), const char* 
opr_name DEFAULT(NULL),
+                                  bool wait DEFAULT(false));
 
 /*!
   * \brief Push a synchronous operation to the engine.
@@ -2886,11 +2886,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc 
async_func, void* func_param,
   * \param opr_name The operation name.
   */
 MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
-                               EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
-                               NDArrayHandle const_nds_handle, int 
num_const_nds,
-                               NDArrayHandle mutable_nds_handle, int 
num_mutable_nds,
-                               EngineFnPropertyHandle prop_handle 
DEFAULT(NULL),
-                               int priority DEFAULT(0), const char* opr_name 
DEFAULT(NULL));
+                                 EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
+                                 NDArrayHandle* const_nds_handle, int 
num_const_nds,
+                                 NDArrayHandle* mutable_nds_handle, int 
num_mutable_nds,
+                                 EngineFnPropertyHandle prop_handle 
DEFAULT(NULL),
+                                 int priority DEFAULT(0), const char* opr_name 
DEFAULT(NULL));
 
 #ifdef __cplusplus
 }
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 35bd3ee..6ba46bd 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1535,18 +1535,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* 
func_param,
 }
 
 int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
-                      EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
-                      NDArrayHandle const_nds_handle, int num_const_nds,
-                      NDArrayHandle mutable_nds_handle, int num_mutable_nds,
-                      EngineFnPropertyHandle prop_handle, int priority,
-                      const char* opr_name, bool wait) {
-  API_BEGIN();
-  NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
-  NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+                        EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
+                        NDArrayHandle* const_nds_handle, int num_const_nds,
+                        NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
+                        EngineFnPropertyHandle prop_handle, int priority,
+                        const char* opr_name, bool wait) {
+  API_BEGIN();
+  NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
+  NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
   std::vector<VarHandle> const_var_vec(num_const_nds);
-  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = 
(const_nds+i)->var();
+  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = 
const_nds[i]->var();
   std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
-  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = 
(mutable_nds+i)->var();
+  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = 
mutable_nds[i]->var();
   return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
                            const_var_vec.data(), num_const_nds,
                            mutable_var_vec.data(), num_mutable_nds,
@@ -1555,18 +1555,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, 
void* func_param,
 }
 
 int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
-                     EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
-                     NDArrayHandle const_nds_handle, int num_const_nds,
-                     NDArrayHandle mutable_nds_handle, int num_mutable_nds,
-                     EngineFnPropertyHandle prop_handle, int priority,
-                     const char* opr_name) {
-  API_BEGIN();
-  NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
-  NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
+                       EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
+                       NDArrayHandle* const_nds_handle, int num_const_nds,
+                       NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
+                       EngineFnPropertyHandle prop_handle, int priority,
+                       const char* opr_name) {
+  API_BEGIN();
+  NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
+  NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
   std::vector<VarHandle> const_var_vec(num_const_nds);
-  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = 
(const_nds+i)->var();
+  for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = 
const_nds[i]->var();
   std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
-  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = 
(mutable_nds+i)->var();
+  for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = 
mutable_nds[i]->var();
   return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
                           const_var_vec.data(), num_const_nds,
                           mutable_var_vec.data(), num_mutable_nds,
diff --git a/tests/cpp/engine/threaded_engine_test.cc 
b/tests/cpp/engine/threaded_engine_test.cc
index 6b863f8..cea92a0 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {
 
 TEST(Engine, PushFuncND) {
   auto ctx = mxnet::Context{};
-  mxnet::NDArray nd(ctx);
-
-  // Test #1
-  LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
-  int* a = new int(100);
-  int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, 
nullptr, 0);
-  EXPECT_EQ(res, 0);
-
-  // Test #2
-  LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
-  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&nd, 0);
-  EXPECT_EQ(res, 0);
-
-  // Test #3
-  LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
-  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, 
nullptr, 0);
-  EXPECT_EQ(res, -1);
-
-  // Test #4
-  LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds 
=====";
-  res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&nd, -1);
-  EXPECT_EQ(res, -1);
-
-  // Test #5
-  LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
-  int* b = new int(101);
-  res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, 
nullptr, 0);
-  EXPECT_EQ(res, 0);
-
-  // Test #6
-  LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
-  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&nd, 1);
-  EXPECT_EQ(res, 0);
-
-  // Test #7
-  LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
-  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, 
nullptr, 0);
-  EXPECT_EQ(res, -1);
-
-  // Test #8
-  LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
-  res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&nd, -1);
-  EXPECT_EQ(res, -1);
+  std::vector<mxnet::NDArray*> nds;
+  const int num_nds = 5;
+  for (int i = 0; i < num_nds; ++i) {
+      mxnet::NDArray *pnd = new mxnet::NDArray(ctx);
+      nds.push_back(pnd);
+  }
+  for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) {
+      int num_mutable_nds = num_nds - num_const_nds;
+      void** const_nds_handle = num_const_nds > 0 ?
+          reinterpret_cast<void**>(nds.data()) : nullptr;
+      void** mutable_nds_handle = num_mutable_nds > 0 ?
+          reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr;
+
+      // Test #1
+      LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
+      int* a = new int(100);
+      int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, 0);
+
+      // Test #2
+      LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter 
=====";
+      res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, 0);
+
+      // Test #3
+      LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds 
=====";
+      res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, -1,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, -1);
+
+      // Test #4
+      LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds 
=====";
+      res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, -1);
+      EXPECT_EQ(res, -1);
+
+      // Test #5
+      LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
+      int* b = new int(101);
+      res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, 0);
+
+      // Test #6
+      LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter 
=====";
+      res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, 0);
+
+      // Test #7
+      LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds 
=====";
+      res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, -1,
+              mutable_nds_handle, num_mutable_nds);
+      EXPECT_EQ(res, -1);
+
+      // Test #8
+      LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds 
=====";
+      res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
+              const_nds_handle, num_const_nds,
+              mutable_nds_handle, -1);
+      EXPECT_EQ(res, -1);
+  }
+  for (mxnet::NDArray* pnd : nds) {
+      delete pnd;
+  }
 }
 
 TEST(Engine, basics) {

Reply via email to