This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a352d1e  added function for loading content of nd_array files from a 
buffer (#9883)
a352d1e is described below

commit a352d1e2b31f7b479264ff3ba5fa2d9433f03013
Author: David Braude <dabra...@gmail.com>
AuthorDate: Mon Feb 26 19:48:11 2018 +0000

    added function for loading content of nd_array files from a buffer (#9883)
    
    * added function for loading content of nd_array files
    
    * changed function name and added check for NULL
    
    * removed no lint
    
    * fixed whitespace
    
    * corrected the casting
    
    * added python wrapper for buffer loading
    
    * added unit tests for loading from buffer
    
    * whitespace fixes
    
    * fix for python 3
    
    * fixed test for py3
    
    * python 3 problems
    
    * fixed test
    
    * switched to using temp files
    
    * better use of temp files
    
    * hopefully fixed permission issue
    
    * removed specified directory
    
    * hopefully this will work with windows
    
    * fixed indentation
    
    * check in to relaunch tests
    
    Python 3 windows failed for no obvious reason, deleted some whitespace to 
relaunch
    
    * switched to using temporary directory class
    
    * removed unneeded imports
    
    * moved imports to 1 location
---
 CONTRIBUTORS.md                       |  1 +
 include/mxnet/c_api.h                 | 22 ++++++++++++++++
 python/mxnet/ndarray/__init__.py      |  2 +-
 python/mxnet/ndarray/utils.py         | 39 +++++++++++++++++++++++++++-
 src/c_api/c_api.cc                    | 34 +++++++++++++++++++++++++
 tests/python/unittest/common.py       | 16 ++++++++++++
 tests/python/unittest/test_ndarray.py | 48 ++++++++++++++++++++++++++++++++++-
 7 files changed, 159 insertions(+), 3 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 0f9019b..c09c559 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -157,3 +157,4 @@ List of Contributors
 * [Tao Hu](https://github.com/dongzhuoyao)
 * [Sorokin Evgeniy](https://github.com/TheTweak)
 * [dwSun](https://github.com/dwSun/)
+* [David Braude](https://github.com/dabraude/)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index b1fdbf1..e85afe5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -500,6 +500,28 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
                             NDArrayHandle** out_arr,
                             mx_uint *out_name_size,
                             const char*** out_names);
+
+/*!
+ * \brief Load list / dictionary of narrays from file content loaded into 
memory.
+ * This will load a list of ndarrays in a similar
+ * manner to MXNDArrayLoad, however, it loads from
+ * buffer containing the contents of a file, rather than
+ * from a specified file.
+ * \param ndarray_buffer pointer to the start of the ndarray file content
+ * \param size size of the file
+ * \param out_size number of narray loaded.
+ * \param out_arr head of the returning narray handles.
+ * \param out_name_size size of output name arrray.
+ * \param out_names the names of returning NDArrays, can be NULL
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
+                            size_t size,
+                            mx_uint *out_size,
+                            NDArrayHandle** out_arr,
+                            mx_uint *out_name_size,
+                            const char*** out_names);
+
 /*!
  * \brief Perform a synchronize copy from a continugous CPU memory region.
  *
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index fc4a55d..21193f0 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -27,7 +27,7 @@ from . import register
 from .op import *
 from .ndarray import *
 # pylint: enable=wildcard-import
-from .utils import load, save, zeros, empty, array
+from .utils import load, load_frombuffer, save, zeros, empty, array
 from .sparse import _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP
 
diff --git a/python/mxnet/ndarray/utils.py b/python/mxnet/ndarray/utils.py
index 4f597c7..ff93d0b 100644
--- a/python/mxnet/ndarray/utils.py
+++ b/python/mxnet/ndarray/utils.py
@@ -34,7 +34,7 @@ try:
 except ImportError:
     spsp = None
 
-__all__ = ['zeros', 'empty', 'array', 'load', 'save']
+__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save']
 
 
 def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs):
@@ -182,6 +182,43 @@ def load(fname):
             for i in range(out_size.value))
 
 
+def load_frombuffer(buf):
+    """Loads an array dictionary or list from a buffer
+
+    See more details in ``save``.
+
+    Parameters
+    ----------
+    buf : str
+        Buffer containing contents of a file as a string or bytes.
+
+    Returns
+    -------
+    list of NDArray, RowSparseNDArray or CSRNDArray, or \
+    dict of str to NDArray, RowSparseNDArray or CSRNDArray
+        Loaded data.
+    """
+    if not isinstance(buf, string_types + tuple([bytes])):
+        raise TypeError('buf required to be a string or bytes')
+    out_size = mx_uint()
+    out_name_size = mx_uint()
+    handles = ctypes.POINTER(NDArrayHandle)()
+    names = ctypes.POINTER(ctypes.c_char_p)()
+    check_call(_LIB.MXNDArrayLoadFromBuffer(buf,
+                                            mx_uint(len(buf)),
+                                            ctypes.byref(out_size),
+                                            ctypes.byref(handles),
+                                            ctypes.byref(out_name_size),
+                                            ctypes.byref(names)))
+    if out_name_size.value == 0:
+        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in 
range(out_size.value)]
+    else:
+        assert out_name_size.value == out_size.value
+        return dict(
+            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
+            for i in range(out_size.value))
+
+
 def save(fname, data):
     """Saves a list of arrays or a dict of str->array to file.
 
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 3ef0043..b41a142 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -322,6 +322,40 @@ int MXNDArrayLoad(const char* fname,
   API_END();
 }
 
+int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
+                            size_t size,
+                            mx_uint *out_size,
+                            NDArrayHandle** out_arr,
+                            mx_uint *out_name_size,
+                            const char*** out_names) {
+  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  ret->ret_vec_str.clear();
+  API_BEGIN();
+  CHECK_NOTNULL(ndarray_buffer);
+  std::vector<NDArray> data;
+  std::vector<std::string> &names = ret->ret_vec_str;
+  {
+    std::unique_ptr<dmlc::MemoryFixedSizeStream> fi(new 
dmlc::MemoryFixedSizeStream(
+        const_cast<void*>(ndarray_buffer), size));
+    mxnet::NDArray::Load(fi.get(), &data, &names);
+  }
+  ret->ret_handles.resize(data.size());
+  for (size_t i = 0; i < data.size(); ++i) {
+    NDArray *ptr = new NDArray();
+    *ptr = data[i];
+    ret->ret_handles[i] = ptr;
+  }
+  ret->ret_vec_charp.resize(names.size());
+  for (size_t i = 0; i < names.size(); ++i) {
+    ret->ret_vec_charp[i] = names[i].c_str();
+  }
+  *out_size = static_cast<mx_uint>(data.size());
+  *out_arr = dmlc::BeginPtr(ret->ret_handles);
+  *out_name_size = static_cast<mx_uint>(names.size());
+  *out_names = dmlc::BeginPtr(ret->ret_vec_charp);
+  API_END();
+}
+
 int MXNDArrayFree(NDArrayHandle handle) {
   API_BEGIN();
   delete static_cast<NDArray*>(handle);
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index 35c7e3f..635bdcc 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -19,6 +19,7 @@ import sys, os, logging
 import mxnet as mx
 import numpy as np
 import random
+import shutil
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.append(os.path.join(curr_path, '../common/'))
 sys.path.insert(0, os.path.join(curr_path, '../../../python'))
@@ -26,6 +27,7 @@ sys.path.insert(0, os.path.join(curr_path, '../../../python'))
 import models
 from contextlib import contextmanager
 from nose.tools import make_decorator
+import tempfile
 
 def assertRaises(expected_exception, func, *args, **kwargs):
     try:
@@ -225,3 +227,17 @@ def setup_module():
     #  the 'with_seed()' decoration.  Inform the user of this once here at the 
module level.
     if os.getenv('MXNET_TEST_SEED') is not None:
         logger.warn('*** test-level seed set: all "@with_seed()" tests run 
deterministically ***')
+
+try:
+    from tempfile import TemporaryDirectory
+except:
+    # really simple implementation of TemporaryDirectory
+    class TemporaryDirectory(object):
+        def __init__(self, suffix='', prefix='', dir=''):
+            self._dirname = tempfile.mkdtemp(suffix, prefix, dir)
+
+        def __enter__(self):
+            return self._dirname
+
+        def __exit__(self, exc_type, exc_value, traceback):
+            shutil.rmtree(self._dirname)
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 52697c2..0daf74a 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -21,7 +21,7 @@ import os
 import pickle as pkl
 import unittest
 from nose.tools import raises
-from common import setup_module, with_seed
+from common import setup_module, with_seed, assertRaises, TemporaryDirectory
 from mxnet.test_utils import almost_equal
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import default_context
@@ -292,6 +292,52 @@ def test_ndarray_legacy_load():
 
 
 @with_seed()
+def test_buffer_load():
+    nrepeat = 10
+    with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
+        for repeat in range(nrepeat):
+            # test load_buffer as list
+            data = []
+            for i in range(10):
+                data.append(random_ndarray(np.random.randint(1, 5)))
+            fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
+            mx.nd.save(fname, data)
+            with open(fname, 'rb') as dfile:
+                buf_data = dfile.read()
+                data2 = mx.nd.load_frombuffer(buf_data)
+                assert len(data) == len(data2)
+                for x, y in zip(data, data2):
+                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_data[:-10])
+            # test load_buffer as dict
+            dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)}
+            fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
+            mx.nd.save(fname, dmap)
+            with open(fname, 'rb') as dfile:
+                buf_dmap = dfile.read()
+                dmap2 = mx.nd.load_frombuffer(buf_dmap)
+                assert len(dmap2) == len(dmap)
+                for k, x in dmap.items():
+                    y = dmap2[k]
+                    assert np.sum(x.asnumpy() != y.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_dmap[:-10])
+
+            # we expect the single ndarray to be converted into a list 
containing the ndarray
+            single_ndarray = data[0]
+            fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
+            mx.nd.save(fname, single_ndarray)
+            with open(fname, 'rb') as dfile:
+                buf_single_ndarray = dfile.read()
+                single_ndarray_loaded = 
mx.nd.load_frombuffer(buf_single_ndarray)
+                assert len(single_ndarray_loaded) == 1
+                single_ndarray_loaded = single_ndarray_loaded[0]
+                assert np.sum(single_ndarray.asnumpy() != 
single_ndarray_loaded.asnumpy()) == 0
+                # test garbage values
+                assertRaises(mx.base.MXNetError,  mx.nd.load_frombuffer, 
buf_single_ndarray[:-10])
+
+@with_seed()
 def test_ndarray_slice():
     shape = (10,)
     A = mx.nd.array(np.random.uniform(-10, 10, shape))

-- 
To stop receiving notification emails like this one, please contact
cjolivie...@apache.org.

Reply via email to