This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.4.x by this push:
     new 60e8a1f  Add MXEnginePushAsync and MXEnginePushSync C APIs (#14615) 
(#14770)
60e8a1f is described below

commit 60e8a1f252a8dd5d083f3fb515089dcc373eb87a
Author: Yuxi Hu <[email protected]>
AuthorDate: Mon Apr 22 18:04:19 2019 -0700

    Add MXEnginePushAsync and MXEnginePushSync C APIs (#14615) (#14770)
    
    * add PushAsyncPtr and PushSyncPtr APIs in engine
    
    * avoid using shared_ptr for param in new APIs
    
    * avoid using std::vector in parameters
    
    * change to C API
    
    * address comments and add tests
    
    * fix perl build
    
    * use int instead of size_t
---
 include/mxnet/c_api.h                    | 61 +++++++++++++++++++++-
 src/c_api/c_api.cc                       | 88 ++++++++++++++++++++++++++++++++
 tests/cpp/engine/threaded_engine_test.cc | 78 ++++++++++++++++++++++++++++
 3 files changed, 225 insertions(+), 2 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index e9f1e2d..412cee5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -95,10 +95,22 @@ typedef void *CudaKernelHandle;
 typedef void *ProfileHandle;
 /*! \brief handle to DLManagedTensor*/
 typedef void *DLManagedTensorHandle;
-
+/*! \brief handle to Context */
+typedef const void *ContextHandle;
+/*! \brief handle to Engine FnProperty */
+typedef const void *EngineFnPropertyHandle;
+/*! \brief handle to Engine VarHandle */
+typedef void *EngineVarHandle;
+
+/*! \brief Engine asynchronous operation */
+typedef void (*EngineAsyncFunc)(void*, void*, void*);
+/*! \brief Engine synchronous operation */
+typedef void (*EngineSyncFunc)(void*, void*);
+/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
+typedef void (*EngineFuncParamDeleter)(void*);
 typedef void (*ExecutorMonitorCallback)(const char*,
                                         NDArrayHandle,
-                                        void *);
+                                        void*);
 
 struct NativeOpInfo {
   void (*forward)(int, float**, int*, unsigned**, int*, void*);
@@ -2486,6 +2498,51 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle 
handle, int* shared_pid,
 MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, 
const mx_uint *shape,
                                            mx_uint ndim, int dtype, 
NDArrayHandle *out);
 
+/*!
+  * \brief Push an asynchronous operation to the engine.
+  * \param async_func Execution function whici takes a parameter on_complete
+  *                   that must be called when the execution ompletes.
+  * \param func_param The parameter set on calling async_func, can be NULL.
+  * \param deleter The callback to free func_param, can be NULL.
+  * \param ctx_handle Execution context.
+  * \param const_vars_handle The variables that current operation will use
+  *                          but not mutate.
+  * \param num_const_vars The number of const_vars.
+  * \param mutable_vars_handle The variables that current operation will 
mutate.
+  * \param num_mutable_vars The number of mutable_vars.
+  * \param prop_handle Property of the function.
+  * \param priority Priority of the action, as hint to the engine.
+  * \param opr_name The operation name.
+  * \param wait Whether this is a WaitForVar operation.
+  */
+MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
+                                EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
+                                EngineVarHandle const_vars_handle, int 
num_const_vars,
+                                EngineVarHandle mutable_vars_handle, int 
num_mutable_vars,
+                                EngineFnPropertyHandle prop_handle = NULL, int 
priority = 0,
+                                const char* opr_name = NULL, bool wait = 
false);
+
+/*!
+  * \brief Push a synchronous operation to the engine.
+  * \param sync_func Execution function that executes the operation.
+  * \param func_param The parameter set on calling sync_func, can be NULL.
+  * \param deleter The callback to free func_param, can be NULL.
+  * \param ctx_handle Execution context.
+  * \param const_vars_handle The variables that current operation will use
+  *                          but not mutate.
+  * \param num_const_vars The number of const_vars.
+  * \param mutable_vars_handle The variables that current operation will 
mutate.
+  * \param num_mutable_vars The number of mutable_vars.
+  * \param prop_handle Property of the function.
+  * \param priority Priority of the action, as hint to the engine.
+  * \param opr_name The operation name.
+  */
+MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
+                               EngineFuncParamDeleter deleter, ContextHandle 
ctx_handle,
+                               EngineVarHandle const_vars_handle, int 
num_const_vars,
+                               EngineVarHandle mutable_vars_handle, int 
num_mutable_vars,
+                               EngineFnPropertyHandle prop_handle = NULL, int 
priority = 0,
+                               const char* opr_name = NULL);
 
 #ifdef __cplusplus
 }
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 80bd605..0568c19 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1389,3 +1389,91 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int 
shared_id, const mx_uint *s
   *out = new NDArray(shared_pid, shared_id, TShape(shape, shape + ndim), 
dtype);
   API_END();
 }
+
+typedef Engine::VarHandle VarHandle;
+typedef Engine::CallbackOnComplete CallbackOnComplete;
+
+void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
+  CHECK_GE(num_const_vars, 0) << "Non-negative number of const vars expected.";
+  CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars 
expected.";
+}
+
+int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
+                      EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                      EngineVarHandle const_vars_handle, int num_const_vars,
+                      EngineVarHandle mutable_vars_handle, int 
num_mutable_vars,
+                      EngineFnPropertyHandle prop_handle, int priority,
+                      const char* opr_name, bool wait) {
+  API_BEGIN();
+
+  auto exec_ctx = *static_cast<const Context*>(ctx_handle);
+  auto const_vars = static_cast<VarHandle*>(const_vars_handle);
+  auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
+  auto prop = FnProperty::kNormal;
+  if (prop_handle) {
+    prop = *static_cast<const FnProperty*>(prop_handle);
+  }
+
+  Engine::AsyncFn exec_fn;
+  if (deleter == nullptr) {
+    exec_fn = [async_func, func_param](RunContext rctx,
+                                       CallbackOnComplete on_complete) {
+      async_func(&rctx, &on_complete, func_param);
+    };
+  } else {
+    // Wrap func_param in a shared_ptr with deleter such that deleter
+    // will be called when the lambda goes out of scope.
+    std::shared_ptr<void> shared_func_param(func_param, deleter);
+    exec_fn = [async_func, shared_func_param](RunContext rctx,
+                                              CallbackOnComplete on_complete) {
+      async_func(&rctx, &on_complete, shared_func_param.get());
+    };
+  }
+
+  AssertValidNumberVars(num_const_vars, num_mutable_vars);
+  std::vector<VarHandle> const_var_vec(const_vars, const_vars + 
num_const_vars);
+  std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + 
num_mutable_vars);
+  Engine::Get()->PushAsync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
+                           prop, priority, opr_name, wait);
+
+  API_END();
+}
+
+int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
+                     EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+                     EngineVarHandle const_vars_handle, int num_const_vars,
+                     EngineVarHandle mutable_vars_handle, int num_mutable_vars,
+                     EngineFnPropertyHandle prop_handle, int priority,
+                     const char* opr_name) {
+  API_BEGIN();
+
+  auto exec_ctx = *static_cast<const Context*>(ctx_handle);
+  auto const_vars = static_cast<VarHandle*>(const_vars_handle);
+  auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
+  auto prop = FnProperty::kNormal;
+  if (prop_handle) {
+    prop = *static_cast<const FnProperty*>(prop_handle);
+  }
+
+  Engine::SyncFn exec_fn;
+  if (deleter == nullptr) {
+    exec_fn = [sync_func, func_param](RunContext rctx) {
+      sync_func(&rctx, func_param);
+    };
+  } else {
+    // Wrap func_param in a shared_ptr with deleter such that deleter
+    // will be called when the lambda goes out of scope.
+    std::shared_ptr<void> shared_func_param(func_param, deleter);
+    exec_fn = [sync_func, shared_func_param](RunContext rctx) {
+      sync_func(&rctx, shared_func_param.get());
+    };
+  }
+
+  AssertValidNumberVars(num_const_vars, num_mutable_vars);
+  std::vector<VarHandle> const_var_vec(const_vars, const_vars + 
num_const_vars);
+  std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars + 
num_mutable_vars);
+  Engine::Get()->PushSync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
+                          prop, priority, opr_name);
+
+  API_END();
+}
diff --git a/tests/cpp/engine/threaded_engine_test.cc 
b/tests/cpp/engine/threaded_engine_test.cc
index 6d669c1..405f3b3 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -27,6 +27,7 @@
 #include <dmlc/thread_group.h>
 #include <dmlc/omp.h>
 #include <gtest/gtest.h>
+#include <mxnet/c_api.h>
 #include <mxnet/engine.h>
 #include <dmlc/timer.h>
 #include <cstdio>
@@ -176,6 +177,83 @@ TEST(Engine, RandSumExpr) {
 
 void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }
 
+void FooAsyncFunc(void*, void* cb_ptr, void* param) {
+  if (param == nullptr) {
+    LOG(INFO) << "The fox asynchronously says receiving nothing.";
+  } else {
+    auto num = static_cast<int*>(param);
+    EXPECT_EQ(*num, 100);
+    LOG(INFO) << "The fox asynchronously says receiving " << *num;
+  }
+  auto cb = *static_cast<mxnet::engine::CallbackOnComplete*>(cb_ptr);
+  cb();
+}
+
+void FooSyncFunc(void*, void* param) {
+  if (param == nullptr) {
+    LOG(INFO) << "The fox synchronously says receiving nothing.";
+  } else {
+    auto num = static_cast<int*>(param);
+    EXPECT_EQ(*num, 101);
+    LOG(INFO) << "The fox synchronously says receiving " << *num;
+  }
+}
+
+void FooFuncDeleter(void* param) {
+  if (param != nullptr) {
+    auto num = static_cast<int*>(param);
+    LOG(INFO) << "The fox says deleting " << *num;
+    delete num;
+  }
+}
+
+TEST(Engine, PushFunc) {
+  auto var = mxnet::Engine::Get()->NewVariable();
+  auto ctx = mxnet::Context{};
+
+  // Test #1
+  LOG(INFO) << "===== Test #1: PushAsync param and deleter =====";
+  int* a = new int(100);
+  int res = MXEnginePushAsync(FooAsyncFunc, a, FooFuncDeleter, &ctx, &var, 1, 
nullptr, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #2
+  LOG(INFO) << "===== Test #2: PushAsync NULL param and NULL deleter =====";
+  res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&var, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #3
+  LOG(INFO) << "===== Test #3: PushAsync invalid number of const vars =====";
+  res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, &var, -1, 
nullptr, 0);
+  EXPECT_EQ(res, -1);
+
+  // Test #4
+  LOG(INFO) << "===== Test #4: PushAsync invalid number of mutable vars =====";
+  res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&var, -1);
+  EXPECT_EQ(res, -1);
+
+  // Test #5
+  LOG(INFO) << "===== Test #5: PushSync param and deleter =====";
+  int* b = new int(101);
+  res = MXEnginePushSync(FooSyncFunc, b, FooFuncDeleter, &ctx, &var, 1, 
nullptr, 0);
+  EXPECT_EQ(res, 0);
+
+  // Test #6
+  LOG(INFO) << "===== Test #6: PushSync NULL param and NULL deleter =====";
+  res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&var, 1);
+  EXPECT_EQ(res, 0);
+
+  // Test #7
+  LOG(INFO) << "===== Test #7: PushSync invalid number of const vars =====";
+  res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, &var, -1, 
nullptr, 0);
+  EXPECT_EQ(res, -1);
+
+  // Test #8
+  LOG(INFO) << "===== Test #8: PushSync invalid number of mutable vars =====";
+  res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, 
&var, -1);
+  EXPECT_EQ(res, -1);
+}
+
 TEST(Engine, basics) {
   auto&& engine = mxnet::Engine::Get();
   auto&& var = engine->NewVariable();

Reply via email to