cjolivier01 closed pull request #9883: added function for loading content of nd_array files from a buffer URL: https://github.com/apache/incubator-mxnet/pull/9883
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0f9019bbd2..c09c559b18 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -157,3 +157,4 @@ List of Contributors * [Tao Hu](https://github.com/dongzhuoyao) * [Sorokin Evgeniy](https://github.com/TheTweak) * [dwSun](https://github.com/dwSun/) +* [David Braude](https://github.com/dabraude/) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b1fdbf1db2..e85afe522f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -500,6 +500,28 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names); + +/*! + * \brief Load list / dictionary of narrays from file content loaded into memory. + * This will load a list of ndarrays in a similar + * manner to MXNDArrayLoad, however, it loads from + * buffer containing the contents of a file, rather than + * from a specified file. + * \param ndarray_buffer pointer to the start of the ndarray file content + * \param size size of the file + * \param out_size number of narray loaded. + * \param out_arr head of the returning narray handles. + * \param out_name_size size of output name arrray. + * \param out_names the names of returning NDArrays, can be NULL + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); + /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index fc4a55d8f5..21193f0790 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -27,7 +27,7 @@ from .op import * from .ndarray import * # pylint: enable=wildcard-import -from .utils import load, save, zeros, empty, array +from .utils import load, load_frombuffer, save, zeros, empty, array from .sparse import _ndarray_cls from .ndarray import _GRAD_REQ_MAP diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py index 4f597c749f..ff93d0be6d 100644 --- a/python/mxnet/ndarray/utils.py +++ b/python/mxnet/ndarray/utils.py @@ -34,7 +34,7 @@ except ImportError: spsp = None -__all__ = ['zeros', 'empty', 'array', 'load', 'save'] +__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save'] def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs): @@ -182,6 +182,43 @@ def load(fname): for i in range(out_size.value)) +def load_frombuffer(buf): + """Loads an array dictionary or list from a buffer + + See more details in ``save``. + + Parameters + ---------- + buf : str + Buffer containing contents of a file as a string or bytes. + + Returns + ------- + list of NDArray, RowSparseNDArray or CSRNDArray, or \ + dict of str to NDArray, RowSparseNDArray or CSRNDArray + Loaded data. + """ + if not isinstance(buf, string_types + tuple([bytes])): + raise TypeError('buf required to be a string or bytes') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoadFromBuffer(buf, + mx_uint(len(buf)), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) + + def save(fname, data): """Saves a list of arrays or a dict of str->array to file. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3ef0043aca..b41a142ab6 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -322,6 +322,40 @@ int MXNDArrayLoad(const char* fname, API_END(); } +int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + ret->ret_vec_str.clear(); + API_BEGIN(); + CHECK_NOTNULL(ndarray_buffer); + std::vector<NDArray> data; + std::vector<std::string> &names = ret->ret_vec_str; + { + std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new dmlc::MemoryFixedSizeStream( + const_cast<void*>(ndarray_buffer), size)); + mxnet::NDArray::Load(fi.get(), &data, &names); + } + ret->ret_handles.resize(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + NDArray *ptr = new NDArray(); + *ptr = data[i]; + ret->ret_handles[i] = ptr; + } + ret->ret_vec_charp.resize(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + ret->ret_vec_charp[i] = names[i].c_str(); + } + *out_size = static_cast<mx_uint>(data.size()); + *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_name_size = static_cast<mx_uint>(names.size()); + *out_names = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast<NDArray*>(handle); diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 35c7e3ffc8..635bdcc609 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -19,6 +19,7 @@ import mxnet as mx import numpy as np import random +import shutil curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../common/')) sys.path.insert(0, os.path.join(curr_path, '../../../python')) @@ -26,6 +27,7 @@ import models from contextlib import contextmanager from nose.tools import make_decorator +import tempfile def assertRaises(expected_exception, func, *args, **kwargs): try: @@ -225,3 +227,17 @@ def setup_module(): # the 'with_seed()' decoration. Inform the user of this once here at the module level. if os.getenv('MXNET_TEST_SEED') is not None: logger.warn('*** test-level seed set: all "@with_seed()" tests run deterministically ***') + +try: + from tempfile import TemporaryDirectory +except: + # really simple implementation of TemporaryDirectory + class TemporaryDirectory(object): + def __init__(self, suffix='', prefix='', dir=''): + self._dirname = tempfile.mkdtemp(suffix, prefix, dir) + + def __enter__(self): + return self._dirname + + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self._dirname) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 6c104878ab..a3694b2ce6 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -21,7 +21,7 @@ import pickle as pkl import unittest from nose.tools import raises -from common import setup_module, with_seed +from common import setup_module, with_seed, assertRaises, TemporaryDirectory from mxnet.test_utils import almost_equal from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import default_context @@ -291,6 +291,52 @@ def test_ndarray_legacy_load(): assert same(data[i].asnumpy(), legacy_data[i].asnumpy()) +@with_seed() +def test_buffer_load(): + nrepeat = 10 + with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir: + for repeat in range(nrepeat): + # test load_buffer as list + data = [] + for i in range(10): + data.append(random_ndarray(np.random.randint(1, 5))) + fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat)) + mx.nd.save(fname, data) + with open(fname, 'rb') as dfile: + buf_data = dfile.read() + data2 = mx.nd.load_frombuffer(buf_data) + assert len(data) == len(data2) + for x, y in zip(data, data2): + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_data[:-10]) + # test load_buffer as dict + dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)} + fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat)) + mx.nd.save(fname, dmap) + with open(fname, 'rb') as dfile: + buf_dmap = dfile.read() + dmap2 = mx.nd.load_frombuffer(buf_dmap) + assert len(dmap2) == len(dmap) + for k, x in dmap.items(): + y = dmap2[k] + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_dmap[:-10]) + + # we expect the single ndarray to be converted into a list containing the ndarray + single_ndarray = data[0] + fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat)) + mx.nd.save(fname, single_ndarray) + with open(fname, 'rb') as dfile: + buf_single_ndarray = dfile.read() + single_ndarray_loaded = mx.nd.load_frombuffer(buf_single_ndarray) + assert len(single_ndarray_loaded) == 1 + single_ndarray_loaded = single_ndarray_loaded[0] + assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10]) + @with_seed() def test_ndarray_slice(): shape = (10,) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services