This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 74e71e9 [MXNET-400] support string type for kvstore key in
cpp-package (#10792)
74e71e9 is described below
commit 74e71e9b55f3a2e2650e33d6cb0b6b78ad44e9b4
Author: nihui <[email protected]>
AuthorDate: Tue Apr 9 15:56:04 2019 +0800
[MXNET-400] support string type for kvstore key in cpp-package (#10792)
* Kvstore strkey (#2)
* support string type for kvstore key in cpp-package
* make lines short
* fix build
* add kvstore testcase
* no rand() use
* fix cpplint sanity check
* support string type for kvstore key in cpp-package
* make lines short
* fix build
* print error log
* Update test_kvstore.cpp
* update
* add gpu unittest
* check gpu count
* fix sanity check
---
cpp-package/example/test_kvstore.cpp | 201 ++++++++++++++++++++++++++++++
cpp-package/include/mxnet-cpp/kvstore.h | 13 +-
cpp-package/include/mxnet-cpp/kvstore.hpp | 78 +++++++++++-
cpp-package/tests/ci_test.sh | 7 +-
4 files changed, 292 insertions(+), 7 deletions(-)
diff --git a/cpp-package/example/test_kvstore.cpp
b/cpp-package/example/test_kvstore.cpp
new file mode 100644
index 0000000..d9e0400
--- /dev/null
+++ b/cpp-package/example/test_kvstore.cpp
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "mxnet/c_api.h" // MXGetGPUCount()
+#include "mxnet-cpp/MxNetCpp.h"
+
+using namespace mxnet::cpp;
+
+static bool test_single_key(const Context &context, const std::string
&context_str) {
+ std::string key = "singlekeytest-" + context_str;
+
+ NDArray result(Shape(4), context);
+ NDArray result_cpu;
+
+ // initialize data
+ NDArray data_cpu({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
+ NDArray data = data_cpu.Copy(context);
+ NDArray::WaitAll();
+
+ KVStore::Init(key, data);
+ NDArray::WaitAll();
+
+ // retrieve result
+ KVStore::Pull(key, &result);
+ NDArray::WaitAll();
+
+ result_cpu = result.Copy(Context::cpu());
+ NDArray::WaitAll();
+
+ // compare
+ for (size_t j=0; j < result_cpu.Size(); j++) {
+ if (result_cpu.GetData()[j] != data_cpu.GetData()[j]) {
+ LG << "Error: wrong initialized data in singlekeytest-" << context_str
+ << ", expect " << data_cpu.GetData()[j]
+ << " got " << result_cpu.GetData()[j];
+ return false;
+ }
+ }
+
+ // push gradient
+ NDArray grad_cpu({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
+ NDArray grad = grad_cpu.Copy(context);
+ NDArray::WaitAll();
+
+ KVStore::Push(key, grad);
+ NDArray::WaitAll();
+
+ // retrieve result
+ KVStore::Pull(key, &result);
+ NDArray::WaitAll();
+
+ result_cpu = result.Copy(Context::cpu());
+ NDArray::WaitAll();
+
+ // compare
+ for (size_t j=0; j < result_cpu.Size(); j++) {
+ if (result_cpu.GetData()[j] != grad_cpu.GetData()[j]) {
+ LG << "Error: wrong gradient data in singlekeytest-" << context_str
+ << ", expect " << grad_cpu.GetData()[j]
+ << " got " << result_cpu.GetData()[j];
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool test_multiple_key(const Context &context, const std::string
&context_str) {
+ std::vector<std::string> keys(2);
+ keys[0] = "multikeytest-0-" + context_str;
+ keys[1] = "multikeytest-1-" + context_str;
+
+ std::vector<NDArray> results(2);
+ results[0] = NDArray(Shape(4), context);
+ results[1] = NDArray(Shape(4), context);
+ std::vector<NDArray> results_cpu(2);
+
+ // initialize data
+ std::vector<NDArray> data_cpu(2);
+ data_cpu[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
+ data_cpu[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
+ std::vector<NDArray> data(2);
+ data[0] = data_cpu[0].Copy(context);
+ data[1] = data_cpu[1].Copy(context);
+ NDArray::WaitAll();
+
+ KVStore::Init(keys, data);
+ NDArray::WaitAll();
+
+ // retrieve result
+ KVStore::Pull(keys, &results);
+ NDArray::WaitAll();
+
+ results_cpu[0] = results[0].Copy(Context::cpu());
+ results_cpu[1] = results[1].Copy(Context::cpu());
+ NDArray::WaitAll();
+
+ // compare
+ for (size_t i=0; i < results_cpu.size(); i++) {
+ for (size_t j=0; j < results_cpu[i].Size(); j++) {
+ if (results_cpu[i].GetData()[j] != data_cpu[i].GetData()[j]) {
+ LG << "Error: wrong initialized data in multikeytest-" << context_str
+ << ", expect " << data_cpu[i].GetData()[j]
+ << " got " << results_cpu[i].GetData()[j];
+ return false;
+ }
+ }
+ }
+
+ // push gradient, reduce for the second
+ std::vector<std::string> push_keys(3);
+ push_keys[0] = "multikeytest-0-" + context_str;
+ push_keys[1] = "multikeytest-1-" + context_str;
+ push_keys[2] = "multikeytest-1-" + context_str;
+
+ std::vector<NDArray> grads_cpu(3);
+ grads_cpu[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
+ grads_cpu[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
+ grads_cpu[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
+ std::vector<NDArray> grads(3);
+ grads[0] = grads_cpu[0].Copy(context);
+ grads[1] = grads_cpu[1].Copy(context);
+ grads[2] = grads_cpu[2].Copy(context);
+ NDArray::WaitAll();
+
+ KVStore::Push(push_keys, grads);
+ NDArray::WaitAll();
+
+ // retrieve result
+ KVStore::Pull(keys, &results);
+ NDArray::WaitAll();
+
+ results_cpu[0] = results[0].Copy(Context::cpu());
+ results_cpu[1] = results[1].Copy(Context::cpu());
+ NDArray::WaitAll();
+
+ // compare the first
+ for (size_t j=0; j < results_cpu[0].Size(); j++) {
+ if (results_cpu[0].GetData()[j] != grads_cpu[0].GetData()[j]) {
+ LG << "Error: wrong gradient data in multikeytest-" << context_str
+ << ", expect " << grads_cpu[0].GetData()[j]
+ << " got " << results_cpu[0].GetData()[j];
+ return false;
+ }
+ }
+
+ // compare the second
+ for (size_t j=0; j < results_cpu[1].Size(); j++) {
+ if (results_cpu[1].GetData()[j] != (grads_cpu[1].GetData()[j] +
grads_cpu[2].GetData()[j])) {
+ LG << "Error: wrong reduced gradient data in multikeytest-" <<
context_str
+ << ", expect " << (grads_cpu[1].GetData()[j] +
grads_cpu[2].GetData()[j])
+ << " got " << results_cpu[1].GetData()[j];
+ return false;
+ }
+ }
+
+ return true;
+}
+
+int main(int argc, char** argv) {
+ KVStore::SetType("local");
+
+ bool success1 = test_single_key(Context::cpu(), "cpu");
+ bool success2 = test_multiple_key(Context::cpu(), "cpu");
+
+ bool success3 = true;
+ bool success4 = true;
+
+ int gpu_count = 0;
+ if (MXGetGPUCount(&gpu_count) != 0) {
+ LG << "Error: MXGetGPUCount";
+
+ MXNotifyShutdown();
+ return 1;
+ }
+
+ if (gpu_count > 0) {
+ success3 = test_single_key(Context::gpu(), "gpu");
+ success4 = test_multiple_key(Context::gpu(), "gpu");
+ }
+
+ int ret = (success1 && success2 && success3 && success4) ? 0 : 1;
+
+ MXNotifyShutdown();
+ return ret;
+}
diff --git a/cpp-package/include/mxnet-cpp/kvstore.h
b/cpp-package/include/mxnet-cpp/kvstore.h
index d5aa150..67f984f 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.h
+++ b/cpp-package/include/mxnet-cpp/kvstore.h
@@ -39,12 +39,21 @@ class KVStore {
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
+ static void Init(const std::string& key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>&
vals);
+ static void Init(const std::vector<std::string>& keys, const
std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
+ static void Push(const std::string& key, const NDArray& val, int priority =
0);
static void Push(const std::vector<int>& keys,
- const std::vector<NDArray>& vals, int priority = 0);
+ const std::vector<NDArray>& vals, int priority = 0);
+ static void Push(const std::vector<std::string>& keys,
+ const std::vector<NDArray>& vals, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
- static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs,
int priority = 0);
+ static void Pull(const std::string& key, NDArray* out, int priority = 0);
+ static void Pull(const std::vector<int>& keys,
+ std::vector<NDArray>* outs, int priority = 0);
+ static void Pull(const std::vector<std::string>& keys,
+ std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local =
false);
static std::string GetType();
diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp
b/cpp-package/include/mxnet-cpp/kvstore.hpp
index f2b5e74..6cd405b 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.hpp
+++ b/cpp-package/include/mxnet-cpp/kvstore.hpp
@@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle),
0);
}
+inline void KVStore::Init(const std::string& key, const NDArray& val) {
+ const char* key_handle = key.c_str();
+ NDArrayHandle val_handle = val.GetHandle();
+ CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle,
&val_handle), 0);
+}
+
inline void KVStore::Init(const std::vector<int>& keys, const
std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
@@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys,
const std::vector<NDArra
val_handles.data()), 0);
}
+inline void KVStore::Init(const std::vector<std::string>& keys, const
std::vector<NDArray>& vals) {
+ CHECK_EQ(keys.size(), vals.size());
+ std::vector<const char*> key_handles(keys.size());
+ std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+ [](const std::string& key) {
+ return key.c_str();
+ });
+ std::vector<NDArrayHandle> val_handles(vals.size());
+ std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
+ [](const NDArray& val) {
+ return val.GetHandle();
+ });
+
+ CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(),
key_handles.data(),
+ val_handles.data()), 0);
+}
+
inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle,
priority), 0);
}
+inline void KVStore::Push(const std::string& key, const NDArray& val, int
priority) {
+ const char* key_handle = key.c_str();
+ NDArrayHandle val_handle = val.GetHandle();
+ CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle,
&val_handle, priority), 0);
+}
+
inline void KVStore::Push(const std::vector<int>& keys,
- const std::vector<NDArray>& vals,
- int priority) {
+ const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
@@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
val_handles.data(), priority), 0);
}
+inline void KVStore::Push(const std::vector<std::string>& keys,
+ const std::vector<NDArray>& vals, int priority) {
+ CHECK_EQ(keys.size(), vals.size());
+ std::vector<const char*> key_handles(keys.size());
+ std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+ [](const std::string& key) {
+ return key.c_str();
+ });
+ std::vector<NDArrayHandle> val_handles(vals.size());
+ std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
+ [](const NDArray& val) {
+ return val.GetHandle();
+ });
+
+ CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(),
key_handles.data(),
+ val_handles.data(), priority), 0);
+}
+
inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle,
priority), 0);
}
-inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>*
outs, int priority) {
+inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
+ const char* key_handle = key.c_str();
+ NDArrayHandle out_handle = out->GetHandle();
+ CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle,
&out_handle, priority), 0);
+}
+
+inline void KVStore::Pull(const std::vector<int>& keys,
+ std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());
std::vector<NDArrayHandle> out_handles(keys.size());
@@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys,
std::vector<NDArray>* ou
out_handles.data(), priority), 0);
}
+inline void KVStore::Pull(const std::vector<std::string>& keys,
+ std::vector<NDArray>* outs, int priority) {
+ CHECK_EQ(keys.size(), outs->size());
+
+ std::vector<const char*> key_handles(keys.size());
+ std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+ [](const std::string& key) {
+ return key.c_str();
+ });
+ std::vector<NDArrayHandle> out_handles(keys.size());
+ std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
+ [](const NDArray& val) {
+ return val.GetHandle();
+ });
+
+ CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(),
key_handles.data(),
+ out_handles.data(), priority), 0);
+}
+
inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh
index 18fabea..2d1f8e4 100755
--- a/cpp-package/tests/ci_test.sh
+++ b/cpp-package/tests/ci_test.sh
@@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
cp ../../build/cpp-package/example/mlp_gpu .
./mlp_gpu
- cp ../../build/cpp-package/example/test_optimizer .
- ./test_optimizer
+cp ../../build/cpp-package/example/test_optimizer .
+./test_optimizer
+
+cp ../../build/cpp-package/example/test_kvstore .
+./test_kvstore
cp ../../build/cpp-package/example/test_score .
./test_score 0.93