Revision: 6763
          http://matplotlib.svn.sourceforge.net/matplotlib/?rev=6763&view=rev
Author:   jdh2358
Date:     2009-01-08 02:50:18 +0000 (Thu, 08 Jan 2009)

Log Message:
-----------
added a nearest neighbor cython search example

Added Paths:
-----------
    trunk/py4science/examples/pyrex/nnbf/
    trunk/py4science/examples/pyrex/nnbf/nnbf.pyx
    trunk/py4science/examples/pyrex/nnbf/nnbf_proto.py
    trunk/py4science/examples/pyrex/nnbf/nnbf_v1.pyx
    trunk/py4science/examples/pyrex/nnbf/setup.py
    trunk/py4science/examples/pyrex/nnbf/test_nnbf.py

Added: trunk/py4science/examples/pyrex/nnbf/nnbf.pyx
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/nnbf.pyx                               
(rev 0)
+++ trunk/py4science/examples/pyrex/nnbf/nnbf.pyx       2009-01-08 02:50:18 UTC 
(rev 6763)
@@ -0,0 +1,112 @@
+"""
+A brute force nearest neighbor routine with incremental add.  The
+internal array data structure grows as you add points
+"""
+
+import numpy as np
+cimport numpy as np
+
+cdef extern from "math.h":
+     float sqrt(float)
+
+cdef inline int is_neighbor(int n, double*row, double*pp, double d2max):
+    """
+    return 1 if the sum-of-squares of n length array row[j]-pp[j] <= d2max
+    """
+    cdef int j
+    cdef double d, d2
+
+    d2 = 0.
+
+    for j in range(n):
+        d = row[j] - pp[j]
+        d2 += d*d
+        if d2>d2max:
+            return 0
+    return 1
+
+cdef class NNBF:
+    cdef readonly object data
+    #cdef double* raw_data
+    cdef readonly int n, numrows, numpoints
+
+    def __init__(self, n):
+        """
+        create a buffer to hold n dimensional points
+        """
+        #cdef np.ndarray[double, ndim=2] inner_data
+
+        
+        self.n = n
+        self.numrows = 100
+        #  XXX how to create mepty as contiguous w/o copy?
+        self.data = np.empty((self.numrows, self.n), dtype=np.float)
+        #inner_data = self.data
+        #self.raw_data = <double*>inner_data.data
+        self.numpoints = 0
+
+
+    def add(NNBF self, object point):
+        """
+        add a point to the buffer, grow if necessary
+        """
+        #cdef np.ndarray[double, ndim=2] inner_data
+        cdef np.ndarray[double, ndim=1] pp
+        pp = np.asarray(point).astype(np.float)
+
+
+        self.data[self.numpoints] = pp
+        self.numpoints += 1
+        if self.numpoints==self.numrows:
+            ## XXX do I need to do memory management here, eg free
+            ## raw_data if I were using it?
+            self.numrows *= 2
+            newdata = np.empty((self.numrows, self.n), np.float)
+            newdata[:self.numpoints] = self.data
+            self.data = newdata
+            #self.raw_data = <double*>inner_data.data
+
+    def get_data(NNBF self):
+        """
+        return a copy of data added so far as a numpoints x n array
+        """
+        return self.data[:self.numpoints]
+
+
+    def find_neighbors(NNBF self, object point, double radius):
+        """
+        return a list of indices into data which are within radius
+        from point
+        """
+        cdef int i, neighbor
+        cdef double d2max
+        cdef np.ndarray[double, ndim=1] pp
+        cdef np.ndarray[double, ndim=1] row        
+
+        if len(point)!=self.n:
+            raise ValueError('Expected a length %d vector'%self.n)
+        
+        pp = np.asarray(point).astype(np.float)
+
+        d2max = radius*radius
+        neighbors = []
+
+
+        for i in range(self.numpoints):
+            # XXX : is there a more efficient way to access the row
+            # data?  Can/should we be using raw_data here?
+            row = self.data[i]
+            neighbor = is_neighbor(
+                self.n,
+                <double*>row.data,
+                <double*>pp.data,
+                d2max)
+            
+            # if the number of points in the cluster is small, the
+            # python list performance should not kill us
+            if neighbor:
+                neighbors.append(i)
+
+        return neighbors
+
+

Added: trunk/py4science/examples/pyrex/nnbf/nnbf_proto.py
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/nnbf_proto.py                          
(rev 0)
+++ trunk/py4science/examples/pyrex/nnbf/nnbf_proto.py  2009-01-08 02:50:18 UTC 
(rev 6763)
@@ -0,0 +1,102 @@
+"""
+A brute force nearest neighbor routine
+"""
+import math
+import numpy as np
+
+class NNBF:
+    def __init__(self, n):
+        """
+        create a buffer to hold n dimensional points
+        """
+        self.n = n
+        self.numrows = 100
+        self.data = np.empty((self.numrows, self.n), np.float)
+        self.numpoints = 0
+
+    def add(self, point):
+        self.data[self.numpoints] = point
+        self.numpoints += 1
+        if self.numpoints==self.numrows:
+            self.numrows *= 2
+            newdata = np.empty((self.numrows, self.n), np.float)
+            newdata[:self.numpoints] = self.data
+            self.data = newdata
+
+    def get_data(self):
+        """
+        return a copy of data added so far
+        """
+        return self.data[:self.numpoints]
+
+
+    def find_neighbors(self, point, radius):
+        """
+        return a list of indices into data which are within radius
+        from point
+        """
+        data = self.get_data()
+        neighbors = []
+        for i in range(self.numpoints):
+            row = data[i]
+            fail = False
+            d2 = 0.
+            for j in range(self.n):
+                rowval = row[j]
+                pntval = point[j]
+                d = rowval-pntval
+                if d>radius:
+                    fail = True
+                    break
+                d2 += d*d
+
+            if fail: continue
+
+            r = math.sqrt(d2)
+            if r<=radius:
+                neighbors.append(i)
+
+        return neighbors
+
+    def find_neighbors_numpy(self, point, radius):
+        """
+        Use plain ol numpy to find neighbors
+        """
+        data = self.get_data()
+        neighbors = []
+        for i in range(self.numpoints):
+            row = data[i]
+            fail = False
+            d2 = 0.
+            for j in range(self.n):
+                rowval = row[j]
+                pntval = point[j]
+                d = rowval-pntval
+                if d>radius:
+                    fail = True
+                    break
+                d2 += d*d
+
+            if fail: continue
+
+            r = math.sqrt(d2)
+            if r<=radius:
+                neighbors.append(i)
+
+        return neighbors
+
+def find_neighbors_numpy(data, point, radius):
+    """
+    do a plain ol numpy lookup to compare performance and output
+
+      *data* is a numpoints x numdims array
+      *point* is a numdims length vector
+      radius is the max distance distance
+
+    return an array of indices into data which are within radius
+    """
+    numpoints, n = data.shape
+
+    distance = data - point
+    r = np.sqrt((distance*distance).sum(axis=1))
+    return np.nonzero(r<=radius)[0]

Added: trunk/py4science/examples/pyrex/nnbf/nnbf_v1.pyx
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/nnbf_v1.pyx                            
(rev 0)
+++ trunk/py4science/examples/pyrex/nnbf/nnbf_v1.pyx    2009-01-08 02:50:18 UTC 
(rev 6763)
@@ -0,0 +1,112 @@
+"""
+A brute force nearest neighbor routine with incremental add.  The
+internal array data structure grows as you add points
+"""
+
+import numpy as np
+cimport numpy as np
+
+cdef extern from "math.h":
+     float sqrt(float)
+
+cdef inline int is_neighbor(int n, double*row, double*pp, double d2max):
+    """
+    return 1 if the sum-of-squares of n length array row[j]-pp[j] <= d2max
+    """
+    cdef int j
+    cdef double d, d2
+
+    d2 = 0.
+
+    for j in range(n):
+        d = row[j] - pp[j]
+        d2 += d*d
+        if d2>d2max:
+            return 0
+    return 1
+
+cdef class NNBF:
+    cdef readonly object data
+    #cdef double* raw_data
+    cdef readonly int n, numrows, numpoints
+
+    def __init__(self, n):
+        """
+        create a buffer to hold n dimensional points
+        """
+        #cdef np.ndarray[double, ndim=2] inner_data
+
+        
+        self.n = n
+        self.numrows = 100
+        #  XXX how to create mepty as contiguous w/o copy?
+        self.data = np.empty((self.numrows, self.n), dtype=np.float)
+        #inner_data = self.data
+        #self.raw_data = <double*>inner_data.data
+        self.numpoints = 0
+
+
+    def add(NNBF self, object point):
+        """
+        add a point to the buffer, grow if necessary
+        """
+        #cdef np.ndarray[double, ndim=2] inner_data
+        cdef np.ndarray[double, ndim=1] pp
+        pp = np.asarray(point).astype(np.float)
+
+
+        self.data[self.numpoints] = pp
+        self.numpoints += 1
+        if self.numpoints==self.numrows:
+            ## XXX do I need to do memory management here, eg free
+            ## raw_data if I were using it?
+            self.numrows *= 2
+            newdata = np.empty((self.numrows, self.n), np.float)
+            newdata[:self.numpoints] = self.data
+            self.data = newdata
+            #self.raw_data = <double*>inner_data.data
+
+    def get_data(NNBF self):
+        """
+        return a copy of data added so far as a numpoints x n array
+        """
+        return self.data[:self.numpoints]
+
+
+    def find_neighbors(NNBF self, object point, double radius):
+        """
+        return a list of indices into data which are within radius
+        from point
+        """
+        cdef int i, neighbor
+        cdef double d2max
+        cdef np.ndarray[double, ndim=1] pp
+        cdef np.ndarray[double, ndim=1] row        
+
+        if len(point)!=self.n:
+            raise ValueError('Expected a length %d vector'%self.n)
+        
+        pp = np.asarray(point).astype(np.float)
+
+        d2max = radius*radius
+        neighbors = []
+
+
+        for i in range(self.numpoints):
+            # XXX : is there a more efficient way to access the row
+            # data?  Can/should we be using raw_data here?
+            row = self.data[i]
+            neighbor = is_neighbor(
+                self.n,
+                <double*>row.data,
+                <double*>pp.data,
+                d2max)
+            
+            # if the number of points in the cluster is small, the
+            # python list performance should not kill us
+            if neighbor:
+                neighbors.append(i)
+
+        return neighbors
+
+

Added: trunk/py4science/examples/pyrex/nnbf/setup.py
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/setup.py                               
(rev 0)
+++ trunk/py4science/examples/pyrex/nnbf/setup.py       2009-01-08 02:50:18 UTC 
(rev 6763)
@@ -0,0 +1,29 @@
+
+from distutils.core import setup
+
+import os
+from distutils.core import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+import numpy
+
+
+nnbf = Extension('nnbf',
+                 ['nnbf.pyx'],
+                 include_dirs = [numpy.get_include()])
+
+packages = [
+    'nnbf'
+    ]
+
+
+setup ( name = "nnbf",
+        version = "0.0000",
+        description = "incremental nearest neighbor brute force",
+        author = "John Hunter",
+        author_email = "[email protected]",
+        packages = packages,
+        ext_modules = [nnbf],
+        cmdclass    = {'build_ext': build_ext},
+
+        )

Added: trunk/py4science/examples/pyrex/nnbf/test_nnbf.py
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/test_nnbf.py                           
(rev 0)
+++ trunk/py4science/examples/pyrex/nnbf/test_nnbf.py   2009-01-08 02:50:18 UTC 
(rev 6763)
@@ -0,0 +1,88 @@
+import time
+import numpy as np
+import nose, nose.tools as nt
+import numpy.testing as nptest
+
+
+from nnbf_proto import find_neighbors_numpy
+# the pure python prototype
+
+
+#import nnbf_proto as nnbf
+
+# the cython extension
+import nnbf
+
+
+
+def jdh_add_data():
+    nn = nnbf.NNBF(6)
+
+    for i in range(202):
+        x = np.random.rand(6)
+        nn.add(x)
+        data = nn.get_data()
+        nptest.assert_equal((x==data[-1]).all(), True)
+        nptest.assert_equal(len(data), i+1)
+
+def test_neighbors():
+    NUMDIM = 4
+    nn = nnbf.NNBF(NUMDIM)
+
+    for i in range(2000):
+        x = np.random.rand(NUMDIM)
+        nn.add(x)
+
+    radius = 0.2
+    x = np.random.rand(NUMDIM)
+    ind = nn.find_neighbors(x, radius)
+    data = nn.get_data()
+
+    indnumpy = find_neighbors_numpy(data, x, radius)
+
+    nptest.assert_equal((ind==indnumpy), True)
+
+
+
+
+if 1:
+#def test_performance():
+    NUMDIM = 6
+    nn = nnbf.NNBF(NUMDIM)
+
+    print 'loading data... this could take a while'
+    # this could take a while
+    for i in range(200000):
+        x = np.random.rand(NUMDIM)
+        nn.add(x)
+
+    x = np.random.rand(NUMDIM)
+    radius = 0.2
+    data = nn.get_data()
+
+    print 'testing nnbf...'
+    times = np.zeros(10)
+    for i in range(len(times)): 
+        start = time.clock()
+        ind = nn.find_neighbors(x, radius)
+        end = time.clock()
+        times[i] = end-start
+    print '    10 trials: mean=%1.4f, min=%1.4f'%(times.mean(), times.min())
+
+    print 'testing numpy...'
+    for i in range(len(times)):     
+        start = time.clock()
+        ind = find_neighbors_numpy(data, x, radius)
+        end = time.clock() 
+        times[i] = end-start
+    print '    10 trials: mean=%1.4f, min=%1.4f'%(times.mean(), times.min())   
     
+
+
+
+
+
+if __name__=='__main__':
+
+    #nose.runmodule(argv=['-s','--with-doctest'], exit=False)
+    pass
+


This was sent by the SourceForge.net collaborative development platform, the 
world's largest Open Source development site.

------------------------------------------------------------------------------
Check out the new SourceForge.net Marketplace.
It is the best place to buy or sell services for
just about anything Open Source.
http://p.sf.net/sfu/Xq1LFB
_______________________________________________
Matplotlib-checkins mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins

Reply via email to