zhreshold commented on a change in pull request #17841:
URL: https://github.com/apache/incubator-mxnet/pull/17841#discussion_r416866728



##########
File path: python/mxnet/gluon/data/dataloader.py
##########
@@ -655,3 +680,119 @@ def __del__(self):
             # https://bugs.python.org/issue34172
             assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
             self._worker_pool.terminate()
+
+def _check_mx_loader_capability(dataset, batch_sampler, batchify_fn):
+    from ._internal import MXDataset, MXSampler
+    from ._internal import MXBatchifyFunction
+    mx_loader_args = {}
+    error_template = "MXNet backend loader compatibility: " \
+        "[dataset - {}][batchify_fn - {}][batch sampler - {}]"
+
+    # supported dataset
+    if isinstance(dataset, MXDataset):
+        mx_loader_args['dataset'] = dataset
+    elif hasattr(dataset, '__mx_handle__'):
+        try:
+            mx_loader_args['dataset'] = dataset.__mx_handle__()
+        except NotImplementedError:
+            return False, error_template.format('fail', 'unknown', 'unknown')
+    else:
+        return False, error_template.format('fail', 'unknown', 'unknown')
+
+    # supported batchify functions
+    if hasattr(batchify_fn, '__mx_handle__'):
+        mx_loader_args['batchify_fn'] = batchify_fn.__mx_handle__()
+    elif isinstance(batchify_fn, MXBatchifyFunction):
+        mx_loader_args['batchify_fn'] = batchify_fn
+    else:
+        return False, error_template.format('pass', 'fail', 'unknown')
+
+    # supported sampler
+    if isinstance(batch_sampler, _sampler.BatchSampler):
+        if isinstance(batch_sampler._sampler, _sampler.SequentialSampler):
+            mx_loader_args['batch_sampler'] = MXSampler(
+                'SequentialSampler', length=batch_sampler._sampler._length,
+                start=batch_sampler._sampler._start,
+                batch_size=batch_sampler._batch_size,
+                last_batch=batch_sampler._last_batch)
+        elif isinstance(batch_sampler._sampler, _sampler.RandomSampler):
+            mx_loader_args['batch_sampler'] = MXSampler(
+                'RandomSampler', length=batch_sampler._sampler._length,
+                batch_size=batch_sampler._batch_size,
+                last_batch=batch_sampler._last_batch)
+        else:
+            return False, error_template.format('pass', 'pass', 'fail')
+    elif isinstance(batch_sampler, MXSampler):
+        mx_loader_args['batch_sampler'] = batch_sampler
+    else:
+        return False, error_template.format('pass', 'pass', 'fail')
+    # all good
+    return True, mx_loader_args
+
+
+class MXThreadedDataLoader(object):

Review comment:
       Thanks, changed to private




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to