This is an automated email from the ASF dual-hosted git repository. cjolivier01 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new a352d1e added function for loading content of nd_array files from a buffer (#9883) a352d1e is described below commit a352d1e2b31f7b479264ff3ba5fa2d9433f03013 Author: David Braude <dabra...@gmail.com> AuthorDate: Mon Feb 26 19:48:11 2018 +0000 added function for loading content of nd_array files from a buffer (#9883) * added function for loading content of nd_array files * changed function name and added check for NULL * removed no lint * fixed whitespace * corrected the casting * added python wrapper for buffer loading * added unit tests for loading from buffer * whitespace fixes * fix for python 3 * fixed test for py3 * python 3 problems * fixed test * switched to using temp files * better use of temp files * hopefully fixed permission issue * removed specified directory * hopefully this will work with windows * fixed indentation * check in to relaunch tests Python 3 windows failed for no obvious reason, deleted some whitespace to relaunch * switched to using temporary directory class * removed unneeded imports * moved imports to 1 location --- CONTRIBUTORS.md | 1 + include/mxnet/c_api.h | 22 ++++++++++++++++ python/mxnet/ndarray/__init__.py | 2 +- python/mxnet/ndarray/utils.py | 39 +++++++++++++++++++++++++++- src/c_api/c_api.cc | 34 +++++++++++++++++++++++++ tests/python/unittest/common.py | 16 ++++++++++++ tests/python/unittest/test_ndarray.py | 48 ++++++++++++++++++++++++++++++++++- 7 files changed, 159 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 0f9019b..c09c559 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -157,3 +157,4 @@ List of Contributors * [Tao Hu](https://github.com/dongzhuoyao) * [Sorokin Evgeniy](https://github.com/TheTweak) * [dwSun](https://github.com/dwSun/) +* [David Braude](https://github.com/dabraude/) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index b1fdbf1..e85afe5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -500,6 +500,28 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, NDArrayHandle** out_arr, mx_uint *out_name_size, const char*** out_names); + +/*! + * \brief Load list / dictionary of narrays from file content loaded into memory. + * This will load a list of ndarrays in a similar + * manner to MXNDArrayLoad, however, it loads from + * buffer containing the contents of a file, rather than + * from a specified file. + * \param ndarray_buffer pointer to the start of the ndarray file content + * \param size size of the file + * \param out_size number of narray loaded. + * \param out_arr head of the returning narray handles. + * \param out_name_size size of output name arrray. + * \param out_names the names of returning NDArrays, can be NULL + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names); + /*! * \brief Perform a synchronize copy from a continugous CPU memory region. * diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index fc4a55d..21193f0 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -27,7 +27,7 @@ from . import register from .op import * from .ndarray import * # pylint: enable=wildcard-import -from .utils import load, save, zeros, empty, array +from .utils import load, load_frombuffer, save, zeros, empty, array from .sparse import _ndarray_cls from .ndarray import _GRAD_REQ_MAP diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py index 4f597c7..ff93d0b 100644 --- a/python/mxnet/ndarray/utils.py +++ b/python/mxnet/ndarray/utils.py @@ -34,7 +34,7 @@ try: except ImportError: spsp = None -__all__ = ['zeros', 'empty', 'array', 'load', 'save'] +__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save'] def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs): @@ -182,6 +182,43 @@ def load(fname): for i in range(out_size.value)) +def load_frombuffer(buf): + """Loads an array dictionary or list from a buffer + + See more details in ``save``. + + Parameters + ---------- + buf : str + Buffer containing contents of a file as a string or bytes. + + Returns + ------- + list of NDArray, RowSparseNDArray or CSRNDArray, or \ + dict of str to NDArray, RowSparseNDArray or CSRNDArray + Loaded data. + """ + if not isinstance(buf, string_types + tuple([bytes])): + raise TypeError('buf required to be a string or bytes') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoadFromBuffer(buf, + mx_uint(len(buf)), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) + + def save(fname, data): """Saves a list of arrays or a dict of str->array to file. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3ef0043..b41a142 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -322,6 +322,40 @@ int MXNDArrayLoad(const char* fname, API_END(); } +int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, + size_t size, + mx_uint *out_size, + NDArrayHandle** out_arr, + mx_uint *out_name_size, + const char*** out_names) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + ret->ret_vec_str.clear(); + API_BEGIN(); + CHECK_NOTNULL(ndarray_buffer); + std::vector<NDArray> data; + std::vector<std::string> &names = ret->ret_vec_str; + { + std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new dmlc::MemoryFixedSizeStream( + const_cast<void*>(ndarray_buffer), size)); + mxnet::NDArray::Load(fi.get(), &data, &names); + } + ret->ret_handles.resize(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + NDArray *ptr = new NDArray(); + *ptr = data[i]; + ret->ret_handles[i] = ptr; + } + ret->ret_vec_charp.resize(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + ret->ret_vec_charp[i] = names[i].c_str(); + } + *out_size = static_cast<mx_uint>(data.size()); + *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_name_size = static_cast<mx_uint>(names.size()); + *out_names = dmlc::BeginPtr(ret->ret_vec_charp); + API_END(); +} + int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast<NDArray*>(handle); diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py index 35c7e3f..635bdcc 100644 --- a/tests/python/unittest/common.py +++ b/tests/python/unittest/common.py @@ -19,6 +19,7 @@ import sys, os, logging import mxnet as mx import numpy as np import random +import shutil curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.append(os.path.join(curr_path, '../common/')) sys.path.insert(0, os.path.join(curr_path, '../../../python')) @@ -26,6 +27,7 @@ sys.path.insert(0, os.path.join(curr_path, '../../../python')) import models from contextlib import contextmanager from nose.tools import make_decorator +import tempfile def assertRaises(expected_exception, func, *args, **kwargs): try: @@ -225,3 +227,17 @@ def setup_module(): # the 'with_seed()' decoration. Inform the user of this once here at the module level. if os.getenv('MXNET_TEST_SEED') is not None: logger.warn('*** test-level seed set: all "@with_seed()" tests run deterministically ***') + +try: + from tempfile import TemporaryDirectory +except: + # really simple implementation of TemporaryDirectory + class TemporaryDirectory(object): + def __init__(self, suffix='', prefix='', dir=''): + self._dirname = tempfile.mkdtemp(suffix, prefix, dir) + + def __enter__(self): + return self._dirname + + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self._dirname) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 52697c2..0daf74a 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -21,7 +21,7 @@ import os import pickle as pkl import unittest from nose.tools import raises -from common import setup_module, with_seed +from common import setup_module, with_seed, assertRaises, TemporaryDirectory from mxnet.test_utils import almost_equal from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import default_context @@ -292,6 +292,52 @@ def test_ndarray_legacy_load(): @with_seed() +def test_buffer_load(): + nrepeat = 10 + with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir: + for repeat in range(nrepeat): + # test load_buffer as list + data = [] + for i in range(10): + data.append(random_ndarray(np.random.randint(1, 5))) + fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat)) + mx.nd.save(fname, data) + with open(fname, 'rb') as dfile: + buf_data = dfile.read() + data2 = mx.nd.load_frombuffer(buf_data) + assert len(data) == len(data2) + for x, y in zip(data, data2): + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_data[:-10]) + # test load_buffer as dict + dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)} + fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat)) + mx.nd.save(fname, dmap) + with open(fname, 'rb') as dfile: + buf_dmap = dfile.read() + dmap2 = mx.nd.load_frombuffer(buf_dmap) + assert len(dmap2) == len(dmap) + for k, x in dmap.items(): + y = dmap2[k] + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_dmap[:-10]) + + # we expect the single ndarray to be converted into a list containing the ndarray + single_ndarray = data[0] + fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat)) + mx.nd.save(fname, single_ndarray) + with open(fname, 'rb') as dfile: + buf_single_ndarray = dfile.read() + single_ndarray_loaded = mx.nd.load_frombuffer(buf_single_ndarray) + assert len(single_ndarray_loaded) == 1 + single_ndarray_loaded = single_ndarray_loaded[0] + assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0 + # test garbage values + assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10]) + +@with_seed() def test_ndarray_slice(): shape = (10,) A = mx.nd.array(np.random.uniform(-10, 10, shape)) -- To stop receiving notification emails like this one, please contact cjolivie...@apache.org.