Hi all, here's my first patch for matplotlib. Someone noticed at Stack
Overflow that the plot_surface function in mplot3d wasn't especially fast
for a lot of points (and small rstrides/cstrides) and using shading and a
single color. I found some parts of the code that weren't vectorized. These
are my changes so far.

Summary of changes:
1. Changed from double looping over aranges to using xrange
2. Made the normalization of the normals and their dot product with the
vector [-1,-1,0.5] to find the shading a vectorized operation.
3. Changed a list comprehension which calculated the colors using an
iterative approach to using the already built-in vectorization of the
Normalization class and using the np.outer function. The result is a numpy
array rather than a list which actually speeds up things down the line.
4. removed the corners array from plot_surface which wasn't ever used or
returned. It didn't really slow things down, but I'm thinking that it is
cruft.

For change number two, I made a separate function that generates the shades,
but feel free to move that around if you prefer.. or maybe it should be a
function that begins with a _ because it shouldn't be used externally. These
changes give varying levels of speed improvement depending on the number of
points and the rstrides/cstrides arguments. With larger numbers of points
and small rstrides/cstrides, these changes can more than halve the running
time. I have found no difference in output after my changes.

I know there is more work to be done within the plot_surface function and
I'll submit more changes soon.

Justin
Index: lib/mpl_toolkits/mplot3d/axes3d.py
===================================================================
--- lib/mpl_toolkits/mplot3d/axes3d.py	(revision 8802)
+++ lib/mpl_toolkits/mplot3d/axes3d.py	(working copy)
@@ -722,10 +722,9 @@
         normals = []
         #colset contains the data for coloring: either average z or the facecolor
         colset = []
-        for rs in np.arange(0, rows-1, rstride):
-            for cs in np.arange(0, cols-1, cstride):
+        for rs in xrange(0, rows-1, rstride):
+            for cs in xrange(0, cols-1, cstride):  
                 ps = []
-                corners = []
                 for a, ta in [(X, tX), (Y, tY), (Z, tZ)]:
                     ztop = a[rs][cs:min(cols, cs+cstride+1)]
                     zleft = ta[min(cols-1, cs+cstride)][rs:min(rows, rs+rstride+1)]
@@ -733,7 +732,6 @@
                     zbase = zbase[::-1]
                     zright = ta[cs][rs:min(rows, rs+rstride+1):]
                     zright = zright[::-1]
-                    corners.append([ztop[0], ztop[-1], zbase[0], zbase[-1]])
                     z = np.concatenate((ztop, zleft, zbase, zright))
                     ps.append(z)
 
@@ -802,18 +800,21 @@
             normals.append(np.cross(v1, v2))
         return normals
 
+    def getshades(self, normals):
+        '''
+        Find the normalized vectors of the normals and dot them with
+        the vector [-1,-1,0.5] to get the proper shadings.
+        '''
+        return np.dot(normals/np.sqrt((normals**2).sum(axis=1))[:,np.newaxis],
+               np.array([[-1],[-1],[0.5]])).squeeze()
+
     def _shade_colors(self, color, normals):
         '''
         Shade *color* using normal vectors given by *normals*.
         *color* can also be an array of the same length as *normals*.
         '''
 
-        shade = []
-        for n in normals:
-            n = n / proj3d.mod(n)
-            shade.append(np.dot(n, [-1, -1, 0.5]))
-
-        shade = np.array(shade)
+        shade = self.getshades(np.array(normals))
         mask = ~np.isnan(shade)
 
         if len(shade[mask]) > 0:
@@ -821,7 +822,7 @@
             if art3d.iscolor(color):
                 color = color.copy()
                 color[3] = 1
-                colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
+                colors = np.outer(0.5 + norm(shade) * 0.5, color)
             else:
                 colors = [np.array(colorConverter.to_rgba(c)) * \
                             (0.5 + norm(v) * 0.5) \
------------------------------------------------------------------------------
Beautiful is writing same markup. Internet Explorer 9 supports
standards for HTML5, CSS3, SVG 1.1,  ECMAScript5, and DOM L2 & L3.
Spend less time writing and  rewriting code and more time creating great
experiences on the web. Be a part of the beta today
http://p.sf.net/sfu/msIE9-sfdev2dev
_______________________________________________
Matplotlib-devel mailing list
Matplotlib-devel@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/matplotlib-devel

Reply via email to