# Re: [Matplotlib-users] Sankey diagram

Hi,

following my own post, I attach an upgraded version of the sankey.py script to draw single-direction Sankey diagrams. It now supports multiple inputs and double-sided diagrams (see attached example). Let me know if you find it useful, and/or have any comment.

Cheers,
--
.~.   Yannick COPIN  (o:>*  Doctus cum libro
/V\   Institut de physique nucleaire de Lyon (IN2P3 - France)
// \\  Tel: (33/0) 472 431 968     AIM: YnCopin ICQ: 236931013
/(   )\ http://snovae.in2p3.fr/ycopin/
^`~'^
#!/usr/bin/env python

__author__ = "Yannick Copin <yco...@ipnl.in2p3.fr>"
__version__ = "Time-stamp: <10/02/2010 16:49 yco...@lyopc548.in2p3.fr>"

import numpy as N

def sankey(ax,
outputs=[100.], outlabels=None,
inputs=[100.], inlabels='',
dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
"""Draw a Sankey diagram.

outputs: array of outputs, should sum up to 100%
outlabels: output labels (same length as outputs),
or None (use default labels) or '' (no labels)
inputs and inlabels: similar for inputs
dx: horizontal elongation
dy: vertical elongation
outangle: output arrow angle [deg]
w: output arrow shoulder
inangle: input dip angle
offset: text offset
**kwargs: propagated to Patch (e.g. fill=False)

Return (patch,[intexts,outtexts])."""

import matplotlib.patches as mpatches
from matplotlib.path import Path

outs = N.absolute(outputs)
outsigns = N.sign(outputs)
outsigns[-1] = 0                       # Last output

ins = N.absolute(inputs)
insigns = N.sign(inputs)
insigns[0] = 0                         # First input

assert sum(outs)==100, "Outputs don't sum up to 100%"
assert sum(ins)==100, "Inputs don't sum up to 100%"

h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height
move,(x,y) = path[-1]           # Use last point as reference
if sign==0:                     # Final loss (horizontal)
path.extend([(Path.LINETO,[x+dx,y]),
(Path.LINETO,[x+dx,y+w]),
(Path.LINETO,[x+dx+h,y-loss/2]), # Tip
(Path.LINETO,[x+dx,y-loss-w]),
(Path.LINETO,[x+dx,y-loss])])
outtips.append((sign,path[-3][1]))
else:                           # Intermediate loss (vertical)
path.extend([(Path.CURVE4,[x+dx/2,y]),
(Path.CURVE4,[x+dx,y]),
(Path.CURVE4,[x+dx,y+sign*dy]),
(Path.LINETO,[x+dx-w,y+sign*dy]),
(Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
(Path.LINETO,[x+dx+loss+w,y+sign*dy]),
(Path.LINETO,[x+dx+loss,y+sign*dy]),
(Path.CURVE3,[x+dx+loss,y-sign*loss]),
(Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
outtips.append((sign,path[-5][1]))

h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth
move,(x,y) = path[-1]           # Use last point as reference
if sign==0:                     # First gain (horizontal)
path.extend([(Path.LINETO,[x-dx,y]),
(Path.LINETO,[x-dx+h,y+gain/2]), # Dip
(Path.LINETO,[x-dx,y+gain])])
xd,yd = path[-2][1]         # Dip position
indips.append((sign,[xd-h,yd]))
else:                           # Intermediate gain (vertical)
path.extend([(Path.CURVE4,[x-dx/2,y]),
(Path.CURVE4,[x-dx,y]),
(Path.CURVE4,[x-dx,y+sign*dy]),
(Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
(Path.LINETO,[x-dx-gain,y+sign*dy]),
(Path.CURVE3,[x-dx-gain,y-sign*gain]),
(Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
xd,yd = path[-4][1]         # Dip position
indips.append((sign,[xd,yd+sign*h]))

outtips = []                        # Output arrow tip dir. and positions
urpath = [(Path.MOVETO,[0,100])]    # 1st point of upper right path
lrpath = [(Path.LINETO,[0,0])]      # 1st point of lower right path
add_output(sign>=0 and urpath or lrpath, loss, sign=sign)

indips = []                         # Input arrow tip dir. and positions
llpath = [(Path.LINETO,[0,0])]      # 1st point of lower left path
ulpath = [(Path.MOVETO,[0,100])]    # 1st point of upper left path
add_input(sign<=0 and llpath or ulpath, gain, sign=sign)

def revert(path):
"""A path is not just revertable by path[::-1] because of Bezier
curves."""
rpath = []
nextmove = Path.LINETO
for move,pos in path[::-1]:
rpath.append((nextmove,pos))
nextmove = move
return rpath

# Concatenate subpathes in correct order
path = urpath + revert(lrpath) + llpath + revert(ulpath)

codes,verts = zip(*path)
verts = N.array(verts)

# Path patch
path = Path(verts,codes)
patch = mpatches.PathPatch(path, **kwargs)

if False:                           # DEBUG
print "urpath", urpath
print "lrpath", revert(lrpath)
print "llpath", llpath
print "ulpath", revert(ulpath)

xs,ys = zip(*verts)
ax.plot(xs,ys,'go-')

# Labels

def set_labels(labels,values):
"""Set or check labels according to values."""
if labels=='':                   # No labels
return labels
elif labels is None:             # Default labels
return [ '%2d%%' % val for val in values ]
else:
assert len(labels)==len(values)
return labels

def put_labels(labels,positions,output=True):
"""Put labels to positions."""
texts = []
lbls = output and labels or labels[::-1]
for i,label in enumerate(lbls):
s,(x,y) = positions[i]      # Label direction and position
if s==0:
t = ax.text(x+offset,y,label,
ha=output and 'left' or 'right', va='center')
elif s>0:
t = ax.text(x,y+offset,label, ha='center', va='bottom')
else:
t = ax.text(x,y-offset,label, ha='center', va='top')
texts.append(t)
return texts

outlabels = set_labels(outlabels, outs)
outtexts = put_labels(outlabels, outtips, output=True)

inlabels = set_labels(inlabels, ins)
intexts = put_labels(inlabels, indips, output=False)

# Axes management
ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)

return patch,[intexts,outtexts]

if __name__=='__main__':

import matplotlib.pyplot as P

outputs = [10.,-20.,5.,15.,-10.,40.]
outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]

inputs = [60.,-25.,15.]

fig = P.figure()
title="Sankey diagram"
)

patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
inputs=inputs, inlabels=None,
fc='g', alpha=0.2)
outtexts[1].set_color('r')
outtexts[-1].set_fontweight('bold')

P.show()

<<inline: sankey.png>>

------------------------------------------------------------------------------
SOLARIS 10 is the OS for Data Centers - provides features such as DTrace,
Predictive Self Healing and Award Winning ZFS. Get Solaris 10 NOW
http://p.sf.net/sfu/solaris-dev2dev
_______________________________________________
Matplotlib-users mailing list
Matplotlib-users@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/matplotlib-users