SINGA-246 Imgtool for image augmentation

Reformat the code and add comments.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/26d9cd44
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/26d9cd44
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/26d9cd44

Branch: refs/heads/master
Commit: 26d9cd44598715b4c04dfdfccce44ba140ced248
Parents: 809a592
Author: Wei Wang <[email protected]>
Authored: Thu Sep 15 17:54:19 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Thu Sep 15 17:54:19 2016 +0800

----------------------------------------------------------------------
 python/singa/imgtool.py | 414 +++++++++++++++++++++++--------------------
 1 file changed, 226 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/26d9cd44/python/singa/imgtool.py
----------------------------------------------------------------------
diff --git a/python/singa/imgtool.py b/python/singa/imgtool.py
index aaaf9e0..021a060 100644
--- a/python/singa/imgtool.py
+++ b/python/singa/imgtool.py
@@ -17,408 +17,446 @@
 
 import random
 import numpy as np
-from PIL import Image,ImageEnhance
+from PIL import Image, ImageEnhance
+
 
 def load_img(path, grayscale=False):
-  img = Image.open(path)
-  if grayscale:
-      img = img.convert('L')
-  else:  # Ensure 3 channel even when loaded image is grayscale
-      img = img.convert('RGB')
-  return img
+    '''Read the image from a give path'''
+    img = Image.open(path)
+    if grayscale:
+        img = img.convert('L')
+    else:  # Ensure 3 channel even when loaded image is grayscale
+        img = img.convert('RGB')
+    return img
 
-def do_crop(img,crop,position):
 
+def do_crop(img, crop, position):
+    '''Crop the input image into given size at given position.
+
+    Args:
+        crop(tuple): width and height of the result image.
+        position(list(str)): left_top, left_bottom, right_top, right_bottom
+        and center.
+    '''
     if img.size[0] < crop[0]:
-        raise Exception('img size[0] %d is smaller than crop[0]: %d' % 
(img[0],crop[0]))
+        raise Exception(
+            'img size[0] %d is smaller than crop[0]: %d' % (img[0], crop[0]))
     if img.size[1] < crop[1]:
-        raise Exception('img size[1] %d is smaller than crop[1]: %d' % 
(img[1],crop[1]))
+        raise Exception(
+            'img size[1] %d is smaller than crop[1]: %d' % (img[1], crop[1]))
 
     if position == 'left_top':
-        left,upper=0,0
+        left, upper = 0, 0
     elif position == 'left_bottom':
-        left,upper=0,img.size[1]-crop[1] 
+        left, upper = 0, img.size[1]-crop[1]
     elif position == 'right_top':
-        left,upper=img.size[0]-crop[0],0
+        left, upper = img.size[0]-crop[0], 0
     elif position == 'right_bottom':
-        left,upper=img.size[0]-crop[0],img.size[1]-crop[1] 
+        left, upper = img.size[0]-crop[0], img.size[1]-crop[1]
     elif position == 'center':
-        left,upper=(img.size[0]-crop[0])/2,(img.size[1]-crop[1])/2
+        left, upper = (img.size[0]-crop[0])/2, (img.size[1]-crop[1])/2
     else:
         raise Exception('position is wrong')
 
-    box =(left,upper,left+crop[0],upper+crop[1]) 
+    box = (left, upper, left+crop[0], upper+crop[1])
     new_img = img.crop(box)
-    #print "crop to box %d,%d,%d,%d" % box
+    # print "crop to box %d,%d,%d,%d" % box
     return new_img
 
-def do_crop_and_scale(img,crop,position):
+
+def do_crop_and_scale(img, crop, position):
+    '''Crop a max square patch of the input image at given position and scale
+    it into given size.
+
+    Args:
+        crop(tuple): width, height
+        position(list(str)): left, center, right, top, middle, bottom.
+    '''
     size = img.size
     if position == 'left':
-        left,upper=0,0
-        right,bottom = size[1],size[1]
+        left, upper = 0, 0
+        right, bottom = size[1], size[1]
     elif position == 'center':
-        left,upper=(size[0]-size[1])/2,0 
-        right,bottom =(size[0]+size[1])/2,size[1]
+        left, upper = (size[0]-size[1])/2, 0
+        right, bottom = (size[0]+size[1])/2, size[1]
     elif position == 'right':
-        left,upper=size[0]-size[1],0 
-        right,bottom =size[0],size[1]
+        left, upper = size[0]-size[1], 0
+        right, bottom = size[0], size[1]
     elif position == 'top':
-        left,upper=0,0 
-        right,bottom =size[0],size[0]
+        left, upper = 0, 0
+        right, bottom = size[0], size[0]
     elif position == 'middle':
-        left,upper=0,(size[1]-size[0])/2
-        right,bottom =size[0],(size[1]+size[0])/2
+        left, upper = 0, (size[1]-size[0])/2
+        right, bottom = size[0], (size[1]+size[0])/2
     elif position == 'bottom':
-        left,upper=0,size[1]-size[0]
-        right,bottom =size[0],size[1]
+        left, upper = 0, size[1]-size[0]
+        right, bottom = size[0], size[1]
     else:
         raise Exception('position is wrong')
-    box =(left,upper,right,bottom) 
+    box = (left, upper, right, bottom)
     new_img = img.crop(box)
 
     new_img = img.resize(crop)
-    #print box+crop
-    #print "crop to box %d,%d,%d,%d and scale to %d,%d" % (box+crop)
+    # print box+crop
+    # print "crop to box %d,%d,%d,%d and scale to %d,%d" % (box+crop)
     return new_img
 
-def do_resize(img,small_size):
+
+def do_resize(img, small_size):
+    '''Resize the image to make the smaller side be at the given size'''
     size = img.size
-    if size[0]<size[1]:
-        new_size = ( small_size, int(small_size*size[1]/size[0]) )
+    if size[0] < size[1]:
+        new_size = (small_size, int(small_size*size[1]/size[0]))
     else:
-        new_size = ( int(small_size*size[0]/size[1]), small_size )
-    new_img=img.resize(new_size)
-    #print 'resize to (%d,%d)' % new_size
+        new_size = (int(small_size*size[0]/size[1]), small_size)
+    new_img = img.resize(new_size)
+    # print 'resize to (%d,%d)' % new_size
     return new_img
- 
 
-def do_color_cast(img,offset):
 
+def do_color_cast(img, offset):
+    '''Add a random number from [-offset, offset] to each channel'''
     x = np.asarray(img, dtype='uint8')
-    x.flags.writeable = True  
-    cast_value=[0,0,0]
+    x.flags.writeable = True
+    cast_value = [0, 0, 0]
     for i in range(3):
-        r= random.randint(0,1)
+        r = random.randint(0, 1)
         if r:
-            cast_value[i] = random.randint(-offset,offset)
+            cast_value[i] = random.randint(-offset, offset)
     for w in range(img.size[0]):
         for h in range(img.size[1]):
             for c in range(3):
-                if cast_value[c]==0:
+                if cast_value[c] == 0:
                     continue
-                v=x[w][h][c]+cast_value[c] 
-                if v<0:
-                    v=0
-                if v>255:
-                    v=255
-                x[w][h][c]=v
-    new_img= Image.fromarray(x.astype('uint8'), 'RGB')
+                v = x[w][h][c]+cast_value[c]
+                if v < 0:
+                    v = 0
+                if v > 255:
+                    v = 255
+                x[w][h][c] = v
+    new_img = Image.fromarray(x.astype('uint8'), 'RGB')
     return new_img
 
 
-def do_enhance(img,scale):
+def do_enhance(img, scale):
+    '''Apply random enhancement for Color,Contrast,Brightness,Sharpness.
 
-    # Color,Contrast,Brightness,Sharpness
-    enhance_value=[1.0,1.0,1.0,1.0]
+    Args:
+        scale(float): enhancement degree is from [1-scale, 1+scale]
+    '''
+    enhance_value = [1.0, 1.0, 1.0, 1.0]
     for i in range(4):
-        r= random.randint(0,1)
+        r = random.randint(0, 1)
         if r:
-            enhance_value[i] = random.uniform(1-scale,1+scale)
-    if not enhance_value[0]==1.0:
-        enhancer = ImageEnhance.Color(img) 
+            enhance_value[i] = random.uniform(1-scale, 1+scale)
+    if not enhance_value[0] == 1.0:
+        enhancer = ImageEnhance.Color(img)
         img = enhancer.enhance(enhance_value[0])
-    if not enhance_value[1]==1.0:
-        enhancer = ImageEnhance.Contrast(img) 
+    if not enhance_value[1] == 1.0:
+        enhancer = ImageEnhance.Contrast(img)
         img = enhancer.enhance(enhance_value[1])
-    if not enhance_value[2]==1.0:
-        enhancer = ImageEnhance.Brightness(img) 
+    if not enhance_value[2] == 1.0:
+        enhancer = ImageEnhance.Brightness(img)
         img = enhancer.enhance(enhance_value[2])
-    if not enhance_value[3]==1.0:
-        enhancer = ImageEnhance.Sharpness(img) 
+    if not enhance_value[3] == 1.0:
+        enhancer = ImageEnhance.Sharpness(img)
         img = enhancer.enhance(enhance_value[3])
     return img
 
+
 def do_flip(img):
-    #print 'flip'
+    # print 'flip'
     new_img = img.transpose(Image.FLIP_LEFT_RIGHT)
-    return new_img     
- 
-def get_list_sample(l,sample_size):
-    return [ l[i] for i in sorted(random.sample(xrange(len(l)), sample_size))]
+    return new_img
+
+
+def get_list_sample(l, sample_size):
+    return [l[i] for i in sorted(random.sample(xrange(len(l)), sample_size))]
+
 
 class Imgtool():
 
     def __init__(self):
-        self.imgs=[]
+        self.imgs = []
         return
 
-    def load(self,path):
+    def load(self, path):
         img = load_img(path)
-        self.imgs=[img]
+        self.imgs = [img]
         return self
 
-    def set(self,imgs):
+    def set(self, imgs):
         self.imgs = imgs
         return self
 
-    def append(self,img):
+    def append(self, img):
         self.imgs.append(img)
         return self
 
     def get(self):
         return self.imgs
- 
-    def resize_by_range(self,rng,k=1,update=True):
 
+    def resize_by_range(self, rng, k=1, update=True):
         '''
         Args:
             rng: a tuple (begin,end), include begin, exclude end
-            k: number of samples, must be smaller than or equare to the length 
of 
-               the range list, if k=0, then sample all
+            k: number of samples, must be smaller than or equare to the length
+               of the range list, if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        size_list = range(rng[0],rng[1])
-        return self.resize_by_list(size_list,k,update)
-        
-    def resize_by_list(self,size_list,k=1,update=True):
+        size_list = range(rng[0], rng[1])
+        return self.resize_by_list(size_list, k, update)
+
+    def resize_by_list(self, size_list, k=1, update=True):
         '''
         Args:
-            k: number of samples, must be smaller than or equare to the length 
of 
-               size_list, if k=0, then sample all
+            k: number of samples, must be smaller than or equare to the length
+               of size_list, if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        if k<0 or k > len(size_list):
-            raise Exception('k must be smaller in [0,%d(length of size_list)]' 
% len(size_list))
+        new_imgs = []
+        if k < 0 or k > len(size_list):
+            raise Exception(
+                'k must be smaller in [0,%d(length of size_list)]' %
+                len(size_list))
         for img in self.imgs:
-            if k ==0 or k== len(size_list):
+            if k == 0 or k == len(size_list):
                 small_sizes = size_list
             else:
-                small_sizes = get_list_sample(size_list,k) 
+                small_sizes = get_list_sample(size_list, k)
 
             for small_size in small_sizes:
-                new_img= do_resize(img,small_size)
+                new_img = do_resize(img, small_size)
                 new_imgs.append(new_img)
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
-    def resize_for_test(self,rng):
+    def resize_for_test(self, rng):
         '''
         Args:
             rng: a tuple (begin,end)
         '''
-        size_list=[rng[0],rng[0]/2+rng[1]/2,rng[1]]
-        return self.resize_by_list(size_list,k=3)
-
-    def rotate_by_range(self,rng,k=1,update=True):
+        size_list = [rng[0], rng[0]/2+rng[1]/2, rng[1]]
+        return self.resize_by_list(size_list, k=3)
 
+    def rotate_by_range(self, rng, k=1, update=True):
         '''
         Args:
-            rng: a tuple (begin,end), include begin, exclude end
-            k: number of samples, must be smaller than or equare to the length 
of 
-               the range list, if k=0, then sample all
+            rng: a tuple (begin,end) in degree, include begin, exclude end
+            k: number of samples, must be smaller than or equare to the length
+               of the range list, if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        size_list = range(rng[0],rng[1])
-        return self.rotate_by_list(angle_list,k,update)
-        
-    def rotate_by_list(self,angle_list,k=1,update=True):
+        angle_list = range(rng[0], rng[1])
+        return self.rotate_by_list(angle_list, k, update)
+
+    def rotate_by_list(self, angle_list, k=1, update=True):
         '''
         Args:
-            k: number of samples, must be smaller than or equare to the length 
of 
-               size_list, if k=0, then sample all
+            k: number of samples, must be smaller than or equare to the length
+               of size_list, if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        if k<0 or k > len(angle_list):
-            raise Exception('k must be smaller in [0,%d(length of 
angle_list)]' % len(angle_list))
+        new_imgs = []
+        if k < 0 or k > len(angle_list):
+            raise Exception(
+                'k must be smaller in [0,%d(length of angle_list)]' %
+                len(angle_list))
 
         for img in self.imgs:
-            if k ==0 or k== len(angle_list):
+            if k == 0 or k == len(angle_list):
                 angles = angle_list
             else:
-                angles = get_list_sample(angle_list,k) 
+                angles = get_list_sample(angle_list, k)
 
             for angle in angles:
-                new_img= img.rotate(angle)
+                new_img = img.rotate(angle)
                 new_imgs.append(new_img)
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
+    def crop_5(self, crop_size, k=1, update=True):
+        '''Crop at positions from [left_top, left_bottom, right_top,
+        right_bottom, and center].
 
-    def crop_5(self,crop_size,k=1,update=True):
-        '''
         Args:
+            crop_size(tuple): width and height of the result image.
             k: number of samples, must be in [0,5], if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        
positions=["left_top","left_bottom","right_top","right_bottom","center"]
-        if k > 5 or k <0:
+        new_imgs = []
+        positions = [
+            "left_top",
+            "left_bottom",
+            "right_top",
+            "right_bottom",
+            "center"]
+        if k > 5 or k < 0:
             raise Exception('k must be in [0,5]')
         for img in self.imgs:
 
-            if k>0 and k<5:
-                positions = get_list_sample(positions,k) 
+            if k > 0 and k < 5:
+                positions = get_list_sample(positions, k)
 
             for position in positions:
-                new_img=do_crop(img,crop_size,position)
+                new_img = do_crop(img, crop_size, position)
                 new_imgs.append(new_img)
 
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
-    def crop_and_scale(self,crop_size,k=1,update=True):
-        '''
-        crop and scale the image into the given size.
-        According to img size, crop position could be either (left, center, 
right)
-        or (up, center, low).
+    def crop_and_scale(self, crop_size, k=1, update=True):
+        '''Crop a max square patch of the input image at given position and
+        scale it into given size.
+
+        According to img size, crop position could be either
+        (left, center, right) or (top, middle, bottom).
 
         Args:
+            crop_size(tuple): the width and height the output image
             k: number of samples, must be in [0,3], if k=0, then sample all
             update: update imgs or not ( return new_imgs)
         '''
         if not crop_size[0] == crop_size[1]:
             raise Exception('crop_size must be a square')
-        new_imgs=[]
-        if k > 3 or k <0:
+        new_imgs = []
+        if k > 3 or k < 0:
             raise Exception('k must be in [0,3]')
-        positions_horizental=["left","center","right"]
-        positions_vertical=["top","middle","bottom"]
+        positions_horizental = ["left", "center", "right"]
+        positions_vertical = ["top", "middle", "bottom"]
         for img in self.imgs:
             size = img.size
-            if size[0] > size[1]: 
-                if k>0 and k<3:
-                    positions = get_list_sample(positions_horizental,k) 
+            if size[0] > size[1]:
+                if k > 0 and k < 3:
+                    positions = get_list_sample(positions_horizental, k)
                 else:
-                    positions = positions_horizental 
+                    positions = positions_horizental
             else:
-                if k>0 and k<3:
-                    positions = get_list_sample(positions_vertical,k) 
+                if k > 0 and k < 3:
+                    positions = get_list_sample(positions_vertical, k)
                 else:
-                    positions = positions_vertical 
- 
+                    positions = positions_vertical
+
             for position in positions:
-                new_img=do_crop_and_scale(img,crop_size,position)
+                new_img = do_crop_and_scale(img, crop_size, position)
                 new_imgs.append(new_img)
 
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
-    def crop_union(self,crop_size,k=1,update=True):
+    def crop_union(self, crop_size, k=1, update=True):
+        '''This is a union of crop_5 and crop_and_scale.
 
-        '''
-        this is a union of crop_5 and crop_and_scale 
-        you can follow this example to union any number of imgtool methods
+        You can follow this example to union any number of imgtool methods
         '''
         crop_5_num = 5
         crop_and_scale_num = 3
-        if k<0 or k> crop_5_num+crop_and_scale_num:
-            raise Exception('k must be in [0,%d]' % 
(crop_5_num+crop_and_scale_num) )
-        if k==0 or k == crop_5_num+crop_and_scale_num:
+        if k < 0 or k > crop_5_num+crop_and_scale_num:
+            raise Exception(
+                'k must be in [0,%d]' %
+                (crop_5_num+crop_and_scale_num))
+        if k == 0 or k == crop_5_num+crop_and_scale_num:
             count = crop_5_num
         else:
-            sample_list = range(0,crop_5_num+crop_and_scale_num)
-            samples = get_list_sample(sample_list,k)
-            count=0
+            sample_list = range(0, crop_5_num+crop_and_scale_num)
+            samples = get_list_sample(sample_list, k)
+            count = 0
             for s in samples:
                 if s < crop_5_num:
-                    count+=1 
-        new_imgs=[]
+                    count += 1
+        new_imgs = []
         if count > 0:
-            new_imgs += self.crop_5(crop_size,k=count,update=False) 
+            new_imgs += self.crop_5(crop_size, k=count, update=False)
         if k-count > 0:
-            new_imgs += self.crop_and_scale(crop_size,k=k-count,update=False)
+            new_imgs += self.crop_and_scale(crop_size, k=k-count, update=False)
 
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
-  
-    def flip(self,k=1,update=True):
+
+    def flip(self, k=1, update=True):
         '''
         randomly flip a img left to right
         Args:
             k: number of samples, must be in [0,1,2], if k=0,2, then sample all
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        if k<0 or k>2:
+        new_imgs = []
+        if k < 0 or k > 2:
             raise Exception('k must be in [0,2]')
         for img in self.imgs:
-            flips = [0,1]
-            if k>0 and k <2:
-                r=random.randint(0,1)
+            flips = [0, 1]
+            if k > 0 and k < 2:
+                r = random.randint(0, 1)
                 flips = [r]
             for flip in flips:
-                if flip: 
-                    new_img=do_flip(img) 
+                if flip:
+                    new_img = do_flip(img)
                 else:
-                    new_img=img              
+                    new_img = img
                 new_imgs.append(new_img)
 
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
-    def color_cast(self,offset=20,k=1,update=True):
-        '''
-        randomly do color cast on rgb channels of a img
+    def color_cast(self, offset=20, k=1, update=True):
+        '''Add a random number from [-offset, offset] to each channel
+
         Args:
             offset: cast offset, >0 and <255
-            k: number of samples, must be larger than 0 
+            k: number of samples, must be larger than 0
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        if k<=0:
+        new_imgs = []
+        if k <= 0:
             raise Exception('k must be larger than 0')
-        if offset<0 or offset>255:
+        if offset < 0 or offset > 255:
             raise Exception('offset must be >0 and <255')
-         
+
         for img in self.imgs:
             for i in range(k):
-                new_img=do_color_cast(img,offset)
+                new_img = do_color_cast(img, offset)
                 new_imgs.append(new_img)
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
 
-    def enhance(self,scale=0.2,k=1,update=True):
-        '''
-        randomly do color, contrast, brightness and sharpness enhance on img
+    def enhance(self, scale=0.2, k=1, update=True):
+        '''Apply random enhancement for Color,Contrast,Brightness,Sharpness.
+
         Args:
-            scale: cast scale, >0 and <1
-            k: number of samples, must be larger than 0 
+            scale(float): enhancement degree is from [1-scale, 1+scale]
+            k: number of samples, must be larger than 0
             update: update imgs or not ( return new_imgs)
         '''
-        new_imgs=[]
-        if k<=0:
+        new_imgs = []
+        if k <= 0:
             raise Exception('k must be larger than 0')
         for img in self.imgs:
             for i in range(k):
-                new_img=do_enhance(img,scale)
+                new_img = do_enhance(img, scale)
                 new_imgs.append(new_img)
         if update:
-            self.imgs=new_imgs
+            self.imgs = new_imgs
             return self
         else:
             return new_imgs
-

Reply via email to