SINGA-317 Extend ImageBatchIter to read labels in general format To enable the image list file to include more general information than label index. Now the second part of each line (separated by user defined delimiter) could be label strings, variable-length label indexs or a single label index. If it is a single label index, we return a numpy array of length = batchsize for the label indexs. Otherwise, we return a list of length = batchsize for the meta information of each image. Users have to parse the information in their code.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3415099a Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3415099a Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3415099a Branch: refs/heads/master Commit: 3415099a96a9b8c319fe36dd06bf28a7eea3ee92 Parents: be093f1 Author: Wei Wang <[email protected]> Authored: Wed May 24 19:46:05 2017 +0800 Committer: Wei Wang <[email protected]> Committed: Wed May 24 19:46:05 2017 +0800 ---------------------------------------------------------------------- python/singa/data.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3415099a/python/singa/data.py ---------------------------------------------------------------------- diff --git a/python/singa/data.py b/python/singa/data.py index 3a99ad3..ec4aa97 100644 --- a/python/singa/data.py +++ b/python/singa/data.py @@ -61,7 +61,12 @@ class ImageBatchIter: Args: img_list_file(str): name of the file containing image meta data; each - line consists of image_path_suffix delimiter label + line consists of image_path_suffix delimiter meta_info, + where meta info could be label index or label strings, etc. + meta_info should not contain the delimiter. If the meta_info + of each image is just the label index, then we will parse the + label index into a numpy array with length=batchsize + (for compatibility); otherwise, we return a list of meta_info batch_size(int): num of samples in one mini-batch image_transform: a function for image augmentation; it accepts the full image path and outputs a list of augmented images. @@ -106,21 +111,21 @@ class ImageBatchIter: def run(self): img_list = [] + is_labelindex = True for line in open(self.img_list_file, 'r'): - item = line.split(self.delimiter) - img_path = item[0] - img_label = int(item[1]) - img_list.append((img_label, img_path)) + item = line.strip('\n').split(self.delimiter) + if not item[1].strip().isdigit(): # the meta info is not label index + is_labelindex = False + img_list.append((item[0].strip(), item[1].strip())) index = 0 # index for the image if self.shuffle: random.shuffle(img_list) while not self.stop: if not self.queue.full(): - x = [] - y = np.empty(self.batch_size, dtype=np.int32) + x, y = [], [] i = 0 while i < self.batch_size: - img_label, img_path = img_list[index] + img_path, img_meta = img_list[index] aug_images = self.image_transform( os.path.join(self.image_folder, img_path)) assert i + len(aug_images) <= self.batch_size, \ @@ -129,7 +134,10 @@ class ImageBatchIter: for img in aug_images: ary = np.asarray(img.convert('RGB'), dtype=np.float32) x.append(ary.transpose(2, 0, 1)) - y[i] = img_label + if is_labelindex: + y.append(int(img_meta)) + else: + y.append(img_meta) i += 1 index += 1 if index == self.num_samples: @@ -137,7 +145,10 @@ class ImageBatchIter: if self.shuffle: random.shuffle(img_list) # enqueue one mini-batch - self.queue.put((np.asarray(x), y)) + if is_labelindex: + self.queue.put((np.asarray(x), np.asarray(y, dtype=np.int32))) + else: + self.queue.put((np.asarray(x), y)) else: time.sleep(0.1) return @@ -155,11 +166,12 @@ if __name__ == '__main__': (96, 96)).flip().get() data = ImageBatchIter('train.txt', 3, - image_transform, shuffle=True, delimiter=',', + image_transform, shuffle=False, delimiter=',', image_folder='images/', capacity=10) data.start() imgs, labels = data.next() + print labels for idx in range(imgs.shape[0]): img = Image.fromarray(imgs[idx].astype(np.uint8).transpose(1, 2, 0), 'RGB')
