This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 52116d4  Fix lazy record io when used with dataloader and multi_worker 
> 0 (#12554)
52116d4 is described below

commit 52116d4d3205361f861141013eb0de40be003e3c
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Thu Sep 13 15:22:14 2018 -0700

    Fix lazy record io when used with dataloader and multi_worker > 0 (#12554)
    
    * temp solution to record file dataset with multi worker
    
    * fix cascaded dataset for gluon dataloader, when multi_worker > 0 is used
---
 python/mxnet/gluon/data/dataloader.py    | 20 ++++++++++++++++++--
 python/mxnet/gluon/data/dataset.py       |  8 --------
 tests/python/unittest/test_gluon_data.py | 12 +++++++++++-
 3 files changed, 29 insertions(+), 11 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py 
b/python/mxnet/gluon/data/dataloader.py
index 412d313..1c54158 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -36,6 +36,7 @@ except ImportError:
 
 from . import sampler as _sampler
 from ... import nd, context
+from ...recordio import MXRecordIO
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
     def rebuild_ndarray(*args):
@@ -158,10 +159,24 @@ def _as_in_context(data, ctx):
         return [_as_in_context(d, ctx) for d in data]
     return data
 
+def _recursive_fork_recordio(obj, depth, max_depth=1000):
+    """Recursively find instance of MXRecordIO and reset file handler.
+    This is required for MXRecordIO which holds a C pointer to a opened file 
after fork.
+    """
+    if depth >= max_depth:
+        return
+    if isinstance(obj, MXRecordIO):
+        obj.close()
+        obj.open()  # re-obtain file hanlder in new process
+    elif (hasattr(obj, '__dict__')):
+        for _, v in obj.__dict__.items():
+            _recursive_fork_recordio(v, depth + 1, max_depth)
+
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
-    if hasattr(dataset, '_fork') and callable(dataset._fork):
-        dataset._fork()
+    # re-fork a new recordio handler in new process if applicable
+    _recursive_fork_recordio(dataset, 0, 1000)
+
     while True:
         idx, samples = key_queue.get()
         if idx is None:
@@ -181,6 +196,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False):
             batch = _as_in_context(batch, context.cpu())
         data_buffer[idx] = batch
 
+
 class _MultiWorkerIter(object):
     """Interal multi-worker iterator for DataLoader."""
     def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, 
pin_memory=False,
diff --git a/python/mxnet/gluon/data/dataset.py 
b/python/mxnet/gluon/data/dataset.py
index 13e2b57..c93a4b1 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -94,11 +94,6 @@ class Dataset(object):
             return fn(x)
         return self.transform(base_fn, lazy)
 
-    def _fork(self):
-        """Protective operations required when launching multiprocess 
workers."""
-        # for non file descriptor related datasets, just skip
-        pass
-
 
 class SimpleDataset(Dataset):
     """Simple Dataset wrapper for lists and arrays.
@@ -180,9 +175,6 @@ class RecordFileDataset(Dataset):
     def __init__(self, filename):
         self.idx_file = os.path.splitext(filename)[0] + '.idx'
         self.filename = filename
-        self._fork()
-
-    def _fork(self):
         self._record = recordio.MXIndexedRecordIO(self.idx_file, 
self.filename, 'r')
 
     def __getitem__(self, idx):
diff --git a/tests/python/unittest/test_gluon_data.py 
b/tests/python/unittest/test_gluon_data.py
index 53ce600..cc80aac 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -65,7 +65,8 @@ def prepare_record():
 @with_seed()
 def test_recordimage_dataset():
     recfile = prepare_record()
-    dataset = gluon.data.vision.ImageRecordDataset(recfile)
+    fn = lambda x, y : (x, y)
+    dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
     loader = gluon.data.DataLoader(dataset, 1)
 
     for i, (x, y) in enumerate(loader):
@@ -84,6 +85,15 @@ def test_recordimage_dataset_with_data_loader_multiworker():
             assert x.shape[0] == 1 and x.shape[3] == 3
             assert y.asscalar() == i
 
+        # with transform
+        fn = lambda x, y : (x, y)
+        dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
+        loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+        for i, (x, y) in enumerate(loader):
+            assert x.shape[0] == 1 and x.shape[3] == 3
+            assert y.asscalar() == i
+
 @with_seed()
 def test_sampler():
     seq_sampler = gluon.data.SequentialSampler(10)

Reply via email to