following my own post, I attach an upgraded version of the sankey.py script to draw single-direction Sankey diagrams. It now supports multiple inputs and double-sided diagrams (see attached example). Let me know if you find it useful, and/or have any comment.

   .~.   Yannick COPIN  (o:>*  Doctus cum libro
   /V\   Institut de physique nucleaire de Lyon (IN2P3 - France)
  // \\  Tel: (33/0) 472 431 968     AIM: YnCopin ICQ: 236931013
 /(   )\ http://snovae.in2p3.fr/ycopin/
#!/usr/bin/env python

__author__ = "Yannick Copin <yco...@ipnl.in2p3.fr>"
__version__ = "Time-stamp: <10/02/2010 16:49 yco...@lyopc548.in2p3.fr>"

import numpy as N

def sankey(ax,
           outputs=[100.], outlabels=None,
           inputs=[100.], inlabels='',
           dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
    """Draw a Sankey diagram.

    outputs: array of outputs, should sum up to 100%
    outlabels: output labels (same length as outputs),
               or None (use default labels) or '' (no labels)
    inputs and inlabels: similar for inputs
    dx: horizontal elongation
    dy: vertical elongation
    outangle: output arrow angle [deg]
    w: output arrow shoulder
    inangle: input dip angle
    offset: text offset
    **kwargs: propagated to Patch (e.g. fill=False)

    Return (patch,[intexts,outtexts])."""

    import matplotlib.patches as mpatches
    from matplotlib.path import Path

    outs = N.absolute(outputs)
    outsigns = N.sign(outputs)
    outsigns[-1] = 0                       # Last output

    ins = N.absolute(inputs)
    insigns = N.sign(inputs)
    insigns[0] = 0                         # First input

    assert sum(outs)==100, "Outputs don't sum up to 100%"
    assert sum(ins)==100, "Inputs don't sum up to 100%"

    def add_output(path, loss, sign=1):
        h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height
        move,(x,y) = path[-1]           # Use last point as reference
        if sign==0:                     # Final loss (horizontal)
                         (Path.LINETO,[x+dx+h,y-loss/2]), # Tip
        else:                           # Intermediate loss (vertical)
                         (Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip

    def add_input(path, gain, sign=1):
        h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth
        move,(x,y) = path[-1]           # Use last point as reference
        if sign==0:                     # First gain (horizontal)
                         (Path.LINETO,[x-dx+h,y+gain/2]), # Dip
            xd,yd = path[-2][1]         # Dip position
        else:                           # Intermediate gain (vertical)
                         (Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
            xd,yd = path[-4][1]         # Dip position

    outtips = []                        # Output arrow tip dir. and positions
    urpath = [(Path.MOVETO,[0,100])]    # 1st point of upper right path
    lrpath = [(Path.LINETO,[0,0])]      # 1st point of lower right path
    for loss,sign in zip(outs,outsigns):
        add_output(sign>=0 and urpath or lrpath, loss, sign=sign)

    indips = []                         # Input arrow tip dir. and positions
    llpath = [(Path.LINETO,[0,0])]      # 1st point of lower left path
    ulpath = [(Path.MOVETO,[0,100])]    # 1st point of upper left path
    for gain,sign in zip(ins,insigns)[::-1]:
        add_input(sign<=0 and llpath or ulpath, gain, sign=sign)

    def revert(path):
        """A path is not just revertable by path[::-1] because of Bezier
        rpath = []
        nextmove = Path.LINETO
        for move,pos in path[::-1]:
            nextmove = move
        return rpath

    # Concatenate subpathes in correct order
    path = urpath + revert(lrpath) + llpath + revert(ulpath)
    codes,verts = zip(*path)
    verts = N.array(verts)

    # Path patch
    path = Path(verts,codes)
    patch = mpatches.PathPatch(path, **kwargs)

    if False:                           # DEBUG
        print "urpath", urpath
        print "lrpath", revert(lrpath)
        print "llpath", llpath
        print "ulpath", revert(ulpath)

        xs,ys = zip(*verts)

    # Labels

    def set_labels(labels,values):
        """Set or check labels according to values."""
        if labels=='':                   # No labels
            return labels
        elif labels is None:             # Default labels
            return [ '%2d%%' % val for val in values ]
            assert len(labels)==len(values)
            return labels

    def put_labels(labels,positions,output=True):
        """Put labels to positions."""
        texts = []
        lbls = output and labels or labels[::-1]
        for i,label in enumerate(lbls):
            s,(x,y) = positions[i]      # Label direction and position
            if s==0:
                t = ax.text(x+offset,y,label,
                            ha=output and 'left' or 'right', va='center')
            elif s>0:
                t = ax.text(x,y+offset,label, ha='center', va='bottom')
                t = ax.text(x,y-offset,label, ha='center', va='top')
        return texts

    outlabels = set_labels(outlabels, outs)
    outtexts = put_labels(outlabels, outtips, output=True)

    inlabels = set_labels(inlabels, ins)
    intexts = put_labels(inlabels, indips, output=False)

    # Axes management
    ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
    ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
    ax.set_aspect('equal', adjustable='datalim')

    return patch,[intexts,outtexts]

if __name__=='__main__':

    import matplotlib.pyplot as P

    outputs = [10.,-20.,5.,15.,-10.,40.]
    outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
    outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]

    inputs = [60.,-25.,15.]

    fig = P.figure()
    ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
                         title="Sankey diagram"

    patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
                                      inputs=inputs, inlabels=None,
                                      fc='g', alpha=0.2)


<<inline: sankey.png>>

SOLARIS 10 is the OS for Data Centers - provides features such as DTrace,
Predictive Self Healing and Award Winning ZFS. Get Solaris 10 NOW
Matplotlib-users mailing list

Reply via email to