This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 78fc40f Fix serialization bug for writing large arrays to npz (#19596)
78fc40f is described below
commit 78fc40f68e062258f6897987e8b93686f47af85a
Author: Leonard Lausen <[email protected]>
AuthorDate: Mon Nov 30 16:10:35 2020 -0700
Fix serialization bug for writing large arrays to npz (#19596)
---
src/serialization/cnpy.cc | 3 ++-
tests/python/unittest/test_numpy_ndarray.py | 10 ++++++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/src/serialization/cnpy.cc b/src/serialization/cnpy.cc
index 67abb99..f9532f8 100644
--- a/src/serialization/cnpy.cc
+++ b/src/serialization/cnpy.cc
@@ -315,7 +315,8 @@ size_t npy_header_blob_read_callback(void *pOpaque,
mz_uint64 file_ofs, void *pB
std::memcpy(pBuf_blob, blob->dptr_, n - npy_header_n);
} else {
// Read n bytes from blob
- const void* pSrc = static_cast<const
void*>(static_cast<char*>(blob->dptr_) + file_ofs);
+ const void* pSrc = static_cast<const void*>(
+ static_cast<char*>(blob->dptr_) + file_ofs - npy_header->size());
std::memcpy(pBuf, pSrc, n);
}
return n;
diff --git a/tests/python/unittest/test_numpy_ndarray.py
b/tests/python/unittest/test_numpy_ndarray.py
index 1775397..fa189f8 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -1000,6 +1000,16 @@ def test_np_ndarray_indexing():
@use_np
[email protected]('load_fn', [_np.load, npx.load])
+def test_np_save_load_large_ndarrays(load_fn, tmp_path):
+ weight = mx.np.arange(32768 * 512).reshape((32768, 512))
+ mx.npx.savez(str(tmp_path / 'params.npz'), weight=weight)
+ arr_loaded = load_fn(str(tmp_path / 'params.npz'))['weight']
+ assert _np.array_equal(arr_loaded.asnumpy() if load_fn is npx.load
+ else arr_loaded, weight)
+
+
+@use_np
@pytest.mark.serial
@pytest.mark.parametrize('load_fn', [_np.load, npx.load])
def test_np_save_load_ndarrays(load_fn):