leezu commented on a change in pull request #19748:
URL: https://github.com/apache/incubator-mxnet/pull/19748#discussion_r558845282
##########
File path: tests/python/unittest/test_gluon_data.py
##########
@@ -158,10 +158,11 @@ def __getitem__(self, key):
def test_multi_worker():
data = Dataset()
for thread_pool in [True, False]:
- loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5,
thread_pool=thread_pool)
- for i, batch in enumerate(loader):
- assert (batch.asnumpy() == i).all()
-
+ for auto_reload in [True, False]:
Review comment:
Use pytest.mark.parametrize instead of a for loop
##########
File path: python/mxnet/gluon/data/dataloader.py
##########
@@ -655,3 +655,149 @@ def __del__(self):
# https://bugs.python.org/issue34172
assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
self._worker_pool.terminate()
+
+class PrefetchedDataLoader(DataLoader):
+ """Prefetch data from a dataset and returns mini-batches of data.
+ The prefetch performed immeditely after it yield last item from its
iterator.
+ When generate multiple iterator from PrefetchedDataLoader, only the first
iter
+ is the prefetched iter, and PrefetchedDataLoader will not prefetch any data
+ if its prefetched iter is not being used.
+
+ Example:
+ >>> from mxnet.gluon.data import PrefetchedDataLoader, ArrayDataset
+ >>> train_data = ArrayDataset([i for i in range(10)],[9-i for i in
range(10)])
+ >>> def transform_train(sample):
+ ... if sample == 0 : print('(pre)fetching data here')
+ ... return sample
+ ...
+ >>> train_iter =
PrefetchedDataLoader(train_data.transform_first(transform_train),
+ ... batch_size=1,num_workers=1)
+ (pre)fetching data here
+ >>> it = iter(train_iter) # nothing is generated since lazy-evaluation
occurs
+ >>> it2 = iter(train_iter)
+ >>> it3 = iter(train_iter)
+ >>> it4 = iter(train_iter)
+ >>> _ = next(it2) # the first iter we are using is the prefetched iter.
+ >>> _ = next(it) # since the prefetched iter is cconsumed, we have to
fetch data for `it`.
+ (pre)fetching data here
+ >>> _ = [None for _ in it3]
+ (pre)fetching data here
+ (pre)fetching data here
+ >>> # Here, 2 prefetches are triggered, one is fetching the first batch of
`it3` and
+ >>> # another is when `it3` yield its last item, a prefetch is
automatically performed.
+ >>> _ = [None for _ in it]
+ >>> # no prefetch is happened since train_loader has already prefetch data.
+ >>> _ = next(it4)
+ >>> # since the prefetch is performed, it4 become the prefetched iter.
+ >>>
+ >>> test_data = ArrayDataset([i for i in range(10)],[9-i for i in
range(10)])
+ >>> test_iter = PrefetchedDataLoader(test_data,
+ ... batch_size=1,num_workers=1)
+ >>> for epoch in range(200):
+ ... # there is almost no difference between it and the default DataLoader
+ ... for data, label in train_iter:
+ ... # training...
+ ... for data, label in test_iter:
+ ... # testing...
+ ...
+
+ Parameters
+ ----------
+ dataset : Dataset
+ Source dataset. Note that numpy and mxnet arrays can be directly used
+ as a Dataset.
+ batch_size : int
+ Size of mini-batch.
+ shuffle : bool
+ Whether to shuffle the samples.
+ sampler : Sampler
+ The sampler to use. Either specify sampler or shuffle, not both.
+ last_batch : {'keep', 'discard', 'rollover'}
+ How to handle the last batch if batch_size does not evenly divide
+ `len(dataset)`.
+
+ keep - A batch with less samples than previous batches is returned.
+ discard - The last batch is discarded if its incomplete.
+ rollover - The remaining samples are rolled over to the next epoch.
+ batch_sampler : Sampler
+ A sampler that returns mini-batches. Do not specify batch_size,
+ shuffle, sampler, and last_batch if batch_sampler is specified.
+ batchify_fn : callable
+ Callback function to allow users to specify how to merge samples
+ into a batch. Defaults to `default_batchify_fn`::
+
+ def default_batchify_fn(data):
+ if isinstance(data[0], nd.NDArray):
+ return nd.stack(*data)
+ elif isinstance(data[0], tuple):
+ data = zip(*data)
+ return [default_batchify_fn(i) for i in data]
+ else:
+ data = np.asarray(data)
+ return nd.array(data, dtype=data.dtype)
+
+ num_workers : int, default 0
+ The number of multiprocessing workers to use for data preprocessing.
+ pin_memory : boolean, default False
+ If ``True``, the dataloader will copy NDArrays into pinned memory
+ before returning them. Copying from CPU pinned memory to GPU is faster
+ than from normal CPU memory.
+ pin_device_id : int, default 0
+ The device id to use for allocating pinned memory if pin_memory is
``True``
+ prefetch : int, default is `num_workers * 2`
+ The number of prefetching batches only works if `num_workers` > 0.
+ If `prefetch` > 0, it allow worker process to prefetch certain batches
before
+ acquiring data from iterators.
+ Note that using large prefetching batch will provide smoother
bootstrapping performance,
+ but will consume more shared_memory. Using smaller number may forfeit
the purpose of using
+ multiple worker processes, try reduce `num_workers` in this case.
+ By default it defaults to `num_workers * 2`.
+ thread_pool : bool, default False
+ If ``True``, use threading pool instead of multiprocessing pool. Using
threadpool
+ can avoid shared memory usage. If `DataLoader` is more IO bounded or
GIL is not a killing
+ problem, threadpool version may achieve better performance than
multiprocessing.
+ timeout : int, default is 120
+ The timeout in seconds for each worker to fetch a batch data. Only
modify this number
+ unless you are experiencing timeout and you know it's due to slow data
loading.
+ Sometimes full `shared_memory` will cause all workers to hang and
causes timeout. In these
+ cases please reduce `num_workers` or increase system `shared_memory`
size instead.
+ """
+ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
+ last_batch=None, batch_sampler=None, batchify_fn=None,
+ num_workers=0, pin_memory=False, pin_device_id=0,
+ prefetch=None, thread_pool=False, timeout=120):
+ super(PrefetchedDataLoader, self).\
+ __init__(dataset, batch_size, shuffle, sampler,
+ last_batch, batch_sampler, batchify_fn,
+ num_workers, pin_memory, pin_device_id,
+ prefetch, thread_pool, timeout)
+ self.refresh()
+
+ def __iter__(self):
+ if self._iter is None:
+ self.refresh()
+ t = self._iter
+ self._iter = None # ensure a single iter would not using twice.
+ for item in t:
+ yield item
+ if self._iter is None: # ensure we do not waste any exist iter by
mistake
+ self.refresh()
+
+ def refresh(self):
Review comment:
Why is changing the behavior of MXNet 1.x to match 2.x not a breaking
change for 1.x?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]