cjolivier01 closed pull request #10261: [MXNET-128] added load from buffer functions URL: https://github.com/apache/incubator-mxnet/pull/10261
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/cpp-package/include/mxnet-cpp/ndarray.h b/cpp-package/include/mxnet-cpp/ndarray.h index 1166643e4e8..6f37d91aa68 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.h +++ b/cpp-package/include/mxnet-cpp/ndarray.h @@ -398,6 +398,32 @@ class NDArray { */ static std::vector<NDArray> LoadToList(const std::string &file_name); /*! + * \brief Load NDArrays from buffer. + * \param buffer Pointer to buffer. (ie contents of param file) + * \param size Size of buffer + * \param array_list a list of NDArrays returned, do not fill the list if + * nullptr is given. + * \param array_map a map from names to NDArrays returned, do not fill the map + * if nullptr is given or no names is stored in binary file. + */ + static void LoadFromBuffer(const void *buffer, size_t size, + std::vector<NDArray> *array_list = nullptr, + std::map<std::string, NDArray> *array_map = nullptr); + /*! + * \brief Load map of NDArrays from buffer. + * \param buffer Pointer to buffer. (ie contents of param file) + * \param size Size of buffer + * \return a list of NDArrays. + */ + static std::map<std::string, NDArray> LoadFromBufferToMap(const void *buffer, size_t size); + /*! + * \brief Load list of NDArrays from buffer. + * \param buffer Pointer to buffer. (ie contents of param file) + * \param size Size of buffer + * \return a map from names to NDArrays. + */ + static std::vector<NDArray> LoadFromBufferToList(const void *buffer, size_t size); + /*! * \brief save a map of string->NDArray to binary file. * \param file_name name of the binary file. * \param array_map a map from names to NDArrays. diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 3c3b85d3732..966cf75c912 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -255,6 +255,7 @@ inline void NDArray::Load(const std::string &file_name, &out_names), 0); if (array_list != nullptr) { + array_list->reserve(out_size); for (mx_uint i = 0; i < out_size; ++i) { array_list->push_back(NDArray(out_arr[i])); } @@ -291,6 +292,60 @@ inline std::vector<NDArray> NDArray::LoadToList(const std::string &file_name) { CHECK_EQ(MXNDArrayLoad(file_name.c_str(), &out_size, &out_arr, &out_name_size, &out_names), 0); + array_list.reserve(out_size); + for (mx_uint i = 0; i < out_size; ++i) { + array_list.push_back(NDArray(out_arr[i])); + } + return array_list; +} +inline void NDArray::LoadFromBuffer(const void *buffer, size_t size, + std::vector<NDArray> *array_list, + std::map<std::string, NDArray> *array_map) { + mx_uint out_size, out_name_size; + NDArrayHandle *out_arr; + const char **out_names; + CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size, + &out_names), + 0); + if (array_list != nullptr) { + array_list->reserve(out_size); + for (mx_uint i = 0; i < out_size; ++i) { + array_list->push_back(NDArray(out_arr[i])); + } + } + if (array_map != nullptr && out_name_size > 0) { + CHECK_EQ(out_name_size, out_size); + for (mx_uint i = 0; i < out_size; ++i) { + (*array_map)[out_names[i]] = NDArray(out_arr[i]); + } + } +} +inline std::map<std::string, NDArray> NDArray::LoadFromBufferToMap( + const void *buffer, size_t size) { + std::map<std::string, NDArray> array_map; + mx_uint out_size, out_name_size; + NDArrayHandle *out_arr; + const char **out_names; + CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size, + &out_names), + 0); + if (out_name_size > 0) { + CHECK_EQ(out_name_size, out_size); + for (mx_uint i = 0; i < out_size; ++i) { + array_map[out_names[i]] = NDArray(out_arr[i]); + } + } + return array_map; +} +inline std::vector<NDArray> NDArray::LoadFromBufferToList(const void *buffer, size_t size) { + std::vector<NDArray> array_list; + mx_uint out_size, out_name_size; + NDArrayHandle *out_arr; + const char **out_names; + CHECK_EQ(MXNDArrayLoadFromBuffer(buffer, size, &out_size, &out_arr, &out_name_size, + &out_names), + 0); + array_list.reserve(out_size); for (mx_uint i = 0; i < out_size; ++i) { array_list.push_back(NDArray(out_arr[i])); } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services