Revision: 8015 http://matplotlib.svn.sourceforge.net/matplotlib/?rev=8015&view=rev Author: heeres Date: 2009-12-10 00:03:03 +0000 (Thu, 10 Dec 2009)
Log Message: ----------- mplot3d updates: * Fix scatter markers * Add facecolor support for plot_surface * Fix XYZ-pane order drawing * Add examples (animations, colored surface) Modified Paths: -------------- trunk/matplotlib/examples/mplot3d/bars3d_demo.py trunk/matplotlib/examples/mplot3d/hist3d_demo.py trunk/matplotlib/examples/mplot3d/scatter3d_demo.py trunk/matplotlib/examples/mplot3d/surface3d_demo.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py Added Paths: ----------- trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py trunk/matplotlib/examples/mplot3d/surface3d_demo3.py trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py Modified: trunk/matplotlib/examples/mplot3d/bars3d_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/bars3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/examples/mplot3d/bars3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -7,8 +7,13 @@ for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]): xs = np.arange(20) ys = np.random.rand(20) - ax.bar(xs, ys, zs=z, zdir='y', color=c, alpha=0.8) + # You can provide either a single color or an array. To demonstrate this, + # the first bar of each set will be colored cyan. + cs = [c] * len(xs) + cs[0] = 'c' + ax.bar(xs, ys, zs=z, zdir='y', color=cs, alpha=0.8) + ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') Modified: trunk/matplotlib/examples/mplot3d/hist3d_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/hist3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/examples/mplot3d/hist3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -16,6 +16,7 @@ dx = 0.5 * np.ones_like(zpos) dy = dx.copy() dz = hist.flatten() + ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b') plt.show() Added: trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py (rev 0) +++ trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -0,0 +1,15 @@ +from mpl_toolkits.mplot3d import axes3d +import matplotlib.pyplot as plt +import numpy as np + +plt.ion() + +fig = plt.figure() +ax = axes3d.Axes3D(fig) +X, Y, Z = axes3d.get_test_data(0.1) +ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5) + +for angle in range(0, 360): + ax.view_init(30, angle) + plt.draw() + Modified: trunk/matplotlib/examples/mplot3d/scatter3d_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -2,18 +2,17 @@ from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt - def randrange(n, vmin, vmax): return (vmax-vmin)*np.random.rand(n) + vmin fig = plt.figure() ax = Axes3D(fig) n = 100 -for c, zl, zh in [('r', -50, -25), ('b', -30, -5)]: +for c, m, zl, zh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]: xs = randrange(n, 23, 32) ys = randrange(n, 0, 100) zs = randrange(n, zl, zh) - ax.scatter(xs, ys, zs, c=c) + ax.scatter(xs, ys, zs, c=c, marker=m) ax.set_xlabel('X Label') ax.set_ylabel('Y Label') Modified: trunk/matplotlib/examples/mplot3d/surface3d_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -1,5 +1,6 @@ from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm +from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter import matplotlib.pyplot as plt import numpy as np @@ -10,7 +11,14 @@ X, Y = np.meshgrid(X, Y) R = np.sqrt(X**2 + Y**2) Z = np.sin(R) -ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet) +surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet, + linewidth=0, antialiased=False) +ax.set_zlim3d(-1.01, 1.01) +ax.w_zaxis.set_major_locator(LinearLocator(10)) +ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f')) + +fig.colorbar(surf, shrink=0.5, aspect=5) + plt.show() Added: trunk/matplotlib/examples/mplot3d/surface3d_demo3.py =================================================================== --- trunk/matplotlib/examples/mplot3d/surface3d_demo3.py (rev 0) +++ trunk/matplotlib/examples/mplot3d/surface3d_demo3.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -0,0 +1,31 @@ +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import cm +from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter +import matplotlib.pyplot as plt +import numpy as np + +fig = plt.figure() +ax = Axes3D(fig) +X = np.arange(-5, 5, 0.25) +xlen = len(X) +Y = np.arange(-5, 5, 0.25) +ylen = len(Y) +X, Y = np.meshgrid(X, Y) +R = np.sqrt(X**2 + Y**2) +Z = np.sin(R) + +colortuple = ('y', 'b') +colors = np.empty(X.shape, dtype=str) +for y in range(ylen): + for x in range(xlen): + colors[x, y] = colortuple[(x + y) % len(colortuple)] + +surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors, + linewidth=0, antialiased=False) + +ax.set_zlim3d(-1.01, 1.01) +ax.w_zaxis.set_major_locator(LinearLocator(10)) +ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f')) + +plt.show() + Added: trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py (rev 0) +++ trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -0,0 +1,34 @@ +from mpl_toolkits.mplot3d import axes3d +import matplotlib.pyplot as plt +import numpy as np +import time + +def generate(X, Y, phi): + R = 1 - np.sqrt(X**2 + Y**2) + return np.cos(2 * np.pi * X + phi) * R + +plt.ion() +fig = plt.figure() +ax = axes3d.Axes3D(fig) + +xs = np.linspace(-1, 1, 50) +ys = np.linspace(-1, 1, 50) +X, Y = np.meshgrid(xs, ys) +Z = generate(X, Y, 0.0) + +wframe = None +tstart = time.time() +for phi in np.linspace(0, 360 / 2 / np.pi, 100): + + oldcol = wframe + + Z = generate(X, Y, phi) + wframe = ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2) + + # Remove old line collection before drawing + if oldcol is not None: + ax.collections.remove(oldcol) + + plt.draw() + +print 'FPS: %f' % (100 / (time.time() - tstart)) Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -274,6 +274,7 @@ def __init__(self, *args, **kwargs): PatchCollection.__init__(self, *args, **kwargs) + self._old_draw = lambda x: PatchCollection.draw(self, x) def set_3d_properties(self, zs, zdir): xs, ys = zip(*self.get_offsets()) @@ -293,10 +294,15 @@ return min(vzs) def draw(self, renderer): - PatchCollection.draw(self, renderer) + self._old_draw(renderer) def patch_collection_2d_to_3d(col, zs=0, zdir='z'): """Convert a PatchCollection to a Patch3DCollection object.""" + + # The tricky part here is that there are several classes that are + # derived from PatchCollection. We need to use the right draw method. + col._old_draw = col.draw + col.__class__ = Patch3DCollection col.set_3d_properties(zs, zdir) Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -13,7 +13,7 @@ from matplotlib.transforms import Bbox from matplotlib import collections import numpy as np -from matplotlib.colors import Normalize, colorConverter +from matplotlib.colors import Normalize, colorConverter, LightSource import art3d import proj3d @@ -37,6 +37,21 @@ """ def __init__(self, fig, rect=None, *args, **kwargs): + ''' + Build an :class:`Axes3D` instance in + :class:`~matplotlib.figure.Figure` *fig* with + *rect=[left, bottom, width, height]* in + :class:`~matplotlib.figure.Figure` coordinates + + Optional keyword arguments: + + ================ ========================================= + Keyword Description + ================ ========================================= + *azim* Azimuthal viewing angle (default -60) + *elev* Elevation viewing angle (default 30) + ''' + if rect is None: rect = [0.0, 0.0, 1.0, 1.0] self.fig = fig @@ -146,9 +161,12 @@ for i, (z, patch) in enumerate(zlist): patch.zorder = i - self.w_xaxis.draw(renderer) - self.w_yaxis.draw(renderer) - self.w_zaxis.draw(renderer) + axes = (self.w_xaxis, self.w_yaxis, self.w_zaxis) + for ax in axes: + ax.draw_pane(renderer) + for ax in axes: + ax.draw(renderer) + Axes.draw(self, renderer) def get_axis_position(self): @@ -322,8 +340,9 @@ self.grid(rcParams['axes3d.grid']) def _button_press(self, event): - self.button_pressed = event.button - self.sx, self.sy = event.xdata, event.ydata + if event.inaxes == self: + self.button_pressed = event.button + self.sx, self.sy = event.xdata, event.ydata def _button_release(self, event): self.button_pressed = None @@ -565,6 +584,12 @@ *cstride* Array column stride (step size) *color* Color of the surface patches *cmap* A colormap for the surface patches. + *facecolors* Face colors for the individual patches + *norm* An instance of Normalize to map values to colors + *vmin* Minimum value to map + *vmax* Maximum value to map + *shade* Whether to shade the facecolors, default: + false when cmap specified, true otherwise ========== ================================================ ''' @@ -575,13 +600,28 @@ rstride = kwargs.pop('rstride', 10) cstride = kwargs.pop('cstride', 10) - color = kwargs.pop('color', 'b') - color = np.array(colorConverter.to_rgba(color)) + if 'facecolors' in kwargs: + fcolors = kwargs.pop('facecolors') + else: + color = np.array(colorConverter.to_rgba(kwargs.pop('color', 'b'))) + fcolors = None + cmap = kwargs.get('cmap', None) + norm = kwargs.pop('norm', None) + vmin = kwargs.pop('vmin', None) + vmax = kwargs.pop('vmax', None) + linewidth = kwargs.get('linewidth', None) + shade = kwargs.pop('shade', cmap is None) + lightsource = kwargs.pop('lightsource', None) + # Shade the data + if shade and cmap is not None and fcolors is not None: + fcolors = self._shade_colors_lightsource(Z, cmap, lightsource) + polys = [] normals = [] - avgz = [] + #colset contains the data for coloring: either average z or the facecolor + colset = [] for rs in np.arange(0, rows-1, rstride): for cs in np.arange(0, cols-1, cstride): ps = [] @@ -609,19 +649,38 @@ lastp = p avgzsum += p[2] polys.append(ps2) - avgz.append(avgzsum / len(ps2)) - v1 = np.array(ps2[0]) - np.array(ps2[1]) - v2 = np.array(ps2[2]) - np.array(ps2[0]) - normals.append(np.cross(v1, v2)) + if fcolors is not None: + colset.append(fcolors[rs][cs]) + else: + colset.append(avgzsum / len(ps2)) + # Only need vectors to shade if no cmap + if cmap is None and shade: + v1 = np.array(ps2[0]) - np.array(ps2[1]) + v2 = np.array(ps2[2]) - np.array(ps2[0]) + normals.append(np.cross(v1, v2)) + polyc = art3d.Poly3DCollection(polys, *args, **kwargs) - if cmap is not None: - polyc.set_array(np.array(avgz)) - polyc.set_linewidth(0) + + if fcolors is not None: + if shade: + colset = self._shade_colors(colset, normals) + polyc.set_facecolors(colset) + polyc.set_edgecolors(colset) + elif cmap: + colset = np.array(colset) + polyc.set_array(colset) + if vmin is not None or vmax is not None: + polyc.set_clim(vmin, vmax) + if norm is not None: + polyc.set_norm(norm) else: - colors = self._shade_colors(color, normals) - polyc.set_facecolors(colors) + if shade: + colset = self._shade_colors(color, normals) + else: + colset = color + polyc.set_facecolors(colset) self.add_collection(polyc) self.auto_scale_xyz(X, Y, Z, had_data) @@ -643,24 +702,39 @@ return normals def _shade_colors(self, color, normals): + ''' + Shade *color* using normal vectors given by *normals*. + *color* can also be an array of the same length as *normals*. + ''' + shade = [] for n in normals: - n = n / proj3d.mod(n) * 5 + n = n / proj3d.mod(n) shade.append(np.dot(n, [-1, -1, 0.5])) shade = np.array(shade) mask = ~np.isnan(shade) if len(shade[mask]) > 0: - norm = Normalize(min(shade[mask]), max(shade[mask])) - color = color.copy() - color[3] = 1 - colors = [color * (0.5 + norm(v) * 0.5) for v in shade] + norm = Normalize(min(shade[mask]), max(shade[mask])) + if art3d.iscolor(color): + color = color.copy() + color[3] = 1 + colors = [color * (0.5 + norm(v) * 0.5) for v in shade] + else: + colors = [np.array(colorConverter.to_rgba(c)) * \ + (0.5 + norm(v) * 0.5) \ + for c, v in zip(color, shade)] else: - colors = color.copy() + colors = color.copy() return colors + def _shade_colors_lightsource(self, data, cmap, lightsource): + if lightsource is None: + lightsource = LightSource(azdeg=135, altdeg=55) + return lightsource.shade(data, cmap) + def plot_wireframe(self, X, Y, Z, *args, **kwargs): ''' Plot a 3D wireframe. Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-09 20:29:10 UTC (rev 8014) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-10 00:03:03 UTC (rev 8015) @@ -75,7 +75,7 @@ maxis.XAxis.__init__(self, axes, *args, **kwargs) self.line = mlines.Line2D(xdata=(0, 0), ydata=(0, 0), linewidth=0.75, - color=(0,0, 0,0), + color=(0, 0, 0, 1), antialiased=True, ) @@ -100,8 +100,8 @@ majorLabels = [self.major.formatter(val, i) for i, val in enumerate(majorLocs)] return majorLabels, majorLocs - def get_major_ticks(self): - ticks = maxis.XAxis.get_major_ticks(self) + def get_major_ticks(self, numticks=None): + ticks = maxis.XAxis.get_major_ticks(self, numticks) for t in ticks: t.tick1line.set_transform(self.axes.transData) t.tick2line.set_transform(self.axes.transData) @@ -132,23 +132,7 @@ else: return len(text) > 4 - def draw(self, renderer): - self.label._transform = self.axes.transData - renderer.open_group('axis3d') - - # code from XAxis - majorTicks = self.get_major_ticks() - majorLocs = self.major.locator() - - # filter locations here so that no extra grid lines are drawn - interval = self.get_view_interval() - majorLocs = [loc for loc in majorLocs if \ - interval[0] < loc < interval[1]] - self.major.formatter.set_locs(majorLocs) - majorLabels = [self.major.formatter(val, i) - for i, val in enumerate(majorLocs)] - - # Determine bounds + def _get_coord_info(self, renderer): minx, maxx, miny, maxy, minz, maxz = self.axes.get_w_lims() mins = np.array((minx, miny, minz)) maxs = np.array((maxx, maxy, maxz)) @@ -157,15 +141,19 @@ mins = mins - deltas / 4. maxs = maxs + deltas / 4. - # Determine which planes should be visible by the avg z value vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2] tc = self.axes.tunit_cube(vals, renderer.M) - #raise RuntimeError('WTF: p1=%s'%p1) avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2] for \ p1, p2, p3, p4 in self._PLANES] highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)]) - # Draw plane + return mins, maxs, centers, deltas, tc, highs + + def draw_pane(self, renderer): + renderer.open_group('pane3d') + + mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer) + info = self._AXINFO[self.adir] index = info['i'] if not highs[index]: @@ -176,6 +164,29 @@ self.set_pane(xys, info['color']) self.pane.draw(renderer) + renderer.close_group('pane3d') + + def draw(self, renderer): + self.label._transform = self.axes.transData + renderer.open_group('axis3d') + + # code from XAxis + majorTicks = self.get_major_ticks() + majorLocs = self.major.locator() + + info = self._AXINFO[self.adir] + index = info['i'] + + # filter locations here so that no extra grid lines are drawn + interval = self.get_view_interval() + majorLocs = [loc for loc in majorLocs if \ + interval[0] < loc < interval[1]] + self.major.formatter.set_locs(majorLocs) + majorLabels = [self.major.formatter(val, i) + for i, val in enumerate(majorLocs)] + + mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer) + # Determine grid lines minmax = np.where(highs, maxs, mins) This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. ------------------------------------------------------------------------------ Return on Information: Google Enterprise Search pays you back Get the facts. http://p.sf.net/sfu/google-dev2dev _______________________________________________ Matplotlib-checkins mailing list Matplotlib-checkins@lists.sourceforge.net https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins