This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 74e71e9  [MXNET-400] support string type for kvstore key in 
cpp-package (#10792)
74e71e9 is described below

commit 74e71e9b55f3a2e2650e33d6cb0b6b78ad44e9b4
Author: nihui <[email protected]>
AuthorDate: Tue Apr 9 15:56:04 2019 +0800

    [MXNET-400] support string type for kvstore key in cpp-package (#10792)
    
    * Kvstore strkey (#2)
    
    * support string type for kvstore key in cpp-package
    
    * make lines short
    
    * fix build
    
    * add kvstore testcase
    
    * no rand() use
    
    * fix cpplint sanity check
    
    * support string type for kvstore key in cpp-package
    
    * make lines short
    
    * fix build
    
    * print error log
    
    * Update test_kvstore.cpp
    
    * update
    
    * add gpu unittest
    
    * check gpu count
    
    * fix sanity check
---
 cpp-package/example/test_kvstore.cpp      | 201 ++++++++++++++++++++++++++++++
 cpp-package/include/mxnet-cpp/kvstore.h   |  13 +-
 cpp-package/include/mxnet-cpp/kvstore.hpp |  78 +++++++++++-
 cpp-package/tests/ci_test.sh              |   7 +-
 4 files changed, 292 insertions(+), 7 deletions(-)

diff --git a/cpp-package/example/test_kvstore.cpp 
b/cpp-package/example/test_kvstore.cpp
new file mode 100644
index 0000000..d9e0400
--- /dev/null
+++ b/cpp-package/example/test_kvstore.cpp
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "mxnet/c_api.h"  // MXGetGPUCount()
+#include "mxnet-cpp/MxNetCpp.h"
+
+using namespace mxnet::cpp;
+
+static bool test_single_key(const Context &context, const std::string 
&context_str) {
+  std::string key = "singlekeytest-" + context_str;
+
+  NDArray result(Shape(4), context);
+  NDArray result_cpu;
+
+  // initialize data
+  NDArray data_cpu({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
+  NDArray data = data_cpu.Copy(context);
+  NDArray::WaitAll();
+
+  KVStore::Init(key, data);
+  NDArray::WaitAll();
+
+  // retrieve result
+  KVStore::Pull(key, &result);
+  NDArray::WaitAll();
+
+  result_cpu = result.Copy(Context::cpu());
+  NDArray::WaitAll();
+
+  // compare
+  for (size_t j=0; j < result_cpu.Size(); j++) {
+    if (result_cpu.GetData()[j] != data_cpu.GetData()[j]) {
+      LG << "Error: wrong initialized data in singlekeytest-" << context_str
+          << ", expect " << data_cpu.GetData()[j]
+          << " got " << result_cpu.GetData()[j];
+      return false;
+    }
+  }
+
+  // push gradient
+  NDArray grad_cpu({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
+  NDArray grad = grad_cpu.Copy(context);
+  NDArray::WaitAll();
+
+  KVStore::Push(key, grad);
+  NDArray::WaitAll();
+
+  // retrieve result
+  KVStore::Pull(key, &result);
+  NDArray::WaitAll();
+
+  result_cpu = result.Copy(Context::cpu());
+  NDArray::WaitAll();
+
+  // compare
+  for (size_t j=0; j < result_cpu.Size(); j++) {
+    if (result_cpu.GetData()[j] != grad_cpu.GetData()[j]) {
+      LG << "Error: wrong gradient data in singlekeytest-" << context_str
+          << ", expect " << grad_cpu.GetData()[j]
+          << " got " << result_cpu.GetData()[j];
+      return false;
+    }
+  }
+
+  return true;
+}
+
+static bool test_multiple_key(const Context &context, const std::string 
&context_str) {
+  std::vector<std::string> keys(2);
+  keys[0] = "multikeytest-0-" + context_str;
+  keys[1] = "multikeytest-1-" + context_str;
+
+  std::vector<NDArray> results(2);
+  results[0] = NDArray(Shape(4), context);
+  results[1] = NDArray(Shape(4), context);
+  std::vector<NDArray> results_cpu(2);
+
+  // initialize data
+  std::vector<NDArray> data_cpu(2);
+  data_cpu[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
+  data_cpu[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
+  std::vector<NDArray> data(2);
+  data[0] = data_cpu[0].Copy(context);
+  data[1] = data_cpu[1].Copy(context);
+  NDArray::WaitAll();
+
+  KVStore::Init(keys, data);
+  NDArray::WaitAll();
+
+  // retrieve result
+  KVStore::Pull(keys, &results);
+  NDArray::WaitAll();
+
+  results_cpu[0] = results[0].Copy(Context::cpu());
+  results_cpu[1] = results[1].Copy(Context::cpu());
+  NDArray::WaitAll();
+
+  // compare
+  for (size_t i=0; i < results_cpu.size(); i++) {
+    for (size_t j=0; j < results_cpu[i].Size(); j++) {
+      if (results_cpu[i].GetData()[j] != data_cpu[i].GetData()[j]) {
+        LG << "Error: wrong initialized data in multikeytest-" << context_str
+            << ", expect " << data_cpu[i].GetData()[j]
+            << " got " << results_cpu[i].GetData()[j];
+        return false;
+      }
+    }
+  }
+
+  // push gradient, reduce for the second
+  std::vector<std::string> push_keys(3);
+  push_keys[0] = "multikeytest-0-" + context_str;
+  push_keys[1] = "multikeytest-1-" + context_str;
+  push_keys[2] = "multikeytest-1-" + context_str;
+
+  std::vector<NDArray> grads_cpu(3);
+  grads_cpu[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
+  grads_cpu[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
+  grads_cpu[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
+  std::vector<NDArray> grads(3);
+  grads[0] = grads_cpu[0].Copy(context);
+  grads[1] = grads_cpu[1].Copy(context);
+  grads[2] = grads_cpu[2].Copy(context);
+  NDArray::WaitAll();
+
+  KVStore::Push(push_keys, grads);
+  NDArray::WaitAll();
+
+  // retrieve result
+  KVStore::Pull(keys, &results);
+  NDArray::WaitAll();
+
+  results_cpu[0] = results[0].Copy(Context::cpu());
+  results_cpu[1] = results[1].Copy(Context::cpu());
+  NDArray::WaitAll();
+
+  // compare the first
+  for (size_t j=0; j < results_cpu[0].Size(); j++) {
+    if (results_cpu[0].GetData()[j] != grads_cpu[0].GetData()[j]) {
+      LG << "Error: wrong gradient data in multikeytest-" << context_str
+          << ", expect " << grads_cpu[0].GetData()[j]
+          << " got " << results_cpu[0].GetData()[j];
+      return false;
+    }
+  }
+
+  // compare the second
+  for (size_t j=0; j < results_cpu[1].Size(); j++) {
+    if (results_cpu[1].GetData()[j] != (grads_cpu[1].GetData()[j] + 
grads_cpu[2].GetData()[j])) {
+      LG << "Error: wrong reduced gradient data in multikeytest-" << 
context_str
+          << ", expect " << (grads_cpu[1].GetData()[j] + 
grads_cpu[2].GetData()[j])
+          << " got " << results_cpu[1].GetData()[j];
+      return false;
+    }
+  }
+
+  return true;
+}
+
+int main(int argc, char** argv) {
+  KVStore::SetType("local");
+
+  bool success1 = test_single_key(Context::cpu(), "cpu");
+  bool success2 = test_multiple_key(Context::cpu(), "cpu");
+
+  bool success3 = true;
+  bool success4 = true;
+
+  int gpu_count = 0;
+  if (MXGetGPUCount(&gpu_count) != 0) {
+    LG << "Error: MXGetGPUCount";
+
+    MXNotifyShutdown();
+    return 1;
+  }
+
+  if (gpu_count > 0) {
+    success3 = test_single_key(Context::gpu(), "gpu");
+    success4 = test_multiple_key(Context::gpu(), "gpu");
+  }
+
+  int ret = (success1 && success2 && success3 && success4) ? 0 : 1;
+
+  MXNotifyShutdown();
+  return ret;
+}
diff --git a/cpp-package/include/mxnet-cpp/kvstore.h 
b/cpp-package/include/mxnet-cpp/kvstore.h
index d5aa150..67f984f 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.h
+++ b/cpp-package/include/mxnet-cpp/kvstore.h
@@ -39,12 +39,21 @@ class KVStore {
   static void SetType(const std::string& type);
   static void RunServer();
   static void Init(int key, const NDArray& val);
+  static void Init(const std::string& key, const NDArray& val);
   static void Init(const std::vector<int>& keys, const std::vector<NDArray>& 
vals);
+  static void Init(const std::vector<std::string>& keys, const 
std::vector<NDArray>& vals);
   static void Push(int key, const NDArray& val, int priority = 0);
+  static void Push(const std::string& key, const NDArray& val, int priority = 
0);
   static void Push(const std::vector<int>& keys,
-      const std::vector<NDArray>& vals, int priority = 0);
+                   const std::vector<NDArray>& vals, int priority = 0);
+  static void Push(const std::vector<std::string>& keys,
+                   const std::vector<NDArray>& vals, int priority = 0);
   static void Pull(int key, NDArray* out, int priority = 0);
-  static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, 
int priority = 0);
+  static void Pull(const std::string& key, NDArray* out, int priority = 0);
+  static void Pull(const std::vector<int>& keys,
+                   std::vector<NDArray>* outs, int priority = 0);
+  static void Pull(const std::vector<std::string>& keys,
+                   std::vector<NDArray>* outs, int priority = 0);
   // TODO(lx): put lr in optimizer or not?
   static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = 
false);
   static std::string GetType();
diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp 
b/cpp-package/include/mxnet-cpp/kvstore.hpp
index f2b5e74..6cd405b 100644
--- a/cpp-package/include/mxnet-cpp/kvstore.hpp
+++ b/cpp-package/include/mxnet-cpp/kvstore.hpp
@@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
   CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 
0);
 }
 
+inline void KVStore::Init(const std::string& key, const NDArray& val) {
+  const char* key_handle = key.c_str();
+  NDArrayHandle val_handle = val.GetHandle();
+  CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, 
&val_handle), 0);
+}
+
 inline void KVStore::Init(const std::vector<int>& keys, const 
std::vector<NDArray>& vals) {
   CHECK_EQ(keys.size(), vals.size());
   std::vector<NDArrayHandle> val_handles(vals.size());
@@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys, 
const std::vector<NDArra
       val_handles.data()), 0);
 }
 
+inline void KVStore::Init(const std::vector<std::string>& keys, const 
std::vector<NDArray>& vals) {
+  CHECK_EQ(keys.size(), vals.size());
+  std::vector<const char*> key_handles(keys.size());
+  std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+      [](const std::string& key) {
+        return key.c_str();
+      });
+  std::vector<NDArrayHandle> val_handles(vals.size());
+  std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
+      [](const NDArray& val) {
+        return val.GetHandle();
+      });
+
+  CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), 
key_handles.data(),
+      val_handles.data()), 0);
+}
+
 inline void KVStore::Push(int key, const NDArray& val, int priority) {
   NDArrayHandle val_handle = val.GetHandle();
   CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, 
priority), 0);
 }
 
+inline void KVStore::Push(const std::string& key, const NDArray& val, int 
priority) {
+  const char* key_handle = key.c_str();
+  NDArrayHandle val_handle = val.GetHandle();
+  CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, 
&val_handle, priority), 0);
+}
+
 inline void KVStore::Push(const std::vector<int>& keys,
-                          const std::vector<NDArray>& vals,
-                          int priority) {
+                          const std::vector<NDArray>& vals, int priority) {
   CHECK_EQ(keys.size(), vals.size());
   std::vector<NDArrayHandle> val_handles(vals.size());
   std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
@@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
       val_handles.data(), priority), 0);
 }
 
+inline void KVStore::Push(const std::vector<std::string>& keys,
+                          const std::vector<NDArray>& vals, int priority) {
+  CHECK_EQ(keys.size(), vals.size());
+  std::vector<const char*> key_handles(keys.size());
+  std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+      [](const std::string& key) {
+        return key.c_str();
+      });
+  std::vector<NDArrayHandle> val_handles(vals.size());
+  std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
+      [](const NDArray& val) {
+        return val.GetHandle();
+      });
+
+  CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), 
key_handles.data(),
+      val_handles.data(), priority), 0);
+}
+
 inline void KVStore::Pull(int key, NDArray* out, int priority) {
   NDArrayHandle out_handle = out->GetHandle();
   CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, 
priority), 0);
 }
 
-inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* 
outs, int priority) {
+inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
+  const char* key_handle = key.c_str();
+  NDArrayHandle out_handle = out->GetHandle();
+  CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, 
&out_handle, priority), 0);
+}
+
+inline void KVStore::Pull(const std::vector<int>& keys,
+                          std::vector<NDArray>* outs, int priority) {
   CHECK_EQ(keys.size(), outs->size());
 
   std::vector<NDArrayHandle> out_handles(keys.size());
@@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys, 
std::vector<NDArray>* ou
       out_handles.data(), priority), 0);
 }
 
+inline void KVStore::Pull(const std::vector<std::string>& keys,
+                          std::vector<NDArray>* outs, int priority) {
+  CHECK_EQ(keys.size(), outs->size());
+
+  std::vector<const char*> key_handles(keys.size());
+  std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
+      [](const std::string& key) {
+        return key.c_str();
+      });
+  std::vector<NDArrayHandle> out_handles(keys.size());
+  std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
+      [](const NDArray& val) {
+        return val.GetHandle();
+      });
+
+  CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), 
key_handles.data(),
+      out_handles.data(), priority), 0);
+}
+
 inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
                              void* handle_) {
   Optimizer *opt = static_cast<Optimizer*>(handle_);
diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh
index 18fabea..2d1f8e4 100755
--- a/cpp-package/tests/ci_test.sh
+++ b/cpp-package/tests/ci_test.sh
@@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
 cp ../../build/cpp-package/example/mlp_gpu .
 ./mlp_gpu
 
- cp ../../build/cpp-package/example/test_optimizer .
- ./test_optimizer
+cp ../../build/cpp-package/example/test_optimizer .
+./test_optimizer
+
+cp ../../build/cpp-package/example/test_kvstore .
+./test_kvstore
 
 cp ../../build/cpp-package/example/test_score .
 ./test_score 0.93

Reply via email to