Revision: 6766
          http://matplotlib.svn.sourceforge.net/matplotlib/?rev=6766&view=rev
Author:   jdh2358
Date:     2009-01-08 19:42:15 +0000 (Thu, 08 Jan 2009)

Log Message:
-----------
enhanced tests and moved numpy method into pyx module

Modified Paths:
--------------
    trunk/py4science/examples/pyrex/nnbf/nnbf.pyx
    trunk/py4science/examples/pyrex/nnbf/test_nnbf.py

Modified: trunk/py4science/examples/pyrex/nnbf/nnbf.pyx
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/nnbf.pyx       2009-01-08 19:17:15 UTC 
(rev 6765)
+++ trunk/py4science/examples/pyrex/nnbf/nnbf.pyx       2009-01-08 19:42:15 UTC 
(rev 6766)
@@ -38,8 +38,8 @@
 
 
         self.n = n
-        self.numrows = 100
-        #  XXX how to create mepty as contiguous w/o copy?
+        self.numrows = 10000
+        #  XXX how to create empty as contiguous w/o copy?
         data = np.empty((self.numrows, self.n), dtype=np.float)
         self.data = np.ascontiguousarray(data, dtype=np.float)
         inner_data = self.data
@@ -53,7 +53,7 @@
         """
         cdef np.ndarray[double, ndim=2] inner_data
         cdef np.ndarray[double, ndim=1] pp
-        pp = np.asarray(point).astype(np.float)
+        pp = np.array(point).astype(np.float)
 
 
         self.data[self.numpoints] = pp
@@ -67,8 +67,8 @@
             self.data = np.ascontiguousarray(newdata, dtype=np.float)
             inner_data = self.data
             self.raw_data = <double*>inner_data.data
-            #self.raw_data = <double*>inner_data.data
 
+
     def get_data(NNBF self):
         """
         return a copy of data added so far as a numpoints x n array
@@ -99,16 +99,11 @@
 
         # don't do a python lookup inside the loop
         n = self.n
-        
+
         for i in range(self.numpoints):
-            # XXX : is there a more efficient way to access the row
-            # data?  Can/should we be using raw_data here?
-            #row = self.data[i]
             neighbor = is_neighbor(
                 n,
-                #(<double*>self.data.data)+i*n,
                 self.raw_data + i*n,
-                #dataptr + i*n,
                 <double*>pp.data,
                 d2max)
 
@@ -119,4 +114,19 @@
 
         return neighbors
 
+    def find_neighbors_numpy(self, point, radius):
+        """
+        do a plain ol numpy lookup to compare performance and output
 
+          *data* is a numpoints x numdims array
+          *point* is a numdims length vector
+          radius is the max distance distance
+
+        return an array of indices into data which are within radius
+        """
+        data = self.get_data()
+        distance = data - point
+        r = np.sqrt((distance*distance).sum(axis=1))
+        return np.nonzero(r<=radius)[0]
+
+

Modified: trunk/py4science/examples/pyrex/nnbf/test_nnbf.py
===================================================================
--- trunk/py4science/examples/pyrex/nnbf/test_nnbf.py   2009-01-08 19:17:15 UTC 
(rev 6765)
+++ trunk/py4science/examples/pyrex/nnbf/test_nnbf.py   2009-01-08 19:42:15 UTC 
(rev 6766)
@@ -4,18 +4,10 @@
 import numpy.testing as nptest
 
 
-from nnbf_proto import find_neighbors_numpy
-# the pure python prototype
-
-
-#import nnbf_proto as nnbf
-
-# the cython extension
 import nnbf
 
 
-
-def jdh_add_data():
+def test_add_data():
     nn = nnbf.NNBF(6)
 
     for i in range(202):
@@ -38,14 +30,13 @@
     ind = nn.find_neighbors(x, radius)
     data = nn.get_data()
 
-    indnumpy = find_neighbors_numpy(data, x, radius)
+    indnumpy = nn.find_neighbors_numpy(x, radius)
 
     nptest.assert_equal((ind==indnumpy), True)
 
 
 
-if 1:
-#def test_performance():
+def test_performance():
     NUMDIM = 6
     nn = nnbf.NNBF(NUMDIM)
 
@@ -61,27 +52,32 @@
 
     print 'testing nnbf...'
     times = np.zeros(10)
-    for i in range(len(times)): 
+    for i in range(len(times)):
         start = time.clock()
         ind = nn.find_neighbors(x, radius)
         end = time.clock()
         times[i] = end-start
-    print '    10 trials: mean=%1.4f, min=%1.4f'%(times.mean(), times.min())
 
+    munn = times.mean()
+    print '    10 trials: mean=%1.4f, min=%1.4f'%(munn, times.min())
+
     print 'testing numpy...'
-    for i in range(len(times)):     
+    for i in range(len(times)):
         start = time.clock()
-        ind = find_neighbors_numpy(data, x, radius)
-        end = time.clock() 
+        ind = nn.find_neighbors_numpy(x, radius)
+        end = time.clock()
         times[i] = end-start
-    print '    10 trials: mean=%1.4f, min=%1.4f'%(times.mean(), times.min())   
     
+    munumpy = times.mean()
+    print '    10 trials: mean=%1.4f, min=%1.4f'%(munumpy, times.min())
 
+    # nn should be at least 3 times faster
+    nptest.assert_equal((3*munn < munumpy), True)
 
 
 
 
 if __name__=='__main__':
 
-    #nose.runmodule(argv=['-s','--with-doctest'], exit=False)
+    nose.runmodule(argv=['-s','--with-doctest'], exit=False)
     pass
 


This was sent by the SourceForge.net collaborative development platform, the 
world's largest Open Source development site.

------------------------------------------------------------------------------
Check out the new SourceForge.net Marketplace.
It is the best place to buy or sell services for
just about anything Open Source.
http://p.sf.net/sfu/Xq1LFB
_______________________________________________
Matplotlib-checkins mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins

Reply via email to