This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 318c689  fix shared_storage free (#11159)
318c689 is described below

commit 318c6899031d024acb28a4937678b85f141f98d8
Author: Joshua Z. Zhang <cheungc...@gmail.com>
AuthorDate: Tue Jun 5 22:57:51 2018 -0700

    fix shared_storage free (#11159)
    
    * fix shared_storage free
    
    * fix bracket
    
    * make local ref
    
    * cpplint
    
    * fix tests
    
    * fix tests
---
 python/mxnet/gluon/data/dataloader.py    |  2 ++
 src/storage/cpu_shared_storage_manager.h | 10 ++++++++--
 tests/python/unittest/test_gluon_data.py | 18 ++++++++++++++++++
 tests/python/unittest/test_ndarray.py    |  1 -
 4 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py 
b/python/mxnet/gluon/data/dataloader.py
index 151b49d..29b9b81 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -57,6 +57,8 @@ else:
 
     def reduce_ndarray(data):
         """Reduce ndarray to shared memory handle"""
+        # keep a local ref before duplicating fd
+        data = data.as_in_context(context.Context('cpu_shared', 0))
         pid, fd, shape, dtype = data._to_shared_mem()
         if sys.version_info[0] == 2:
             fd = multiprocessing.reduction.reduce_handle(fd)
diff --git a/src/storage/cpu_shared_storage_manager.h 
b/src/storage/cpu_shared_storage_manager.h
index 85c6a35..a52d779 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -174,8 +174,12 @@ void CPUSharedStorageManager::Alloc(Storage::Handle* 
handle) {
   }
 
   if (fid == -1) {
-    LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
-               << strerror(errno);
+    if (is_new) {
+      LOG(FATAL) << "Failed to open shared memory. shm_open failed with error "
+                 << strerror(errno);
+    } else {
+      LOG(FATAL) << "Invalid file descriptor from shared array.";
+    }
   }
 
   if (is_new) CHECK_EQ(ftruncate(fid, size), 0);
@@ -216,9 +220,11 @@ void CPUSharedStorageManager::FreeImpl(const 
Storage::Handle& handle) {
       << strerror(errno);
 
 #ifdef __linux__
+  if (handle.shared_id != -1) {
   CHECK_EQ(close(handle.shared_id), 0)
       << "Failed to close shared memory. close failed with error "
       << strerror(errno);
+  }
 #else
   if (count == 0) {
     auto filename = SharedHandleToString(handle.shared_pid, handle.shared_id);
diff --git a/tests/python/unittest/test_gluon_data.py 
b/tests/python/unittest/test_gluon_data.py
index 93160aa..751886b 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -140,6 +140,16 @@ def test_multi_worker_forked_data_loader():
         def __len__(self):
             return 50
 
+        def batchify_list(self, data):
+            """
+            return list of ndarray without stack/concat/pad
+            """
+            if isinstance(data, (tuple, list)):
+                return list(data)
+            if isinstance(data, mx.nd.NDArray):
+                return [data]
+            return data
+
         def batchify(self, data):
             """
             Collate data into batch. Use shared memory for stacking.
@@ -194,6 +204,14 @@ def test_multi_worker_forked_data_loader():
                     print(data)
                     print('{}:{}'.format(epoch, i))
 
+        data = Dummy(True)
+        loader = DataLoader(data, batch_size=40, 
batchify_fn=data.batchify_list, num_workers=2)
+        for epoch in range(1):
+            for i, data in enumerate(loader):
+                if i % 100 == 0:
+                    print(data)
+                    print('{}:{}'.format(epoch, i))
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 496f80f..a060465 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1304,7 +1304,6 @@ def test_norm(ctx=default_context()):
             assert arr1.shape == arr2.shape
             mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
 
-
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to