Unused whitespace is a pet-peeve of mine, so I tend to use bbox_inches='tight' 
when saving figures. However, when producing publications, I want figures with 
a very specific size (i.e., fit column width of page), but calling 
bbox_inches='tight' changes the figure size. Stretching to fit is out of the 
question (screws up the font sizes).

Anyway, I coded up a way to automatically choose values for subplots_adjust. My 
main goal was to tighten the borders (top, bottom, left, right). Nevertheless, 
I ended up coding up a solution to automatically adjust 'hspace' and 'wspace'.

Just curious if there's any interest in adding this functionality.


Code Notes:
* The code to tighten the subplot spacing only works for regular grids: not 
with subplots that span multiple columns/rows.
* The code to tighten up the borders is short and fairly intuitive, but the 
code for subplot spacing is pretty messy (mainly because wspace/hspace depends 
on the border spacing and the number of rows/columns).
* The code draws the figure twice to calculate subplot parameters (not sure if 
this is a big issue, but I thought it was worth mentioning).
* Just execute the file to plot some examples with random label sizes to 
demonstrate subplot adjustment.

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.transforms import TransformedBbox, Affine2D


def tight_layout(pad_inches=PAD_INCHES, h_pad_inches=None, w_pad_inches=None):
    """Adjust subplot parameters to give specified padding.
    pad_inches : float
        minimum padding between the figure edge and the edges of subplots.
    h_pad_inches, w_pad_inches : float
        minimum padding (height/width) between edges of adjacent subplots.
        Defaults to `pad_inches`.
    if h_pad_inches is None:
        h_pad_inches = pad_inches
    if w_pad_inches is None:
        w_pad_inches = pad_inches
    fig = plt.gcf()
    tight_borders(fig, pad_inches=pad_inches)
    # NOTE: border padding affects subplot spacing; tighten border first
    tight_subplot_spacing(fig, h_pad_inches, w_pad_inches)

def tight_borders(fig, pad_inches=PAD_INCHES):
    """Stretch subplot boundaries to figure edges plus padding."""
    # call draw to update the renderer and get accurate bboxes.
    bbox_original = fig.bbox_inches
    bbox_tight = _get_tightbbox(fig, pad_inches)
    # figure dimensions ordered like bbox.extents: x0, y0, x1, y1
    lengths = np.array([bbox_original.width, bbox_original.height,
                        bbox_original.width, bbox_original.height])
    whitespace = (bbox_tight.extents - bbox_original.extents) / lengths
    # border padding ordered like bbox.extents: x0, y0, x1, y1
    current_borders = np.array([fig.subplotpars.left, fig.subplotpars.bottom,
                                fig.subplotpars.right, fig.subplotpars.top])
    left, bottom, right, top = current_borders - whitespace
    fig.subplots_adjust(bottom=bottom, top=top, left=left, right=right)

def _get_tightbbox(fig, pad_inches):
    renderer = fig.canvas.get_renderer()
    bbox_inches = fig.get_tightbbox(renderer)
    return bbox_inches.padded(pad_inches)

def tight_subplot_spacing(fig, h_pad_inches, w_pad_inches):
    """Stretch subplots so adjacent subplots are separated by given padding."""
    # Zero hspace and wspace to make it easier to calculate the spacing.
    fig.subplots_adjust(hspace=0, wspace=0)
    figbox = fig.bbox_inches
    ax_bottom, ax_top, ax_left, ax_right = _get_grid_boundaries(fig)
    nrows, ncols = ax_bottom.shape
    subplots_height = fig.subplotpars.top - fig.subplotpars.bottom
    if nrows > 1:
        h_overlap_inches = ax_top[1:] - ax_bottom[:-1]
        hspace_inches = h_overlap_inches.max() + h_pad_inches
        hspace_fig_frac = hspace_inches / figbox.height
        hspace = _fig_frac_to_cell_frac(hspace_fig_frac, subplots_height, nrows)
    subplots_width = fig.subplotpars.right - fig.subplotpars.left
    if ncols > 1:
        w_overlap_inches = ax_right[:,:-1] - ax_left[:,1:]
        wspace_inches = w_overlap_inches.max() + w_pad_inches
        wspace_fig_frac = wspace_inches / figbox.width
        wspace = _fig_frac_to_cell_frac(wspace_fig_frac, subplots_width, ncols)

def _get_grid_boundaries(fig):
    """Return grid boundaries for bboxes of subplots
    ax_bottom, ax_top, ax_left, ax_right : array
        bbox cell-boundaries of subplot grid. If a subplot spans cells, the grid
        boundaries cutting through that subplot will be masked.
    nrows, ncols, n = fig.axes[0].get_geometry()
    # Initialize boundaries as masked arrays; in the future, support subplots 
    # that span multiple rows/columns, which would have masked values for grid 
    # boundaries that cut through the subplot.
    ax_bottom, ax_top, ax_left, ax_right = [np.ma.masked_all((nrows, ncols))
                                            for n in range(4)]
    renderer = fig.canvas.get_renderer()
    px2inches_trans = Affine2D().scale(1./fig.dpi)
    for ax in fig.axes:
        ax_bbox = ax.get_tightbbox(renderer)
        x0, y0, x1, y1 = TransformedBbox(ax_bbox, px2inches_trans).extents
        nrows, ncols, n = ax.get_geometry()
        # subplot number starts at 1, matrix index starts at 0
        i = n - 1
        ax_bottom.flat[i] = y0
        ax_top.flat[i] = y1
        ax_left.flat[i] = x0
        ax_right.flat[i] = x1
    return ax_bottom, ax_top, ax_left, ax_right

def _fig_frac_to_cell_frac(fig_frac, subplots_frac, num_cells):
    """Return fraction of cell (row/column) from a given fraction of the figure
    fig_frac : float
        length given as a fraction of figure height or width
    subplots_frac : float
        fraction of figure (height or width) occupied by subplots
    num_cells : int
        number of rows or columns.
    # This function is reverse engineered from the calculation of `sepH` and 
    # `sepW` in  `GridSpecBase.get_grid_positions`.
    return (fig_frac * num_cells) / (subplots_frac - fig_frac*(num_cells-1))

if __name__ == '__main__':
    import random
    fontsizes = [8, 16, 24, 32]
    def example_plot(ax):
        ax.plot([1, 2])
        ax.set_xlabel('x-label', fontsize=random.choice(fontsizes))
        ax.set_ylabel('y-label', fontsize=random.choice(fontsizes))
        ax.set_title('Title', fontsize=random.choice(fontsizes))

    fig, ax = plt.subplots()
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
    fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)
    fig, axes = plt.subplots(nrows=3, ncols=3)
    for row in axes:
        for ax in row:

This SF.net email is sponsored by Sprint
What will you do first with EVO, the first 4G phone?
Visit sprint.com/first -- http://p.sf.net/sfu/sprint-com-first
Matplotlib-devel mailing list

Reply via email to