On Jun 7, 2010, at 11:16 AM, Ian Stokes-Rees wrote:
What you're doing sounds very similar to a Hinton diagrom (or at least the resulting image looks similar). There's an example of plotting such a diagram in the scipy cookbook: The implementation is pretty slow because it loops through the data and draws each square one by one. I wrote a faster alternative a while back (see attached). It uses a custom PolyCollection, which uses data units for the areas instead of figure units. Also, I just noticed there's another implementation of Hinton diagrams in the matplotlib examples folder (examples/api/hinton_demo.py). For some reason, this example doesn't appear on the website, otherwise I'd link to it. I believe the difference between your plot and a hinton diagram is that you have a different metric for calculating the size of the squares. -Tony |
import numpy as np import matplotlib.pyplot as plt import matplotlib.collections as collections import matplotlib.transforms as transforms
class SquareCollection(collections.RegularPolyCollection):
"""Return a collection of squares."""
def __init__(self, **kwargs):
super(SquareCollection, self).__init__(4, rotation=np.pi/4., **kwargs)
def get_transform(self):
"""Return transform scaling circle areas to data space."""
ax = self.axes
pts2pixels = 72.0 / ax.figure.dpi
scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
return transforms.Affine2D().scale(scale_x, scale_y)
def hinton(inarray, max_weight=None):
"""Draw Hinton diagram for visualizing a weight matrix."""
ax = plt.gca()
ax.set_axis_bgcolor('gray')
# make sure we're working with a numpy array, not a numpy matrix
inarray = np.asarray(inarray)
height, width = inarray.shape
if max_weight is None:
max_weight = 2**np.ceil(np.log(np.max(np.abs(inarray)))/np.log(2))
weights = np.clip(inarray/max_weight, -1, 1)
rows, cols = np.mgrid[:height, :width]
pos = np.where(weights > 0)
neg = np.where(weights < 0)
for idx, color in zip([pos, neg], ['white', 'black']):
if len(idx[0]) > 0:
xy = zip(cols[idx], rows[idx])
circle_areas = np.pi / 2 * np.abs(weights[idx])
squares = SquareCollection(sizes=circle_areas,
offsets=xy, transOffset=ax.transData,
facecolor=color, edgecolor=color)
ax.add_collection(squares, autolim=True)
plt.axis('scaled')
ax.set_xlim(-0.5, width-0.5)
ax.set_ylim(height-0.5, -0.5)
ax.set_xlabel('column')
ax.set_ylabel('row')
ax.xaxis.set_ticks_position('top')
ax.xaxis.set_label_position('top')
if __name__ == '__main__':
u = lambda x: 3/4. * x**2 - 0.5 * x
A = np.array([[0, 0.5, 1], [-0.5, 0, 0.5], [-1, -0.5, 0]])
print A
hinton(A)
plt.show()
------------------------------------------------------------------------------ ThinkGeek and WIRED's GeekDad team up for the Ultimate GeekDad Father's Day Giveaway. ONE MASSIVE PRIZE to the lucky parental unit. See the prize list and enter to win: http://p.sf.net/sfu/thinkgeek-promo
_______________________________________________ Matplotlib-users mailing list [email protected] https://lists.sourceforge.net/lists/listinfo/matplotlib-users
