marcoabreu commented on a change in pull request #10090: [MXNET-86] Revert to 
pre-profile-changes copy code
URL: https://github.com/apache/incubator-mxnet/pull/10090#discussion_r174321395
 
 

 ##########
 File path: tests/python/unittest/test_gluon_data.py
 ##########
 @@ -112,6 +117,76 @@ def test_multi_worker():
     for i, batch in enumerate(loader):
         assert (batch.asnumpy() == i).all()
 
+@with_seed()
+def test_multi_worker_data_loader():
+    class Dummy(Dataset):
+        def __init__(self, random_shape):
+            self.random_shape = random_shape
+
+        def __getitem__(self, idx):
+            key = idx
+            if self.random_shape:
+                out = np.random.uniform(size=(random.randint(1000, 1100), 40))
+                labels = np.random.uniform(size=(random.randint(10, 15)))
+            else:
+                out = np.random.uniform(size=(1000, 40))
+                labels = np.random.uniform(size=(10))
+            return key, out, labels
+
+        def __len__(self):
+            return 50
+
+        def batchify(self, data):
+            """
+            Collate data into batch. Use shared memory for stacking.
+
+            :param data: a list of array, with layout of 'NTC'.
+            :return either x  and x's unpadded lengths, or x, x's unpadded 
lengths, y and y's unpadded lengths
+                    if labels are not supplied.
+            """
+
+            # input layout is NTC
+            keys, inputs, labels = [item[0] for item in data], [item[1] for 
item in data], \
+                                   [item[2] for item in data]
+
+            if len(data) > 1:
+                max_data_len = max([seq.shape[0] for seq in inputs])
+                max_labels_len = 0 if not labels else max([seq.shape[0] for 
seq in labels])
+            else:
+                max_data_len = inputs[0].shape[0]
+                max_labels_len = 0 if not labels else labels[0].shape[0]
+
+            x_lens = [item.shape[0] for item in inputs]
+            y_lens = [item.shape[0] for item in labels]
+
+            for i, seq in enumerate(inputs):
+                pad_len = max_data_len - seq.shape[0]
+                inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', 
constant_values=0)
+                labels[i] = np.pad(labels[i], (0, max_labels_len - 
labels[i].shape[0]),
+                                   'constant', constant_values=-1)
+
+            inputs = np.asarray(inputs, dtype=np.float32)
+            if labels is not None:
+                labels = np.asarray(labels, dtype=np.float32)
+            inputs = inputs.transpose((1, 0, 2))
+            labels = labels.transpose((1, 0))
+
+            return (nd.array(inputs, dtype=inputs.dtype, 
ctx=context.Context('cpu_shared', 0)),
 
 Review comment:
   Ah I see, thanks

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to