This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 52116d4 Fix lazy record io when used with dataloader and multi_worker
> 0 (#12554)
52116d4 is described below
commit 52116d4d3205361f861141013eb0de40be003e3c
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Thu Sep 13 15:22:14 2018 -0700
Fix lazy record io when used with dataloader and multi_worker > 0 (#12554)
* temp solution to record file dataset with multi worker
* fix cascaded dataset for gluon dataloader, when multi_worker > 0 is used
---
python/mxnet/gluon/data/dataloader.py | 20 ++++++++++++++++++--
python/mxnet/gluon/data/dataset.py | 8 --------
tests/python/unittest/test_gluon_data.py | 12 +++++++++++-
3 files changed, 29 insertions(+), 11 deletions(-)
diff --git a/python/mxnet/gluon/data/dataloader.py
b/python/mxnet/gluon/data/dataloader.py
index 412d313..1c54158 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -36,6 +36,7 @@ except ImportError:
from . import sampler as _sampler
from ... import nd, context
+from ...recordio import MXRecordIO
if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
@@ -158,10 +159,24 @@ def _as_in_context(data, ctx):
return [_as_in_context(d, ctx) for d in data]
return data
+def _recursive_fork_recordio(obj, depth, max_depth=1000):
+ """Recursively find instance of MXRecordIO and reset file handler.
+ This is required for MXRecordIO which holds a C pointer to a opened file
after fork.
+ """
+ if depth >= max_depth:
+ return
+ if isinstance(obj, MXRecordIO):
+ obj.close()
+ obj.open() # re-obtain file hanlder in new process
+ elif (hasattr(obj, '__dict__')):
+ for _, v in obj.__dict__.items():
+ _recursive_fork_recordio(v, depth + 1, max_depth)
+
def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
- if hasattr(dataset, '_fork') and callable(dataset._fork):
- dataset._fork()
+ # re-fork a new recordio handler in new process if applicable
+ _recursive_fork_recordio(dataset, 0, 1000)
+
while True:
idx, samples = key_queue.get()
if idx is None:
@@ -181,6 +196,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False):
batch = _as_in_context(batch, context.cpu())
data_buffer[idx] = batch
+
class _MultiWorkerIter(object):
"""Interal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler,
pin_memory=False,
diff --git a/python/mxnet/gluon/data/dataset.py
b/python/mxnet/gluon/data/dataset.py
index 13e2b57..c93a4b1 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -94,11 +94,6 @@ class Dataset(object):
return fn(x)
return self.transform(base_fn, lazy)
- def _fork(self):
- """Protective operations required when launching multiprocess
workers."""
- # for non file descriptor related datasets, just skip
- pass
-
class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
@@ -180,9 +175,6 @@ class RecordFileDataset(Dataset):
def __init__(self, filename):
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
- self._fork()
-
- def _fork(self):
self._record = recordio.MXIndexedRecordIO(self.idx_file,
self.filename, 'r')
def __getitem__(self, idx):
diff --git a/tests/python/unittest/test_gluon_data.py
b/tests/python/unittest/test_gluon_data.py
index 53ce600..cc80aac 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -65,7 +65,8 @@ def prepare_record():
@with_seed()
def test_recordimage_dataset():
recfile = prepare_record()
- dataset = gluon.data.vision.ImageRecordDataset(recfile)
+ fn = lambda x, y : (x, y)
+ dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
loader = gluon.data.DataLoader(dataset, 1)
for i, (x, y) in enumerate(loader):
@@ -84,6 +85,15 @@ def test_recordimage_dataset_with_data_loader_multiworker():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
+ # with transform
+ fn = lambda x, y : (x, y)
+ dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(fn)
+ loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+ for i, (x, y) in enumerate(loader):
+ assert x.shape[0] == 1 and x.shape[3] == 3
+ assert y.asscalar() == i
+
@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)