This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.4.x by this push:
new 60e8a1f Add MXEnginePushAsync and MXEnginePushSync C APIs (#14615)
(#14770)
60e8a1f is described below
commit 60e8a1f252a8dd5d083f3fb515089dcc373eb87a
Author: Yuxi Hu <[email protected]>
AuthorDate: Mon Apr 22 18:04:19 2019 -0700
Add MXEnginePushAsync and MXEnginePushSync C APIs (#14615) (#14770)
* add PushAsyncPtr and PushSyncPtr APIs in engine
* avoid using shared_ptr for param in new APIs
* avoid using std::vector in parameters
* change to C API
* address comments and add tests
* fix perl build
* use int instead of size_t
---
include/mxnet/c_api.h | 61 +++++++++++++++++++++-
src/c_api/c_api.cc | 88 ++++++++++++++++++++++++++++++++
tests/cpp/engine/threaded_engine_test.cc | 78 ++++++++++++++++++++++++++++
3 files changed, 225 insertions(+), 2 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index e9f1e2d..412cee5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -95,10 +95,22 @@ typedef void *CudaKernelHandle;
typedef void *ProfileHandle;
/*! \brief handle to DLManagedTensor*/
typedef void *DLManagedTensorHandle;
-
+/*! \brief handle to Context */
+typedef const void *ContextHandle;
+/*! \brief handle to Engine FnProperty */
+typedef const void *EngineFnPropertyHandle;
+/*! \brief handle to Engine VarHandle */
+typedef void *EngineVarHandle;
+
+/*! \brief Engine asynchronous operation */
+typedef void (*EngineAsyncFunc)(void*, void*, void*);
+/*! \brief Engine synchronous operation */
+typedef void (*EngineSyncFunc)(void*, void*);
+/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
+typedef void (*EngineFuncParamDeleter)(void*);
typedef void (*ExecutorMonitorCallback)(const char*,
NDArrayHandle,
- void *);
+ void*);
struct NativeOpInfo {
void (*forward)(int, float**, int*, unsigned**, int*, void*);
@@ -2486,6 +2498,51 @@ MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle
handle, int* shared_pid,
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id,
const mx_uint *shape,
mx_uint ndim, int dtype,
NDArrayHandle *out);
+/*!
+ * \brief Push an asynchronous operation to the engine.
+ * \param async_func Execution function whici takes a parameter on_complete
+ * that must be called when the execution ompletes.
+ * \param func_param The parameter set on calling async_func, can be NULL.
+ * \param deleter The callback to free func_param, can be NULL.
+ * \param ctx_handle Execution context.
+ * \param const_vars_handle The variables that current operation will use
+ * but not mutate.
+ * \param num_const_vars The number of const_vars.
+ * \param mutable_vars_handle The variables that current operation will
mutate.
+ * \param num_mutable_vars The number of mutable_vars.
+ * \param prop_handle Property of the function.
+ * \param priority Priority of the action, as hint to the engine.
+ * \param opr_name The operation name.
+ * \param wait Whether this is a WaitForVar operation.
+ */
+MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle
ctx_handle,
+ EngineVarHandle const_vars_handle, int
num_const_vars,
+ EngineVarHandle mutable_vars_handle, int
num_mutable_vars,
+ EngineFnPropertyHandle prop_handle = NULL, int
priority = 0,
+ const char* opr_name = NULL, bool wait =
false);
+
+/*!
+ * \brief Push a synchronous operation to the engine.
+ * \param sync_func Execution function that executes the operation.
+ * \param func_param The parameter set on calling sync_func, can be NULL.
+ * \param deleter The callback to free func_param, can be NULL.
+ * \param ctx_handle Execution context.
+ * \param const_vars_handle The variables that current operation will use
+ * but not mutate.
+ * \param num_const_vars The number of const_vars.
+ * \param mutable_vars_handle The variables that current operation will
mutate.
+ * \param num_mutable_vars The number of mutable_vars.
+ * \param prop_handle Property of the function.
+ * \param priority Priority of the action, as hint to the engine.
+ * \param opr_name The operation name.
+ */
+MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle
ctx_handle,
+ EngineVarHandle const_vars_handle, int
num_const_vars,
+ EngineVarHandle mutable_vars_handle, int
num_mutable_vars,
+ EngineFnPropertyHandle prop_handle = NULL, int
priority = 0,
+ const char* opr_name = NULL);
#ifdef __cplusplus
}
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 80bd605..0568c19 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1389,3 +1389,91 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int
shared_id, const mx_uint *s
*out = new NDArray(shared_pid, shared_id, TShape(shape, shape + ndim),
dtype);
API_END();
}
+
+typedef Engine::VarHandle VarHandle;
+typedef Engine::CallbackOnComplete CallbackOnComplete;
+
+void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
+ CHECK_GE(num_const_vars, 0) << "Non-negative number of const vars expected.";
+ CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars
expected.";
+}
+
+int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ EngineVarHandle const_vars_handle, int num_const_vars,
+ EngineVarHandle mutable_vars_handle, int
num_mutable_vars,
+ EngineFnPropertyHandle prop_handle, int priority,
+ const char* opr_name, bool wait) {
+ API_BEGIN();
+
+ auto exec_ctx = *static_cast<const Context*>(ctx_handle);
+ auto const_vars = static_cast<VarHandle*>(const_vars_handle);
+ auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
+ auto prop = FnProperty::kNormal;
+ if (prop_handle) {
+ prop = *static_cast<const FnProperty*>(prop_handle);
+ }
+
+ Engine::AsyncFn exec_fn;
+ if (deleter == nullptr) {
+ exec_fn = [async_func, func_param](RunContext rctx,
+ CallbackOnComplete on_complete) {
+ async_func(&rctx, &on_complete, func_param);
+ };
+ } else {
+ // Wrap func_param in a shared_ptr with deleter such that deleter
+ // will be called when the lambda goes out of scope.
+ std::shared_ptr<void> shared_func_param(func_param, deleter);
+ exec_fn = [async_func, shared_func_param](RunContext rctx,
+ CallbackOnComplete on_complete) {
+ async_func(&rctx, &on_complete, shared_func_param.get());
+ };
+ }
+
+ AssertValidNumberVars(num_const_vars, num_mutable_vars);
+ std::vector<VarHandle> const_var_vec(const_vars, const_vars +
num_const_vars);
+ std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars +
num_mutable_vars);
+ Engine::Get()->PushAsync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
+ prop, priority, opr_name, wait);
+
+ API_END();
+}
+
+int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
+ EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
+ EngineVarHandle const_vars_handle, int num_const_vars,
+ EngineVarHandle mutable_vars_handle, int num_mutable_vars,
+ EngineFnPropertyHandle prop_handle, int priority,
+ const char* opr_name) {
+ API_BEGIN();
+
+ auto exec_ctx = *static_cast<const Context*>(ctx_handle);
+ auto const_vars = static_cast<VarHandle*>(const_vars_handle);
+ auto mutable_vars = static_cast<VarHandle*>(mutable_vars_handle);
+ auto prop = FnProperty::kNormal;
+ if (prop_handle) {
+ prop = *static_cast<const FnProperty*>(prop_handle);
+ }
+
+ Engine::SyncFn exec_fn;
+ if (deleter == nullptr) {
+ exec_fn = [sync_func, func_param](RunContext rctx) {
+ sync_func(&rctx, func_param);
+ };
+ } else {
+ // Wrap func_param in a shared_ptr with deleter such that deleter
+ // will be called when the lambda goes out of scope.
+ std::shared_ptr<void> shared_func_param(func_param, deleter);
+ exec_fn = [sync_func, shared_func_param](RunContext rctx) {
+ sync_func(&rctx, shared_func_param.get());
+ };
+ }
+
+ AssertValidNumberVars(num_const_vars, num_mutable_vars);
+ std::vector<VarHandle> const_var_vec(const_vars, const_vars +
num_const_vars);
+ std::vector<VarHandle> mutable_var_vec(mutable_vars, mutable_vars +
num_mutable_vars);
+ Engine::Get()->PushSync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec,
+ prop, priority, opr_name);
+
+ API_END();
+}
diff --git a/tests/cpp/engine/threaded_engine_test.cc
b/tests/cpp/engine/threaded_engine_test.cc
index 6d669c1..405f3b3 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -27,6 +27,7 @@
#include <dmlc/thread_group.h>
#include <dmlc/omp.h>
#include <gtest/gtest.h>
+#include <mxnet/c_api.h>
#include <mxnet/engine.h>
#include <dmlc/timer.h>
#include <cstdio>
@@ -176,6 +177,83 @@ TEST(Engine, RandSumExpr) {
void Foo(mxnet::RunContext, int i) { printf("The fox says %d\n", i); }
+void FooAsyncFunc(void*, void* cb_ptr, void* param) {
+ if (param == nullptr) {
+ LOG(INFO) << "The fox asynchronously says receiving nothing.";
+ } else {
+ auto num = static_cast<int*>(param);
+ EXPECT_EQ(*num, 100);
+ LOG(INFO) << "The fox asynchronously says receiving " << *num;
+ }
+ auto cb = *static_cast<mxnet::engine::CallbackOnComplete*>(cb_ptr);
+ cb();
+}
+
+void FooSyncFunc(void*, void* param) {
+ if (param == nullptr) {
+ LOG(INFO) << "The fox synchronously says receiving nothing.";
+ } else {
+ auto num = static_cast<int*>(param);
+ EXPECT_EQ(*num, 101);
+ LOG(INFO) << "The fox synchronously says receiving " << *num;
+ }
+}
+
+void FooFuncDeleter(void* param) {
+ if (param != nullptr) {
+ auto num = static_cast<int*>(param);
+ LOG(INFO) << "The fox says deleting " << *num;
+ delete num;
+ }
+}
+
+TEST(Engine, PushFunc) {
+ auto var = mxnet::Engine::Get()->NewVariable();
+ auto ctx = mxnet::Context{};
+
+ // Test #1
+ LOG(INFO) << "===== Test #1: PushAsync param and deleter =====";
+ int* a = new int(100);
+ int res = MXEnginePushAsync(FooAsyncFunc, a, FooFuncDeleter, &ctx, &var, 1,
nullptr, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #2
+ LOG(INFO) << "===== Test #2: PushAsync NULL param and NULL deleter =====";
+ res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0,
&var, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #3
+ LOG(INFO) << "===== Test #3: PushAsync invalid number of const vars =====";
+ res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, &var, -1,
nullptr, 0);
+ EXPECT_EQ(res, -1);
+
+ // Test #4
+ LOG(INFO) << "===== Test #4: PushAsync invalid number of mutable vars =====";
+ res = MXEnginePushAsync(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0,
&var, -1);
+ EXPECT_EQ(res, -1);
+
+ // Test #5
+ LOG(INFO) << "===== Test #5: PushSync param and deleter =====";
+ int* b = new int(101);
+ res = MXEnginePushSync(FooSyncFunc, b, FooFuncDeleter, &ctx, &var, 1,
nullptr, 0);
+ EXPECT_EQ(res, 0);
+
+ // Test #6
+ LOG(INFO) << "===== Test #6: PushSync NULL param and NULL deleter =====";
+ res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0,
&var, 1);
+ EXPECT_EQ(res, 0);
+
+ // Test #7
+ LOG(INFO) << "===== Test #7: PushSync invalid number of const vars =====";
+ res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, &var, -1,
nullptr, 0);
+ EXPECT_EQ(res, -1);
+
+ // Test #8
+ LOG(INFO) << "===== Test #8: PushSync invalid number of mutable vars =====";
+ res = MXEnginePushSync(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0,
&var, -1);
+ EXPECT_EQ(res, -1);
+}
+
TEST(Engine, basics) {
auto&& engine = mxnet::Engine::Get();
auto&& var = engine->NewVariable();