On Fri, Jul 10, 2009 at 3:55 PM, John Schulman<[email protected]> wrote:
> After reading the "Cython for Numpy Users" page, I tried rewriting a
> function in cython, which I had previously implemented in python and
> then scipy.weave.
> It runs just as slow as the python version, which is 100 times slower
> than the C version.
> Maybe someone can tell me what is slowing it down.
> Thanks,
> John
>
> #cython_scripts.py
> from __future__ import division
> import numpy as np
> cimport numpy as np
>
> DTYPE = np.int
> ctypedef np.int_t DTYPE_t
>
> DTYPE2 = np.double
> ctypedef np.double_t DTYPE2_t
>
> cdef inline int int_max(int a, int b): return a if a >= b else b
> cdef inline int int_min(int a, int b): return a if a <= b else b
>
> def firstpass_labels(np.ndarray[DTYPE2_t,ndim=2] arr,list links,int
> s_back,double thresh):
>    cdef int n_s,n_ch
>    n_s = arr.shape[0]; n_ch = arr.shape[1]
>    cdef np.ndarray labels = np.zeros([n_s,n_ch],dtype=DTYPE)
>    cdef DTYPE_t c_label
>    c_label = 1
>    cdef int i_s, i_ch, j_s, j_ch, j_ch_ind, j_sstart
>
>    cdef np.ndarray links_arr = np.zeros([n_ch,n_ch+1],dtype=DTYPE)
>    for (source,targs) in enumerate(links):
>        links_arr[source,0] = len(targs)
>        links_arr[source,1:(len(targs)+1)] = links[source]
>
>    for i_s in range(n_s):
>        for i_ch in range(n_ch):
>            if arr[i_s,i_ch] > thresh:
>                j_sstart = int_max(0,i_s-s_back)
>                for j_s in range(j_sstart,i_s+1):
>                    for j_ch_ind in range(1,links_arr[i_ch][0]+1):
>                        j_ch = links_arr[i_ch,j_ch_ind]
>                        if labels[j_s,j_ch] != 0:
>                            labels[i_s,i_ch] = labels[j_s,j_ch]
>
>                if labels[i_s,i_ch] == 0:
>                    labels[i_s,i_ch] = c_label
>                    c_label += 1
>
>    return labels
>
>
>
> And here's the function I use to run it:
> #test_cy.py
> import time
> import numpy as np
>
> import pyximport; pyximport.install()
> import cython_scripts
>
>
>
> arr = np.random.random( (400000,3))
> s_back = 3
> thresh = .7
> links2 = [[0,1],[0,1,2],[1,2]]
> t = time.time()
> print cython_scripts.firstpass_labels(arr,links2,s_back,thresh)
> print "cython %f"%(time.time()-t)

I get a X 70-80 speedup with the following changes.  You can boost it
more by turning off boundschecking & declaring every buffer's mode to
be 'c'; see the documentation here:

http://docs.cython.org/docs/numpy_tutorial.html#tuning-indexing-further

#---------------------------------------------------------------------------
from __future__ import division
import numpy as np
cimport numpy as np

DTYPE = np.int
ctypedef np.int_t DTYPE_t

DTYPE2 = np.double
ctypedef np.double_t DTYPE2_t

cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b

def firstpass_labels(np.ndarray[DTYPE2_t,ndim=2] arr,list links,int
s_back,double thresh):
   cdef int n_s,n_ch
   n_s = arr.shape[0]; n_ch = arr.shape[1]
   # XXX: changed!
   cdef np.ndarray[DTYPE_t, ndim=2] labels = np.zeros([n_s,n_ch],dtype=DTYPE)
   cdef DTYPE_t c_label
   c_label = 1
   cdef int i_s, i_ch, j_s, j_ch, j_ch_ind, j_sstart

   #XXX: changed!
   cdef np.ndarray[DTYPE_t, ndim=2] links_arr =
np.zeros([n_ch,n_ch+1],dtype=DTYPE)
   for (source,targs) in enumerate(links):
       links_arr[source,0] = len(targs)
       links_arr[source,1:(len(targs)+1)] = links[source]

   for i_s in range(n_s):
       for i_ch in range(n_ch):
           if arr[i_s,i_ch] > thresh:
               j_sstart = int_max(0,i_s-s_back)
               for j_s in range(j_sstart,i_s+1):
                   for j_ch_ind in range(1,links_arr[i_ch,0]+1):
                       j_ch = links_arr[i_ch,j_ch_ind]
                       if labels[j_s,j_ch] != 0:
                           labels[i_s,i_ch] = labels[j_s,j_ch]

               if labels[i_s,i_ch] == 0:
                   labels[i_s,i_ch] = c_label
                   c_label += 1

   return labels
#-------------------------------------------------------------------------------------------------------------------
_______________________________________________
Cython-dev mailing list
[email protected]
http://codespeak.net/mailman/listinfo/cython-dev

Reply via email to