Revision: 8015
          http://matplotlib.svn.sourceforge.net/matplotlib/?rev=8015&view=rev
Author:   heeres
Date:     2009-12-10 00:03:03 +0000 (Thu, 10 Dec 2009)

Log Message:
-----------
mplot3d updates:
* Fix scatter markers
* Add facecolor support for plot_surface
* Fix XYZ-pane order drawing
* Add examples (animations, colored surface)

Modified Paths:
--------------
    trunk/matplotlib/examples/mplot3d/bars3d_demo.py
    trunk/matplotlib/examples/mplot3d/hist3d_demo.py
    trunk/matplotlib/examples/mplot3d/scatter3d_demo.py
    trunk/matplotlib/examples/mplot3d/surface3d_demo.py
    trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py
    trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py
    trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py

Added Paths:
-----------
    trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py
    trunk/matplotlib/examples/mplot3d/surface3d_demo3.py
    trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py

Modified: trunk/matplotlib/examples/mplot3d/bars3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/bars3d_demo.py    2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/examples/mplot3d/bars3d_demo.py    2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -7,8 +7,13 @@
 for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]):
     xs = np.arange(20)
     ys = np.random.rand(20)
-    ax.bar(xs, ys, zs=z, zdir='y', color=c, alpha=0.8)
 
+    # You can provide either a single color or an array. To demonstrate this,
+    # the first bar of each set will be colored cyan.
+    cs = [c] * len(xs)
+    cs[0] = 'c'
+    ax.bar(xs, ys, zs=z, zdir='y', color=cs, alpha=0.8)
+
 ax.set_xlabel('X')
 ax.set_ylabel('Y')
 ax.set_zlabel('Z')

Modified: trunk/matplotlib/examples/mplot3d/hist3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/hist3d_demo.py    2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/examples/mplot3d/hist3d_demo.py    2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -16,6 +16,7 @@
 dx = 0.5 * np.ones_like(zpos)
 dy = dx.copy()
 dz = hist.flatten()
+
 ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b')
 
 plt.show()

Added: trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py                     
        (rev 0)
+++ trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py     2009-12-10 
00:03:03 UTC (rev 8015)
@@ -0,0 +1,15 @@
+from mpl_toolkits.mplot3d import axes3d
+import matplotlib.pyplot as plt
+import numpy as np
+
+plt.ion()
+
+fig = plt.figure()
+ax = axes3d.Axes3D(fig)
+X, Y, Z = axes3d.get_test_data(0.1)
+ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5)
+
+for angle in range(0, 360):
+    ax.view_init(30, angle)
+    plt.draw()
+

Modified: trunk/matplotlib/examples/mplot3d/scatter3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -2,18 +2,17 @@
 from mpl_toolkits.mplot3d import Axes3D
 import matplotlib.pyplot as plt
 
-
 def randrange(n, vmin, vmax):
     return (vmax-vmin)*np.random.rand(n) + vmin
 
 fig = plt.figure()
 ax = Axes3D(fig)
 n = 100
-for c, zl, zh in [('r', -50, -25), ('b', -30, -5)]:
+for c, m, zl, zh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
     xs = randrange(n, 23, 32)
     ys = randrange(n, 0, 100)
     zs = randrange(n, zl, zh)
-    ax.scatter(xs, ys, zs, c=c)
+    ax.scatter(xs, ys, zs, c=c, marker=m)
 
 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')

Modified: trunk/matplotlib/examples/mplot3d/surface3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -1,5 +1,6 @@
 from mpl_toolkits.mplot3d import Axes3D
 from matplotlib import cm
+from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
 import matplotlib.pyplot as plt
 import numpy as np
 
@@ -10,7 +11,14 @@
 X, Y = np.meshgrid(X, Y)
 R = np.sqrt(X**2 + Y**2)
 Z = np.sin(R)
-ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)
+surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet,
+        linewidth=0, antialiased=False)
+ax.set_zlim3d(-1.01, 1.01)
 
+ax.w_zaxis.set_major_locator(LinearLocator(10))
+ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
+
+fig.colorbar(surf, shrink=0.5, aspect=5)
+
 plt.show()
 

Added: trunk/matplotlib/examples/mplot3d/surface3d_demo3.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/surface3d_demo3.py                        
        (rev 0)
+++ trunk/matplotlib/examples/mplot3d/surface3d_demo3.py        2009-12-10 
00:03:03 UTC (rev 8015)
@@ -0,0 +1,31 @@
+from mpl_toolkits.mplot3d import Axes3D
+from matplotlib import cm
+from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
+import matplotlib.pyplot as plt
+import numpy as np
+
+fig = plt.figure()
+ax = Axes3D(fig)
+X = np.arange(-5, 5, 0.25)
+xlen = len(X)
+Y = np.arange(-5, 5, 0.25)
+ylen = len(Y)
+X, Y = np.meshgrid(X, Y)
+R = np.sqrt(X**2 + Y**2)
+Z = np.sin(R)
+
+colortuple = ('y', 'b')
+colors = np.empty(X.shape, dtype=str)
+for y in range(ylen):
+    for x in range(xlen):
+        colors[x, y] = colortuple[(x + y) % len(colortuple)]
+
+surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
+        linewidth=0, antialiased=False)
+
+ax.set_zlim3d(-1.01, 1.01)
+ax.w_zaxis.set_major_locator(LinearLocator(10))
+ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
+
+plt.show()
+

Added: trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py                  
        (rev 0)
+++ trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py  2009-12-10 
00:03:03 UTC (rev 8015)
@@ -0,0 +1,34 @@
+from mpl_toolkits.mplot3d import axes3d
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+
+def generate(X, Y, phi):
+    R = 1 - np.sqrt(X**2 + Y**2)
+    return np.cos(2 * np.pi * X + phi) * R
+
+plt.ion()
+fig = plt.figure()
+ax = axes3d.Axes3D(fig)
+
+xs = np.linspace(-1, 1, 50)
+ys = np.linspace(-1, 1, 50)
+X, Y = np.meshgrid(xs, ys)
+Z = generate(X, Y, 0.0)
+
+wframe = None
+tstart = time.time()
+for phi in np.linspace(0, 360 / 2 / np.pi, 100):
+
+    oldcol = wframe
+
+    Z = generate(X, Y, phi)
+    wframe = ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2)
+
+    # Remove old line collection before drawing
+    if oldcol is not None:
+        ax.collections.remove(oldcol)
+
+    plt.draw()
+
+print 'FPS: %f' % (100 / (time.time() - tstart))

Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py  2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py  2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -274,6 +274,7 @@
 
     def __init__(self, *args, **kwargs):
         PatchCollection.__init__(self, *args, **kwargs)
+        self._old_draw = lambda x: PatchCollection.draw(self, x)
 
     def set_3d_properties(self, zs, zdir):
         xs, ys = zip(*self.get_offsets())
@@ -293,10 +294,15 @@
         return min(vzs)
 
     def draw(self, renderer):
-        PatchCollection.draw(self, renderer)
+        self._old_draw(renderer)
 
 def patch_collection_2d_to_3d(col, zs=0, zdir='z'):
     """Convert a PatchCollection to a Patch3DCollection object."""
+
+    # The tricky part here is that there are several classes that are
+    # derived from PatchCollection. We need to use the right draw method.
+    col._old_draw = col.draw
+
     col.__class__ = Patch3DCollection
     col.set_3d_properties(zs, zdir)
 

Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -13,7 +13,7 @@
 from matplotlib.transforms import Bbox
 from matplotlib import collections
 import numpy as np
-from matplotlib.colors import Normalize, colorConverter
+from matplotlib.colors import Normalize, colorConverter, LightSource
 
 import art3d
 import proj3d
@@ -37,6 +37,21 @@
     """
 
     def __init__(self, fig, rect=None, *args, **kwargs):
+        '''
+        Build an :class:`Axes3D` instance in
+        :class:`~matplotlib.figure.Figure` *fig* with
+        *rect=[left, bottom, width, height]* in
+        :class:`~matplotlib.figure.Figure` coordinates
+
+        Optional keyword arguments:
+
+          ================   =========================================
+          Keyword            Description
+          ================   =========================================
+          *azim*             Azimuthal viewing angle (default -60)
+          *elev*             Elevation viewing angle (default 30)
+        '''
+
         if rect is None:
             rect = [0.0, 0.0, 1.0, 1.0]
         self.fig = fig
@@ -146,9 +161,12 @@
         for i, (z, patch) in enumerate(zlist):
             patch.zorder = i
 
-        self.w_xaxis.draw(renderer)
-        self.w_yaxis.draw(renderer)
-        self.w_zaxis.draw(renderer)
+        axes = (self.w_xaxis, self.w_yaxis, self.w_zaxis)
+        for ax in axes:
+            ax.draw_pane(renderer)
+        for ax in axes:
+            ax.draw(renderer)
+
         Axes.draw(self, renderer)
 
     def get_axis_position(self):
@@ -322,8 +340,9 @@
         self.grid(rcParams['axes3d.grid'])
 
     def _button_press(self, event):
-        self.button_pressed = event.button
-        self.sx, self.sy = event.xdata, event.ydata
+        if event.inaxes == self:
+            self.button_pressed = event.button
+            self.sx, self.sy = event.xdata, event.ydata
 
     def _button_release(self, event):
         self.button_pressed = None
@@ -565,6 +584,12 @@
         *cstride*   Array column stride (step size)
         *color*     Color of the surface patches
         *cmap*      A colormap for the surface patches.
+        *facecolors* Face colors for the individual patches
+        *norm*      An instance of Normalize to map values to colors
+        *vmin*      Minimum value to map
+        *vmax*      Maximum value to map
+        *shade*     Whether to shade the facecolors, default:
+                    false when cmap specified, true otherwise
         ==========  ================================================
         '''
 
@@ -575,13 +600,28 @@
         rstride = kwargs.pop('rstride', 10)
         cstride = kwargs.pop('cstride', 10)
 
-        color = kwargs.pop('color', 'b')
-        color = np.array(colorConverter.to_rgba(color))
+        if 'facecolors' in kwargs:
+            fcolors = kwargs.pop('facecolors')
+        else:
+            color = np.array(colorConverter.to_rgba(kwargs.pop('color', 'b')))
+            fcolors = None
+
         cmap = kwargs.get('cmap', None)
+        norm = kwargs.pop('norm', None)
+        vmin = kwargs.pop('vmin', None)
+        vmax = kwargs.pop('vmax', None)
+        linewidth = kwargs.get('linewidth', None)
+        shade = kwargs.pop('shade', cmap is None)
+        lightsource = kwargs.pop('lightsource', None)
 
+        # Shade the data
+        if shade and cmap is not None and fcolors is not None:
+            fcolors = self._shade_colors_lightsource(Z, cmap, lightsource)
+
         polys = []
         normals = []
-        avgz = []
+        #colset contains the data for coloring: either average z or the 
facecolor
+        colset = []
         for rs in np.arange(0, rows-1, rstride):
             for cs in np.arange(0, cols-1, cstride):
                 ps = []
@@ -609,19 +649,38 @@
                         lastp = p
                         avgzsum += p[2]
                 polys.append(ps2)
-                avgz.append(avgzsum / len(ps2))
 
-                v1 = np.array(ps2[0]) - np.array(ps2[1])
-                v2 = np.array(ps2[2]) - np.array(ps2[0])
-                normals.append(np.cross(v1, v2))
+                if fcolors is not None:
+                    colset.append(fcolors[rs][cs])
+                else:
+                    colset.append(avgzsum / len(ps2))
 
+                # Only need vectors to shade if no cmap
+                if cmap is None and shade:
+                    v1 = np.array(ps2[0]) - np.array(ps2[1])
+                    v2 = np.array(ps2[2]) - np.array(ps2[0])
+                    normals.append(np.cross(v1, v2))
+
         polyc = art3d.Poly3DCollection(polys, *args, **kwargs)
-        if cmap is not None:
-            polyc.set_array(np.array(avgz))
-            polyc.set_linewidth(0)
+
+        if fcolors is not None:
+            if shade:
+                colset = self._shade_colors(colset, normals)
+            polyc.set_facecolors(colset)
+            polyc.set_edgecolors(colset)
+        elif cmap:
+            colset = np.array(colset)
+            polyc.set_array(colset)
+            if vmin is not None or vmax is not None:
+                polyc.set_clim(vmin, vmax)
+            if norm is not None:
+                polyc.set_norm(norm)
         else:
-            colors = self._shade_colors(color, normals)
-            polyc.set_facecolors(colors)
+            if shade:
+                colset = self._shade_colors(color, normals)
+            else:
+                colset = color
+            polyc.set_facecolors(colset)
 
         self.add_collection(polyc)
         self.auto_scale_xyz(X, Y, Z, had_data)
@@ -643,24 +702,39 @@
         return normals
 
     def _shade_colors(self, color, normals):
+        '''
+        Shade *color* using normal vectors given by *normals*.
+        *color* can also be an array of the same length as *normals*.
+        '''
+
         shade = []
         for n in normals:
-            n = n / proj3d.mod(n) * 5
+            n = n / proj3d.mod(n)
             shade.append(np.dot(n, [-1, -1, 0.5]))
 
         shade = np.array(shade)
         mask = ~np.isnan(shade)
 
        if len(shade[mask]) > 0:
-           norm = Normalize(min(shade[mask]), max(shade[mask]))
-           color = color.copy()
-           color[3] = 1
-           colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
+            norm = Normalize(min(shade[mask]), max(shade[mask]))
+            if art3d.iscolor(color):
+                color = color.copy()
+                color[3] = 1
+                colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
+            else:
+                colors = [np.array(colorConverter.to_rgba(c)) * \
+                            (0.5 + norm(v) * 0.5) \
+                            for c, v in zip(color, shade)]
         else:
-           colors = color.copy()
+            colors = color.copy()
 
         return colors
 
+    def _shade_colors_lightsource(self, data, cmap, lightsource):
+        if lightsource is None:
+            lightsource = LightSource(azdeg=135, altdeg=55)
+        return lightsource.shade(data, cmap)
+
     def plot_wireframe(self, X, Y, Z, *args, **kwargs):
         '''
         Plot a 3D wireframe.

Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-09 20:29:10 UTC 
(rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-10 00:03:03 UTC 
(rev 8015)
@@ -75,7 +75,7 @@
         maxis.XAxis.__init__(self, axes, *args, **kwargs)
         self.line = mlines.Line2D(xdata=(0, 0), ydata=(0, 0),
                                  linewidth=0.75,
-                                 color=(0,0, 0,0),
+                                 color=(0, 0, 0, 1),
                                  antialiased=True,
                            )
 
@@ -100,8 +100,8 @@
         majorLabels = [self.major.formatter(val, i) for i, val in 
enumerate(majorLocs)]
         return majorLabels, majorLocs
 
-    def get_major_ticks(self):
-        ticks = maxis.XAxis.get_major_ticks(self)
+    def get_major_ticks(self, numticks=None):
+        ticks = maxis.XAxis.get_major_ticks(self, numticks)
         for t in ticks:
             t.tick1line.set_transform(self.axes.transData)
             t.tick2line.set_transform(self.axes.transData)
@@ -132,23 +132,7 @@
         else:
             return len(text) > 4
 
-    def draw(self, renderer):
-        self.label._transform = self.axes.transData
-        renderer.open_group('axis3d')
-
-        # code from XAxis
-        majorTicks = self.get_major_ticks()
-        majorLocs = self.major.locator()
-
-        # filter locations here so that no extra grid lines are drawn
-        interval = self.get_view_interval()
-        majorLocs = [loc for loc in majorLocs if \
-                interval[0] < loc < interval[1]]
-        self.major.formatter.set_locs(majorLocs)
-        majorLabels = [self.major.formatter(val, i)
-                       for i, val in enumerate(majorLocs)]
-
-        # Determine bounds
+    def _get_coord_info(self, renderer):
         minx, maxx, miny, maxy, minz, maxz = self.axes.get_w_lims()
         mins = np.array((minx, miny, minz))
         maxs = np.array((maxx, maxy, maxz))
@@ -157,15 +141,19 @@
         mins = mins - deltas / 4.
         maxs = maxs + deltas / 4.
 
-        # Determine which planes should be visible by the avg z value
         vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
         tc = self.axes.tunit_cube(vals, renderer.M)
-        #raise RuntimeError('WTF: p1=%s'%p1)
         avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2] for \
                 p1, p2, p3, p4 in self._PLANES]
         highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)])
 
-        # Draw plane
+        return mins, maxs, centers, deltas, tc, highs
+
+    def draw_pane(self, renderer):
+        renderer.open_group('pane3d')
+
+        mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
+
         info = self._AXINFO[self.adir]
         index = info['i']
         if not highs[index]:
@@ -176,6 +164,29 @@
         self.set_pane(xys, info['color'])
         self.pane.draw(renderer)
 
+        renderer.close_group('pane3d')
+
+    def draw(self, renderer):
+        self.label._transform = self.axes.transData
+        renderer.open_group('axis3d')
+
+        # code from XAxis
+        majorTicks = self.get_major_ticks()
+        majorLocs = self.major.locator()
+
+        info = self._AXINFO[self.adir]
+        index = info['i']
+
+        # filter locations here so that no extra grid lines are drawn
+        interval = self.get_view_interval()
+        majorLocs = [loc for loc in majorLocs if \
+                interval[0] < loc < interval[1]]
+        self.major.formatter.set_locs(majorLocs)
+        majorLabels = [self.major.formatter(val, i)
+                       for i, val in enumerate(majorLocs)]
+
+        mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
+
         # Determine grid lines
         minmax = np.where(highs, maxs, mins)
 


This was sent by the SourceForge.net collaborative development platform, the 
world's largest Open Source development site.

------------------------------------------------------------------------------
Return on Information:
Google Enterprise Search pays you back
Get the facts.
http://p.sf.net/sfu/google-dev2dev
_______________________________________________
Matplotlib-checkins mailing list
Matplotlib-checkins@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins

Reply via email to