marcoabreu commented on a change in pull request #10090: [MXNET-86] Revert to
pre-profile-changes copy code
URL: https://github.com/apache/incubator-mxnet/pull/10090#discussion_r174304377
##########
File path: tests/python/unittest/test_gluon_data.py
##########
@@ -112,6 +117,76 @@ def test_multi_worker():
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()
+@with_seed()
+def test_multi_worker_data_loader():
+ class Dummy(Dataset):
+ def __init__(self, random_shape):
+ self.random_shape = random_shape
+
+ def __getitem__(self, idx):
+ key = idx
+ if self.random_shape:
+ out = np.random.uniform(size=(random.randint(1000, 1100), 40))
+ labels = np.random.uniform(size=(random.randint(10, 15)))
+ else:
+ out = np.random.uniform(size=(1000, 40))
+ labels = np.random.uniform(size=(10))
+ return key, out, labels
+
+ def __len__(self):
+ return 50
+
+ def batchify(self, data):
+ """
+ Collate data into batch. Use shared memory for stacking.
+
+ :param data: a list of array, with layout of 'NTC'.
+ :return either x and x's unpadded lengths, or x, x's unpadded
lengths, y and y's unpadded lengths
+ if labels are not supplied.
+ """
+
+ # input layout is NTC
+ keys, inputs, labels = [item[0] for item in data], [item[1] for
item in data], \
+ [item[2] for item in data]
+
+ if len(data) > 1:
+ max_data_len = max([seq.shape[0] for seq in inputs])
+ max_labels_len = 0 if not labels else max([seq.shape[0] for
seq in labels])
+ else:
+ max_data_len = inputs[0].shape[0]
+ max_labels_len = 0 if not labels else labels[0].shape[0]
+
+ x_lens = [item.shape[0] for item in inputs]
+ y_lens = [item.shape[0] for item in labels]
+
+ for i, seq in enumerate(inputs):
+ pad_len = max_data_len - seq.shape[0]
+ inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant',
constant_values=0)
+ labels[i] = np.pad(labels[i], (0, max_labels_len -
labels[i].shape[0]),
+ 'constant', constant_values=-1)
+
+ inputs = np.asarray(inputs, dtype=np.float32)
+ if labels is not None:
+ labels = np.asarray(labels, dtype=np.float32)
+ inputs = inputs.transpose((1, 0, 2))
+ labels = labels.transpose((1, 0))
+
+ return (nd.array(inputs, dtype=inputs.dtype,
ctx=context.Context('cpu_shared', 0)),
Review comment:
I guess the context is not really important here, right?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services