This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 78fc40f  Fix serialization bug for writing large arrays to npz (#19596)
78fc40f is described below

commit 78fc40f68e062258f6897987e8b93686f47af85a
Author: Leonard Lausen <[email protected]>
AuthorDate: Mon Nov 30 16:10:35 2020 -0700

    Fix serialization bug for writing large arrays to npz (#19596)
---
 src/serialization/cnpy.cc                   |  3 ++-
 tests/python/unittest/test_numpy_ndarray.py | 10 ++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)

diff --git a/src/serialization/cnpy.cc b/src/serialization/cnpy.cc
index 67abb99..f9532f8 100644
--- a/src/serialization/cnpy.cc
+++ b/src/serialization/cnpy.cc
@@ -315,7 +315,8 @@ size_t npy_header_blob_read_callback(void *pOpaque, 
mz_uint64 file_ofs, void *pB
         std::memcpy(pBuf_blob, blob->dptr_, n - npy_header_n);
     } else {
         // Read n bytes from blob
-        const void* pSrc = static_cast<const 
void*>(static_cast<char*>(blob->dptr_) + file_ofs);
+        const void* pSrc = static_cast<const void*>(
+            static_cast<char*>(blob->dptr_) + file_ofs - npy_header->size());
         std::memcpy(pBuf, pSrc, n);
     }
     return n;
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 1775397..fa189f8 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -1000,6 +1000,16 @@ def test_np_ndarray_indexing():
 
 
 @use_np
[email protected]('load_fn', [_np.load, npx.load])
+def test_np_save_load_large_ndarrays(load_fn, tmp_path):
+    weight = mx.np.arange(32768 * 512).reshape((32768, 512))
+    mx.npx.savez(str(tmp_path / 'params.npz'), weight=weight)
+    arr_loaded = load_fn(str(tmp_path / 'params.npz'))['weight']
+    assert _np.array_equal(arr_loaded.asnumpy() if load_fn is npx.load
+                           else arr_loaded, weight)
+
+
+@use_np
 @pytest.mark.serial
 @pytest.mark.parametrize('load_fn', [_np.load, npx.load])
 def test_np_save_load_ndarrays(load_fn):

Reply via email to