zhreshold closed pull request #13209: Allow dataloader iterator to be reused
URL: https://github.com/apache/incubator-mxnet/pull/13209
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py 
b/python/mxnet/gluon/data/dataloader.py
index 86cb835f512..a23ba580dd3 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -29,6 +29,11 @@
 import threading
 import numpy as np
 
+try:
+    from Queue import Empty as QueueEmpty
+except ImportError:
+    from queue import Empty as QueueEmpty
+
 try:
     import multiprocessing.resource_sharer
 except ImportError:
@@ -36,7 +41,6 @@
 
 from . import sampler as _sampler
 from ... import nd, context
-from ...recordio import MXRecordIO
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
     def rebuild_ndarray(*args):
@@ -159,29 +163,9 @@ def _as_in_context(data, ctx):
         return [_as_in_context(d, ctx) for d in data]
     return data
 
-def _recursive_fork_recordio(obj, depth, max_depth=1000):
-    """Recursively find instance of MXRecordIO and reset file handler.
-    This is required for MXRecordIO which holds a C pointer to a opened file 
after fork.
-    """
-    if depth >= max_depth:
-        return
-    if isinstance(obj, MXRecordIO):
-        obj.close()
-        obj.open()  # re-obtain file hanlder in new process
-    elif (hasattr(obj, '__dict__')):
-        for _, v in obj.__dict__.items():
-            _recursive_fork_recordio(v, depth + 1, max_depth)
 
 def worker_loop(dataset, key_queue, data_queue, batchify_fn):
     """Worker loop for multiprocessing DataLoader."""
-    # re-fork a new recordio handler in new process if applicable
-    # for a dataset with transform function, the depth of MXRecordIO is 1
-    # for a lazy transformer, the depth is 2
-    # for a user defined transformer, the depth is unknown, try a reasonable 
depth
-    limit = sys.getrecursionlimit()
-    max_recursion_depth = min(limit - 5, max(10, limit // 2))
-    _recursive_fork_recordio(dataset, 0, max_recursion_depth)
-
     while True:
         idx, samples = key_queue.get()
         if idx is None:
@@ -207,7 +191,36 @@ def fetcher_loop(data_queue, data_buffer, 
pin_memory=False, data_buffer_lock=Non
 
 
 class _MultiWorkerIter(object):
-    """Interal multi-worker iterator for DataLoader."""
+    """Internal multi-worker iterator for DataLoader.
+    Re-acquire this iterator by `iter()` function will reset it if previous 
iteration is finished.
+    All workers are still alive in order to save re-initialization overhead.
+
+    Parameters
+    ----------
+    num_workers : int, default 0
+        The number of multiprocessing workers to use for data preprocessing.
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
+    worker_fn : callable
+        `worker_fn` is the multiprocess worker function to process data in 
worker processes.
+        It defaults to `worker_loop(dataset, key_queue, data_queue, 
batchify_fn)`.
+        `worker_fn` takes inputs of `dataset` for input data, `key_queue` for 
(idx, batch_sample)
+        from batch sampler, `data_queue` for storing processed batch data as 
`NDArray`, and
+        `batchify_fn` for explicit batching instructions.
+        It is not recommanded to customize `worker_fn` unless you have 
specific use cases.
+
+    """
     def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, 
pin_memory=False,
                  worker_fn=worker_loop):
         assert num_workers > 0, "_MultiWorkerIter is not for {} 
workers".format(num_workers)
@@ -225,6 +238,7 @@ def __init__(self, num_workers, dataset, batchify_fn, 
batch_sampler, pin_memory=
         self._sent_idx = 0
         self._iter = iter(self._batch_sampler)
         self._shutdown = False
+        self._stop = False
 
         workers = []
         for _ in range(self._num_workers):
@@ -242,18 +256,55 @@ def __init__(self, num_workers, dataset, batchify_fn, 
batch_sampler, pin_memory=
         self._fetcher.daemon = True
         self._fetcher.start()
 
-        # pre-fetch
-        for _ in range(2 * self._num_workers):
-            self._push_next()
+        self._reset()
 
     def __len__(self):
+        """Get length of iterator.
+
+        Returns
+        -------
+        int
+            Length of iterator, equals to batch_sampler length.
+
+        """
         return len(self._batch_sampler)
 
     def __del__(self):
         self.shutdown()
 
+    def _reset(self):
+        """Reset iterator with multiprocessing workers alive. Internal use. """
+        assert not self._shutdown, "call reset after shutdown is forbidden"
+        # clear key queue
+        removed_idx = set()
+        while True:
+            try:
+                idx, _ = self._key_queue.get(False)
+                removed_idx.add(idx)
+            except QueueEmpty:
+                break
+
+        # clear data queue
+        while self._rcvd_idx < self._sent_idx:
+            if self._rcvd_idx in removed_idx:
+                self._rcvd_idx += 1
+            elif self._rcvd_idx in self._data_buffer:
+                _ = self._data_buffer.pop(self._rcvd_idx)
+                self._rcvd_idx += 1
+        assert not self._data_buffer, "data buffer should be empty"
+
+        # reset indices and samples
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._stop = False
+
+        # pre-fetch
+        for _ in range(2 * self._num_workers):
+            self._push_next()
+
     def _push_next(self):
-        """Assign next batch workload to workers."""
+        """Assign next batch workload to workers. Internal use only. """
         r = next(self._iter, None)
         if r is None:
             return
@@ -261,10 +312,18 @@ def _push_next(self):
         self._sent_idx += 1
 
     def __next__(self):
+        """Return next sample, will raise `StopIteration` reaching end.
+
+        Returns
+        -------
+        NDArray
+            Batched sample data.
+
+        """
         assert not self._shutdown, "call __next__ after shutdown is forbidden"
         if self._rcvd_idx == self._sent_idx:
             assert not self._data_buffer, "Data buffer should be empty at this 
moment"
-            self.shutdown()
+            self._stop = True
             raise StopIteration
 
         while True:
@@ -276,13 +335,37 @@ def __next__(self):
                 return batch
 
     def next(self):
+        """Compatible portal for __next__ in python2.
+
+        Returns
+        -------
+        NDArray
+            Batched sample data.
+
+        """
         return self.__next__()
 
     def __iter__(self):
+        """Requiring iterator will reset current instance, but keep all workers
+        alive, thus save re-init time of forking processes.
+
+        Returns
+        -------
+        iterator
+            Iterator of self.
+
+        """
+        assert not self._shutdown, "get iterator after shutdown is forbidden"
+        if self._stop:
+            self._reset()
         return self
 
     def shutdown(self):
-        """Shutdown internal workers by pushing terminate signals."""
+        """
+        Shutdown internal workers by pushing terminate signals. Once shutdown,
+        you cannot use this instance again, you will need to obtain a new
+        _MultiWorkerIter by `iter(dataloader)`.
+        """
         if not self._shutdown:
             # send shutdown signal to the fetcher and join data queue first
             # Remark:   loop_fetcher need to be joined prior to the workers.
@@ -299,6 +382,96 @@ def shutdown(self):
             self._shutdown = True
 
 
+class _SameProcessIter(object):
+    """Same Process Iterator.
+    Re-acquire this iterator by `iter()` function will reset it if previous 
iteration is finished.
+
+    Parameters
+    ----------
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
+
+    """
+    def __init__(self, dataset, batchify_fn, batch_sampler, pin_memory=False):
+        self._dataset = dataset
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._pin_memory = pin_memory
+        self._stop = False
+        self._reset()
+
+    def __len__(self):
+        """Get length of iterator.
+
+        Returns
+        -------
+        int
+            Length of iterator, equals to batch_sampler length.
+
+        """
+        return len(self._batch_sampler)
+
+    def _reset(self):
+        """Reset iterator."""
+        self._iter = iter(self._batch_sampler)
+        self._stop = False
+
+    def __next__(self):
+        """Return next sample, will raise `StopIteration` reaching end.
+
+        Returns
+        -------
+        NDArray
+            Batched sample data.
+
+        """
+        try:
+            batch = next(self._iter)
+        except StopIteration:
+            self._stop = True
+            raise StopIteration
+        else:
+            ret = self._batchify_fn([self._dataset[idx] for idx in batch])
+            if self._pin_memory:
+                ret = _as_in_context(ret, context.cpu_pinned())
+            return ret
+
+    def next(self):
+        """Compatible portal for __next__ in python2.
+
+        Returns
+        -------
+        NDArray
+            Batched sample data.
+
+        """
+        return self.__next__()
+
+    def __iter__(self):
+        """Requiring iterator will reset current instance.
+
+        Returns
+        -------
+        iterator
+            Iterator of self.
+
+        """
+        if self._stop:
+            self._reset()
+        return self
+
+
 class DataLoader(object):
     """Loads data from a dataset and returns mini-batches of data.
 
@@ -381,13 +554,8 @@ def __init__(self, dataset, batch_size=None, 
shuffle=False, sampler=None,
 
     def __iter__(self):
         if self._num_workers == 0:
-            def same_process_iter():
-                for batch in self._batch_sampler:
-                    ret = self._batchify_fn([self._dataset[idx] for idx in 
batch])
-                    if self._pin_memory:
-                        ret = _as_in_context(ret, context.cpu_pinned())
-                    yield ret
-            return same_process_iter()
+            return _SameProcessIter(self._dataset, self._batchify_fn,
+                                    self._batch_sampler, self._pin_memory)
 
         # multi-worker
         return _MultiWorkerIter(self._num_workers, self._dataset,
diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py
index 2def141c934..bdc63235d70 100644
--- a/python/mxnet/recordio.py
+++ b/python/mxnet/recordio.py
@@ -18,6 +18,7 @@
 """Read and write for the RecordIO data format."""
 from __future__ import absolute_import
 from collections import namedtuple
+from multiprocessing import current_process
 
 import ctypes
 import struct
@@ -65,6 +66,7 @@ def __init__(self, uri, flag):
         self.uri = c_str(uri)
         self.handle = RecordIOHandle()
         self.flag = flag
+        self.pid = None
         self.is_open = False
         self.open()
 
@@ -78,6 +80,7 @@ def open(self):
             self.writable = False
         else:
             raise ValueError("Invalid flag %s"%self.flag)
+        self.pid = current_process().pid
         self.is_open = True
 
     def __del__(self):
@@ -109,6 +112,14 @@ def __setstate__(self, d):
         if is_open:
             self.open()
 
+    def _check_pid(self, allow_reset=False):
+        """Check process id to ensure integrity, reset if in new process."""
+        if not self.pid == current_process().pid:
+            if allow_reset:
+                self.reset()
+            else:
+                raise RuntimeError("Forbidden operation in multiple processes")
+
     def close(self):
         """Closes the record file."""
         if not self.is_open:
@@ -118,6 +129,7 @@ def close(self):
         else:
             check_call(_LIB.MXRecordIOReaderFree(self.handle))
         self.is_open = False
+        self.pid = None
 
     def reset(self):
         """Resets the pointer to first item.
@@ -156,6 +168,7 @@ def write(self, buf):
             Buffer to write.
         """
         assert self.writable
+        self._check_pid(allow_reset=False)
         check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
                                                     ctypes.c_char_p(buf),
                                                     ctypes.c_size_t(len(buf))))
@@ -182,6 +195,9 @@ def read(self):
             Buffer read.
         """
         assert not self.writable
+        # trying to implicitly read from multiple processes is forbidden,
+        # there's no elegant way to handle unless lock is introduced
+        self._check_pid(allow_reset=False)
         buf = ctypes.c_char_p()
         size = ctypes.c_size_t()
         check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
@@ -255,6 +271,7 @@ def seek(self, idx):
         This function is internally called by `read_idx(idx)` to find the 
current
         reader pointer position. It doesn't return anything."""
         assert not self.writable
+        self._check_pid(allow_reset=True)
         pos = ctypes.c_size_t(self.idx[idx])
         check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))
 
diff --git a/tests/python/unittest/test_gluon_data.py 
b/tests/python/unittest/test_gluon_data.py
index e4206095f9b..77f733669e9 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -244,6 +244,28 @@ def test_multi_worker_forked_data_loader():
         for i, data in enumerate(loader):
             pass
 
+
+class _SequentialDummyData(object):
+    def __len__(self):
+        return 100
+
+    def __getitem__(self, idx):
+        return idx
+
+@with_seed()
+def test_cached_iterator_in_dataloader():
+    data = _SequentialDummyData()
+    length = len(data)
+    expect = np.arange(length)
+    for num_worker in range(0, 4):
+        loader = DataLoader(data, batch_size=2, shuffle=False, 
num_workers=num_worker)
+        it = iter(loader)
+        out = []
+        for i, batch in enumerate(it):
+            print(i, batch)
+            out.append(batch.asnumpy().flatten())
+        np.testing.assert_allclose(np.concatenate(out), expect)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to