cjolivier01 commented on a change in pull request #10090: [MXNET-86] Revert to pre-profile-changes copy code URL: https://github.com/apache/incubator-mxnet/pull/10090#discussion_r174306094
########## File path: tests/python/unittest/test_gluon_data.py ########## @@ -112,6 +117,76 @@ def test_multi_worker(): for i, batch in enumerate(loader): assert (batch.asnumpy() == i).all() +@with_seed() +def test_multi_worker_data_loader(): + class Dummy(Dataset): + def __init__(self, random_shape): + self.random_shape = random_shape + + def __getitem__(self, idx): + key = idx + if self.random_shape: + out = np.random.uniform(size=(random.randint(1000, 1100), 40)) + labels = np.random.uniform(size=(random.randint(10, 15))) + else: + out = np.random.uniform(size=(1000, 40)) + labels = np.random.uniform(size=(10)) + return key, out, labels + + def __len__(self): + return 50 + + def batchify(self, data): + """ + Collate data into batch. Use shared memory for stacking. + + :param data: a list of array, with layout of 'NTC'. + :return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths + if labels are not supplied. + """ + + # input layout is NTC + keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \ + [item[2] for item in data] + + if len(data) > 1: + max_data_len = max([seq.shape[0] for seq in inputs]) + max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels]) + else: + max_data_len = inputs[0].shape[0] + max_labels_len = 0 if not labels else labels[0].shape[0] + + x_lens = [item.shape[0] for item in inputs] + y_lens = [item.shape[0] for item in labels] + + for i, seq in enumerate(inputs): + pad_len = max_data_len - seq.shape[0] + inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0) + labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]), + 'constant', constant_values=-1) + + inputs = np.asarray(inputs, dtype=np.float32) + if labels is not None: + labels = np.asarray(labels, dtype=np.float32) + inputs = inputs.transpose((1, 0, 2)) + labels = labels.transpose((1, 0)) + + return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), Review comment: Why would you think that? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services