ThomasDelteil closed pull request #10628: [MXNET-342] Fix the multi worker
Dataloader
URL: https://github.com/apache/incubator-mxnet/pull/10628
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/gluon/data/dataloader.py
b/python/mxnet/gluon/data/dataloader.py
index 7f09e286742..5601b1b68ea 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -30,6 +30,7 @@
from . import sampler as _sampler
from ... import nd, context
+from . import RecordFileDataset
def rebuild_ndarray(*args):
@@ -112,6 +113,9 @@ def default_mp_batchify_fn(data):
def worker_loop(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
+ if isinstance(dataset, RecordFileDataset):
+ dataset.reload_recordfile()
+
while True:
idx, samples = key_queue.get()
if idx is None:
diff --git a/python/mxnet/gluon/data/dataset.py
b/python/mxnet/gluon/data/dataset.py
index bf5fa0a6d1e..6cf874fc4a0 100644
--- a/python/mxnet/gluon/data/dataset.py
+++ b/python/mxnet/gluon/data/dataset.py
@@ -173,8 +173,15 @@ class RecordFileDataset(Dataset):
Path to rec file.
"""
def __init__(self, filename):
- idx_file = os.path.splitext(filename)[0] + '.idx'
- self._record = recordio.MXIndexedRecordIO(idx_file, filename, 'r')
+ self._filename = filename
+ self.reload_recordfile()
+
+ def reload_recordfile(self):
+ """
+ Reload the record file to open a new file description
+ """
+ idx_file = os.path.splitext(self._filename)[0] + '.idx'
+ self._record = recordio.MXIndexedRecordIO(idx_file, self._filename,
'r')
def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
diff --git a/tests/python/unittest/test_gluon_data.py
b/tests/python/unittest/test_gluon_data.py
index 93160aa0940..faa6197257d 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -72,6 +72,18 @@ def test_recordimage_dataset():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
+@with_seed()
+def test_recordimage_dataset_with_data_loader_multiworker():
+ # This test is pointless on Windows because Windows doesn't fork
+ if platform.system() != 'Windows':
+ recfile = prepare_record()
+ dataset = gluon.data.vision.ImageRecordDataset(recfile)
+ loader = gluon.data.DataLoader(dataset, 1, num_workers=5)
+
+ for i, (x, y) in enumerate(loader):
+ assert x.shape[0] == 1 and x.shape[3] == 3
+ assert y.asscalar() == i
+
@with_seed()
def test_sampler():
seq_sampler = gluon.data.SequentialSampler(10)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services