cjolivier01 closed pull request #10261: [MXNET-128] added load from buffer 
functions
URL: https://github.com/apache/incubator-mxnet/pull/10261
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/cpp-package/include/mxnet-cpp/ndarray.h 
b/cpp-package/include/mxnet-cpp/ndarray.h
index 1166643e4e8..6f37d91aa68 100644
--- a/cpp-package/include/mxnet-cpp/ndarray.h
+++ b/cpp-package/include/mxnet-cpp/ndarray.h
@@ -398,6 +398,32 @@ class NDArray {
   */
   static std::vector<NDArray> LoadToList(const std::string &file_name);
   /*!
+  * \brief Load NDArrays from buffer.
+  * \param buffer Pointer to buffer. (ie contents of param file)
+  * \param size Size of buffer
+  * \param array_list a list of NDArrays returned, do not fill the list if
+  * nullptr is given.
+  * \param array_map a map from names to NDArrays returned, do not fill the map
+  * if nullptr is given or no names is stored in binary file.
+  */
+  static void LoadFromBuffer(const void *buffer, size_t size,
+                   std::vector<NDArray> *array_list = nullptr,
+                   std::map<std::string, NDArray> *array_map = nullptr);
+  /*!
+  * \brief Load map of NDArrays from buffer.
+  * \param buffer Pointer to buffer. (ie contents of param file)
+  * \param size Size of buffer
+  * \return a list of NDArrays.
+  */
+  static std::map<std::string, NDArray> LoadFromBufferToMap(const void 
*buffer, size_t size);
+  /*!
+  * \brief Load list of NDArrays from buffer.
+  * \param buffer Pointer to buffer. (ie contents of param file)
+  * \param size Size of buffer
+  * \return a map from names to NDArrays.
+  */
+  static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t 
size);
+  /*!
   * \brief save a map of string->NDArray to binary file.
   * \param file_name name of the binary file.
   * \param array_map a map from names to NDArrays.
diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp 
b/cpp-package/include/mxnet-cpp/ndarray.hpp
index 3c3b85d3732..966cf75c912 100644
--- a/cpp-package/include/mxnet-cpp/ndarray.hpp
+++ b/cpp-package/include/mxnet-cpp/ndarray.hpp
@@ -255,6 +255,7 @@ inline void NDArray::Load(const std::string &file_name,
                          &out_names),
            0);
   if (array_list != nullptr) {
+    array_list->reserve(out_size);
     for (mx_uint i = 0; i < out_size; ++i) {
       array_list->push_back(NDArray(out_arr[i]));
     }
@@ -291,6 +292,60 @@ inline std::vector<NDArray> NDArray::LoadToList(const 
std::string &file_name) {
   CHECK_EQ(MXNDArrayLoad(file_name.c_str(), &out_size, &out_arr, 
&out_name_size,
                          &out_names),
            0);
+  array_list.reserve(out_size);
+  for (mx_uint i = 0; i < out_size; ++i) {
+    array_list.push_back(NDArray(out_arr[i]));
+  }
+  return array_list;
+}
+inline void NDArray::LoadFromBuffer(const void *buffer, size_t size,
+                          std::vector<NDArray> *array_list,
+                          std::map<std::string, NDArray> *array_map) {
+  mx_uint out_size, out_name_size;
+  NDArrayHandle *out_arr;
+  const char **out_names;
+  CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, 
&out_name_size,
+                         &out_names),
+           0);
+  if (array_list != nullptr) {
+    array_list->reserve(out_size);
+    for (mx_uint i = 0; i < out_size; ++i) {
+      array_list->push_back(NDArray(out_arr[i]));
+    }
+  }
+  if (array_map != nullptr && out_name_size > 0) {
+    CHECK_EQ(out_name_size, out_size);
+    for (mx_uint i = 0; i < out_size; ++i) {
+      (*array_map)[out_names[i]] = NDArray(out_arr[i]);
+    }
+  }
+}
+inline std::map<std::string, NDArray> NDArray::LoadFromBufferToMap(
+    const void *buffer, size_t size) {
+  std::map<std::string, NDArray> array_map;
+  mx_uint out_size, out_name_size;
+  NDArrayHandle *out_arr;
+  const char **out_names;
+  CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, 
&out_name_size,
+                         &out_names),
+           0);
+  if (out_name_size > 0) {
+    CHECK_EQ(out_name_size, out_size);
+    for (mx_uint i = 0; i < out_size; ++i) {
+      array_map[out_names[i]] = NDArray(out_arr[i]);
+    }
+  }
+  return array_map;
+}
+inline std::vector<NDArray> NDArray::LoadFromBufferToList(const void *buffer, 
size_t size) {
+  std::vector<NDArray> array_list;
+  mx_uint out_size, out_name_size;
+  NDArrayHandle *out_arr;
+  const char **out_names;
+  CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, 
&out_name_size,
+                         &out_names),
+           0);
+  array_list.reserve(out_size);
   for (mx_uint i = 0; i < out_size; ++i) {
     array_list.push_back(NDArray(out_arr[i]));
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to