cjolivier01 closed pull request #9883: added function for loading content of 
nd_array files from a buffer
URL: https://github.com/apache/incubator-mxnet/pull/9883
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 0f9019bbd2..c09c559b18 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -157,3 +157,4 @@ List of Contributors
 * [Tao Hu](https://github.com/dongzhuoyao)
 * [Sorokin Evgeniy](https://github.com/TheTweak)
 * [dwSun](https://github.com/dwSun/)
+* [David Braude](https://github.com/dabraude/)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index b1fdbf1db2..e85afe522f 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -500,6 +500,28 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
                             NDArrayHandle** out_arr,
                             mx_uint *out_name_size,
                             const char*** out_names);
+
+/*!
+ * \brief Load list / dictionary of narrays from file content loaded into 
memory.
+ * This will load a list of ndarrays in a similar
+ * manner to MXNDArrayLoad, however, it loads from
+ * buffer containing the contents of a file, rather than
+ * from a specified file.
+ * \param ndarray_buffer pointer to the start of the ndarray file content
+ * \param size size of the file
+ * \param out_size number of narray loaded.
+ * \param out_arr head of the returning narray handles.
+ * \param out_name_size size of output name arrray.
+ * \param out_names the names of returning NDArrays, can be NULL
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
+                            size_t size,
+                            mx_uint *out_size,
+                            NDArrayHandle** out_arr,
+                            mx_uint *out_name_size,
+                            const char*** out_names);
+
 /*!
  * \brief Perform a synchronize copy from a continugous CPU memory region.
  *
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index fc4a55d8f5..21193f0790 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -27,7 +27,7 @@
 from .op import *
 from .ndarray import *
 # pylint: enable=wildcard-import
-from .utils import load, save, zeros, empty, array
+from .utils import load, load_frombuffer, save, zeros, empty, array
 from .sparse import _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP
 
diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py
index 4f597c749f..ff93d0be6d 100644
--- a/python/mxnet/ndarray/utils.py
+++ b/python/mxnet/ndarray/utils.py
@@ -34,7 +34,7 @@
 except ImportError:
     spsp = None
 
-__all__ = ['zeros', 'empty', 'array', 'load', 'save']
+__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save']
 
 
 def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs):
@@ -182,6 +182,43 @@ def load(fname):
             for i in range(out_size.value))
 
 
+def load_frombuffer(buf):
+    """Loads an array dictionary or list from a buffer
+
+    See more details in ``save``.
+
+    Parameters
+    ----------
+    buf : str
+        Buffer containing contents of a file as a string or bytes.
+
+    Returns
+    -------
+    list of NDArray, RowSparseNDArray or CSRNDArray, or \
+    dict of str to NDArray, RowSparseNDArray or CSRNDArray
+        Loaded data.
+    """
+    if not isinstance(buf, string_types + tuple([bytes])):
+        raise TypeError('buf required to be a string or bytes')
+    out_size = mx_uint()
+    out_name_size = mx_uint()
+    handles = ctypes.POINTER(NDArrayHandle)()
+    names = ctypes.POINTER(ctypes.c_char_p)()
+    check_call(_LIB.MXNDArrayLoadFromBuffer(buf,
+                                            mx_uint(len(buf)),
+                                            ctypes.byref(out_size),
+                                            ctypes.byref(handles),
+                                            ctypes.byref(out_name_size),
+                                            ctypes.byref(names)))
+    if out_name_size.value == 0:
+        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in 
range(out_size.value)]
+    else:
+        assert out_name_size.value == out_size.value
+        return dict(
+            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
+            for i in range(out_size.value))
+
+
 def save(fname, data):
     """Saves a list of arrays or a dict of str->array to file.
 
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 3ef0043aca..b41a142ab6 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -322,6 +322,40 @@ int MXNDArrayLoad(const char* fname,
   API_END();
 }
 
+int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
+                            size_t size,
+                            mx_uint *out_size,
+                            NDArrayHandle** out_arr,
+                            mx_uint *out_name_size,
+                            const char*** out_names) {
+  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  ret->ret_vec_str.clear();
+  API_BEGIN();
+  CHECK_NOTNULL(ndarray_buffer);
+  std::vector<NDArray> data;
+  std::vector<std::string> &names = ret->ret_vec_str;
+  {
+    std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new 
dmlc::MemoryFixedSizeStream(
+        const_cast<void*>(ndarray_buffer), size));
+    mxnet::NDArray::Load(fi.get(), &data, &names);
+  }
+  ret->ret_handles.resize(data.size());
+  for (size_t i = 0; i < data.size(); ++i) {
+    NDArray *ptr = new NDArray();
+    *ptr = data[i];
+    ret->ret_handles[i] = ptr;
+  }
+  ret->ret_vec_charp.resize(names.size());
+  for (size_t i = 0; i < names.size(); ++i) {
+    ret->ret_vec_charp[i] = names[i].c_str();
+  }
+  *out_size = static_cast<mx_uint>(data.size());
+  *out_arr = dmlc::BeginPtr(ret->ret_handles);
+  *out_name_size = static_cast<mx_uint>(names.size());
+  *out_names = dmlc::BeginPtr(ret->ret_vec_charp);
+  API_END();
+}
+
 int MXNDArrayFree(NDArrayHandle handle) {
   API_BEGIN();
   delete static_cast<NDArray*>(handle);
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index 35c7e3ffc8..635bdcc609 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -19,6 +19,7 @@
 import mxnet as mx
 import numpy as np
 import random
+import shutil
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.append(os.path.join(curr_path, '../common/'))
 sys.path.insert(0, os.path.join(curr_path, '../../../python'))
@@ -26,6 +27,7 @@
 import models
 from contextlib import contextmanager
 from nose.tools import make_decorator
+import tempfile
 
 def assertRaises(expected_exception, func, *args, **kwargs):
     try:
@@ -225,3 +227,17 @@ def setup_module():
     #  the 'with_seed()' decoration.  Inform the user of this once here at the 
module level.
     if os.getenv('MXNET_TEST_SEED') is not None:
         logger.warn('*** test-level seed set: all "@with_seed()" tests run 
deterministically ***')
+
+try:
+    from tempfile import TemporaryDirectory
+except:
+    # really simple implementation of TemporaryDirectory
+    class TemporaryDirectory(object):
+        def __init__(self, suffix='', prefix='', dir=''):
+            self._dirname = tempfile.mkdtemp(suffix, prefix, dir)
+
+        def __enter__(self):
+            return self._dirname
+
+        def __exit__(self, exc_type, exc_value, traceback):
+            shutil.rmtree(self._dirname)
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 6c104878ab..a3694b2ce6 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -21,7 +21,7 @@
 import pickle as pkl
 import unittest
 from nose.tools import raises
-from common import setup_module, with_seed
+from common import setup_module, with_seed, assertRaises, TemporaryDirectory
 from mxnet.test_utils import almost_equal
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import default_context
@@ -291,6 +291,52 @@ def test_ndarray_legacy_load():
         assert same(data[i].asnumpy(), legacy_data[i].asnumpy())
 
 
+@with_seed()
+def test_buffer_load():
+    nrepeat = 10
+    with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
+        for repeat in range(nrepeat):
+            # test load_buffer as list
+            data = []
+            for i in range(10):
+                data.append(random_ndarray(np.random.randint(1, 5)))
+            fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
+            mx.nd.save(fname, data)
+            with open(fname, 'rb') as dfile:
+                buf_data = dfile.read()
+                data2 = mx.nd.load_frombuffer(buf_data)
+                assert len(data) == len(data2)
+                for x, y in zip(data, data2):
+                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_data[:-10])
+            # test load_buffer as dict
+            dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)}
+            fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
+            mx.nd.save(fname, dmap)
+            with open(fname, 'rb') as dfile:
+                buf_dmap = dfile.read()
+                dmap2 = mx.nd.load_frombuffer(buf_dmap)
+                assert len(dmap2) == len(dmap)
+                for k, x in dmap.items():
+                    y = dmap2[k]
+                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_dmap[:-10])
+
+            # we expect the single ndarray to be converted into a list 
containing the ndarray
+            single_ndarray = data[0]
+            fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
+            mx.nd.save(fname, single_ndarray)
+            with open(fname, 'rb') as dfile:
+                buf_single_ndarray = dfile.read()
+                single_ndarray_loaded = 
mx.nd.load_frombuffer(buf_single_ndarray)
+                assert len(single_ndarray_loaded) == 1
+                single_ndarray_loaded = single_ndarray_loaded[0]
+                assert np.sum(single_ndarray.asnumpy() != 
single_ndarray_loaded.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_single_ndarray[:-10])
+
 @with_seed()
 def test_ndarray_slice():
     shape = (10,)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to