On Jun 7, 2010, at 11:16 AM, Ian Stokes-Rees wrote:
What you're doing sounds very similar to a Hinton diagrom (or at least the resulting image looks similar). There's an example of plotting such a diagram in the scipy cookbook: The implementation is pretty slow because it loops through the data and draws each square one by one. I wrote a faster alternative a while back (see attached). It uses a custom PolyCollection, which uses data units for the areas instead of figure units. Also, I just noticed there's another implementation of Hinton diagrams in the matplotlib examples folder (examples/api/hinton_demo.py). For some reason, this example doesn't appear on the website, otherwise I'd link to it. I believe the difference between your plot and a hinton diagram is that you have a different metric for calculating the size of the squares. -Tony |
import numpy as np import matplotlib.pyplot as plt import matplotlib.collections as collections import matplotlib.transforms as transforms
class SquareCollection(collections.RegularPolyCollection): """Return a collection of squares.""" def __init__(self, **kwargs): super(SquareCollection, self).__init__(4, rotation=np.pi/4., **kwargs) def get_transform(self): """Return transform scaling circle areas to data space.""" ax = self.axes pts2pixels = 72.0 / ax.figure.dpi scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height return transforms.Affine2D().scale(scale_x, scale_y) def hinton(inarray, max_weight=None): """Draw Hinton diagram for visualizing a weight matrix.""" ax = plt.gca() ax.set_axis_bgcolor('gray') # make sure we're working with a numpy array, not a numpy matrix inarray = np.asarray(inarray) height, width = inarray.shape if max_weight is None: max_weight = 2**np.ceil(np.log(np.max(np.abs(inarray)))/np.log(2)) weights = np.clip(inarray/max_weight, -1, 1) rows, cols = np.mgrid[:height, :width] pos = np.where(weights > 0) neg = np.where(weights < 0) for idx, color in zip([pos, neg], ['white', 'black']): if len(idx[0]) > 0: xy = zip(cols[idx], rows[idx]) circle_areas = np.pi / 2 * np.abs(weights[idx]) squares = SquareCollection(sizes=circle_areas, offsets=xy, transOffset=ax.transData, facecolor=color, edgecolor=color) ax.add_collection(squares, autolim=True) plt.axis('scaled') ax.set_xlim(-0.5, width-0.5) ax.set_ylim(height-0.5, -0.5) ax.set_xlabel('column') ax.set_ylabel('row') ax.xaxis.set_ticks_position('top') ax.xaxis.set_label_position('top') if __name__ == '__main__': u = lambda x: 3/4. * x**2 - 0.5 * x A = np.array([[0, 0.5, 1], [-0.5, 0, 0.5], [-1, -0.5, 0]]) print A hinton(A) plt.show()
------------------------------------------------------------------------------ ThinkGeek and WIRED's GeekDad team up for the Ultimate GeekDad Father's Day Giveaway. ONE MASSIVE PRIZE to the lucky parental unit. See the prize list and enter to win: http://p.sf.net/sfu/thinkgeek-promo
_______________________________________________ Matplotlib-users mailing list Matplotlib-users@lists.sourceforge.net https://lists.sourceforge.net/lists/listinfo/matplotlib-users