--- origsprite.py	Mon Apr 21 11:11:47 2008
+++ patsprite.py	Mon Apr 21 11:14:04 2008
@@ -77,7 +77,7 @@
 import pygame
 from pygame import Rect
 from pygame.time import get_ticks
-
+from pygame.mask import from_surface
 
 
 
@@ -1056,49 +1056,165 @@
 
 # some different collision detection functions that could be used.
 
-def collide_rect( left, right ):
-    return left.rect.colliderect( right.rect )
+def collide_rect(left, right):
+    """pygame.sprite.collide_rect(left, right) -> bool
+       collision detection between two sprites, using rects.
+
+       Tests for collision between two sprites. Uses the
+       pygame rect colliderect function to calculate the
+       collision. Intended to be passed as a collided
+       callback function to the *collide functions.
+       Sprites must have a "rect" attributes."""
+    return left.rect.colliderect(right.rect)
 
 class collide_rect_ratio:
+    """A callable class that checks for collisions between
+       two sprites, using a scaled version of the sprites
+       rects.
+
+       Is created with a ratio, the instance is then intended
+       to be passed as a collided callback function to the
+       *collide functions."""
+    
     def __init__( self, ratio ):
+        """Creates a new collide_rect_ratio callable. ratio is
+           expected to be a floating point value used to scale
+           the underlying sprite rect before checking for
+           collisions."""
         self.ratio = ratio
 
-    def __inflate_by_ratio( self, rect ):
-        width = rect.width * self.ratio
-        height = rect.height * self.ratio
-        return rect.inflate( width - rect.width, height - rect.height )
-
     def __call__( self, left, right ):
-        leftrect = self.__inflate_by_ratio( left.rect )
-        rightrect = self.__inflate_by_ratio( right.rect )
+        """pygame.sprite.collide_rect_ratio(ratio)(left, right) -> bool
+           collision detection between two sprites, using scaled rects.
+
+           Tests for collision between two sprites. Uses the
+           pygame rect colliderect function to calculate the
+           collision, after scaling the rects by the stored ratio.
+           Sprites must have a "rect" attributes."""
+        ratio = self.ratio
+        
+        leftrect = left.rect
+        width = leftrect.width
+        height = leftrect.height
+        leftrect = leftrect.inflate( width * ratio - width, height * ratio - height )
+        
+        rightrect = right.rect
+        width = rightrect.width
+        height = rightrect.height
+        rightrect = rightrect.inflate( width * ratio - width, height * ratio - height )
+        
         return leftrect.colliderect( rightrect )
 
 def collide_circle( left, right ):
-    def get_radius_squared( sprite ):
-        try:
-            radiusSquared = sprite.radius ** 2
-        except AttributeError:
-            rect = sprite.rect
-            radiusSquared = ( rect.width ** 2 + rect.height ** 2 ) / 4
-        return radiusSquared
-
-    xDistance = left.rect.centerx - right.rect.centerx
-    yDistance = left.rect.centery - right.rect.centery
-    distanceSquared = xDistance ** 2 + yDistance ** 2
-    return distanceSquared < get_radius_squared( left ) + get_radius_squared( right )
-
-
-def collide_mask( left, right ):
-
-    offset_x = left.rect[0] - right.rect[0]
-    offset_y = left.rect[1] - right.rect[1]
-    # See if the two masks at the offset are overlapping.
-    overlap = left.mask.overlap(right.mask, (offset_x, offset_y))
-    return overlap
-    
+    """pygame.sprite.collide_circle(left, right) -> bool
+       collision detection between two sprites, using circles.
 
+       Tests for collision between two sprites, by testing to
+       see if two circles centered on the sprites overlap. If
+       the sprites have a "radius" attribute, that is used to
+       create the circle, otherwise a circle is created that
+       is big enough to completely enclose the sprites rect as
+       given by the "rect" attribute. Intended to be passed as
+       a collided callback function to the *collide functions.
+       Sprites must have a "rect" and an optional "radius"
+       attribute."""
+    xdistance = left.rect.centerx - right.rect.centerx
+    ydistance = left.rect.centery - right.rect.centery
+    distancesquared = xdistance ** 2 + ydistance ** 2
+    try:
+        leftradiussquared = left.radius ** 2
+    except AttributeError:
+        leftrect = left.rect
+        leftradiussquared = ( leftrect.width ** 2 + leftrect.height ** 2 ) / 4
+    try:
+        rightradiussquared = right.radius ** 2
+    except AttributeError:
+        rightrect = right.rect
+        rightradiussquared = ( rightrect.width ** 2 + rightrect.height ** 2 ) / 4
+    return distancesquared < leftradiussquared + rightradiussquared
+
+class collide_circle_ratio( object ):
+    """A callable class that checks for collisions between
+       two sprites, using a scaled version of the sprites
+       radius.
+
+       Is created with a ratio, the instance is then intended
+       to be passed as a collided callback function to the
+       *collide functions."""
+    
+    def __init__( self, ratio ):
+        """Creates a new collide_circle_ratio callable. ratio is
+           expected to be a floating point value used to scale
+           the underlying sprite radius before checking for
+           collisions."""
+        self.ratio = ratio
+        # Constant value that folds in division for diameter to radius,
+        # when calculating from a rect.
+        self.halfratio = ratio ** 2 / 4.0
 
+    def __call__( self, left, right ):
+        """pygame.sprite.collide_circle_radio(ratio)(left, right) -> bool
+           collision detection between two sprites, using scaled circles.
 
+           Tests for collision between two sprites, by testing to
+           see if two circles centered on the sprites overlap, after
+           scaling the circles radius by the stored ratio. If
+           the sprites have a "radius" attribute, that is used to
+           create the circle, otherwise a circle is created that
+           is big enough to completely enclose the sprites rect as
+           given by the "rect" attribute. Intended to be passed as
+           a collided callback function to the *collide functions.
+           Sprites must have a "rect" and an optional "radius"
+           attribute."""
+        ratio = self.ratio
+        xdistance = left.rect.centerx - right.rect.centerx
+        ydistance = left.rect.centery - right.rect.centery
+        distancesquared = xdistance ** 2 + ydistance ** 2
+        # Optimize for not containing radius attribute, as if radius was
+        # set consistently, would probably be using collide_circle instead.
+        if hasattr( left, "radius" ):
+            leftradiussquared = (left.radius * ratio) ** 2
+            
+            if hasattr( right, "radius" ):
+                rightradiussquared = (right.radius * ratio) ** 2
+            else:
+                halfratio = self.halfratio
+                rightrect = right.rect
+                rightradiussquared = (rightrect.width ** 2 + rightrect.height ** 2) * halfratio
+        else:
+            halfratio = self.halfratio
+            leftrect = left.rect
+            leftradiussquared = (leftrect.width ** 2 + leftrect.height ** 2) * halfratio
+            
+            if hasattr( right, "radius" ):
+                rightradiussquared = (right.radius * ratio) ** 2
+            else:
+                rightrect = right.rect
+                rightradiussquared = (rightrect.width ** 2 + rightrect.height ** 2) * halfratio
+        return distancesquared < leftradiussquared + rightradiussquared
+
+def collide_mask(left, right):
+    """pygame.sprite.collide_mask(left, right) -> bool
+       collision detection between two sprites, using masks.
+
+       Tests for collision between two sprites, by testing if
+       thier bitmasks overlap. If the sprites have a "mask"
+       attribute, that is used as the mask, otherwise a mask is
+       created from the sprite image. Intended to be passed as
+       a collided callback function to the *collide functions.
+       Sprites must have a "rect" and an optional "mask"
+       attribute."""
+    xoffset = right.rect[0] - left.rect[0]
+    yoffset = right.rect[1] - left.rect[1]
+    try:
+        leftmask = left.mask
+    except AttributeError:
+        leftmask = from_surface(left.image)
+    try:
+        rightmask = right.mask
+    except AttributeError:
+        rightmask = from_surface(right.image)
+    return leftmask.overlap(rightmask, (xoffset, yoffset))
 
 def spritecollide(sprite, group, dokill, collided = None):
     """pygame.sprite.spritecollide(sprite, group, dokill) -> list
@@ -1107,16 +1223,19 @@
        given a sprite and a group of sprites, this will
        return a list of all the sprites that intersect
        the given sprite.
-       all sprites must have a "rect" value, which is a
-       rectangle of the sprite area. if the dokill argument
-       is true, the sprites that do collide will be
-       automatically removed from all groups."""
+       if the dokill argument is true, the sprites that
+       do collide will be automatically removed from all
+       groups.
+       collided is a callback function used to calculate if
+       two sprites are colliding. it should take two sprites
+       as values, and return a boolean value indicating if
+       they are colliding. if collided is not passed, all
+       sprites must have a "rect" value, which is a
+       rectangle of the sprite area, which will be used
+       to calculate the collision."""
     crashed = []
-    acollide = collided
-
-
-    if not collided:
-        # special case 
+    if collided is None:
+        # Special case old behaviour for speed.
         spritecollide = sprite.rect.colliderect
         if dokill:
             for s in group.sprites():
@@ -1130,12 +1249,12 @@
     else:
         if dokill:
             for s in group.sprites():
-                if acollide( sprite, s ):
+                if collided(sprite, s):
                     s.kill()
                     crashed.append(s)
         else:
             for s in group:
-                if acollide( sprite, s ):
+                if collided(sprite, s):
                     crashed.append(s)
     return crashed
 
@@ -1150,12 +1269,16 @@
        is a list of the sprites in the second group it
        collides with. the two dokill arguments control if
        the sprites from either group will be automatically
-       removed from all groups."""
+       removed from all groups.
+       collided is a callback function used to calculate if
+       two sprites are colliding. it should take two sprites
+       as values, and return a boolean value indicating if
+       they are colliding. if collided is not passed, all
+       sprites must have a "rect" value, which is a
+       rectangle of the sprite area, which will be used
+       to calculate the collision."""
     crashed = {}
     SC = spritecollide
-    if not collided:
-        collided = collide_rect
-
     if dokilla:
         for s in groupa.sprites():
             c = SC(s, groupb, dokillb, collided)
@@ -1182,18 +1305,21 @@
        spritecollide function, this function will be a
        bit quicker.
 
-       all sprites must have a "rect" value, which is a
-       rectangle of the sprite area."""
-
-    if collided:
+       collided is a callback function used to calculate if
+       two sprites are colliding. it should take two sprites
+       as values, and return a boolean value indicating if
+       they are colliding. if collided is not passed, all
+       sprites must have a "rect" value, which is a
+       rectangle of the sprite area, which will be used
+       to calculate the collision."""
+    if collided is None:
+        # Special case old behaviour for speed.
+        spritecollide = sprite.rect.colliderect
         for s in group:
-            if collided( sprite, s ):
+            if spritecollide(s.rect):
                 return s
     else:
-        spritecollide = sprite.rect.colliderect
         for s in group:
-            if spritecollide(s.rect):
+            if collided(sprite, s):
                 return s
-
-
     return None
