Revision: 8124
          http://matplotlib.svn.sourceforge.net/matplotlib/?rev=8124&view=rev
Author:   jdh2358
Date:     2010-02-10 16:00:15 +0000 (Wed, 10 Feb 2010)

Log Message:
-----------
added Yannick Copin's updated sanke demo

Modified Paths:
--------------
    trunk/matplotlib/examples/api/sankey_demo.py

Modified: trunk/matplotlib/examples/api/sankey_demo.py
===================================================================
--- trunk/matplotlib/examples/api/sankey_demo.py        2010-02-10 15:56:34 UTC 
(rev 8123)
+++ trunk/matplotlib/examples/api/sankey_demo.py        2010-02-10 16:00:15 UTC 
(rev 8124)
@@ -1,105 +1,188 @@
 #!/usr/bin/env python
-# Time-stamp: <2010-02-10 01:49:08 ycopin>
 
-import numpy as np
-import matplotlib.pyplot as plt
-import matplotlib.patches as mpatches
-from matplotlib.path import Path
+__author__ = "Yannick Copin <yco...@ipnl.in2p3.fr>"
+__version__ = "Time-stamp: <10/02/2010 16:49 yco...@lyopc548.in2p3.fr>"
 
-def sankey(ax, losses, labels=None,
-           dx=40, dy=10, angle=45, w=3, dip=10, offset=2, **kwargs):
+import numpy as N
+
+def sankey(ax,
+           outputs=[100.], outlabels=None,
+           inputs=[100.], inlabels='',
+           dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
     """Draw a Sankey diagram.
 
-    losses: array of losses, should sum up to 100%
-    labels: loss labels (same length as losses),
-            or None (use default labels) or '' (no labels)
+    outputs: array of outputs, should sum up to 100%
+    outlabels: output labels (same length as outputs),
+               or None (use default labels) or '' (no labels)
+    inputs and inlabels: similar for inputs
     dx: horizontal elongation
     dy: vertical elongation
-    angle: arrow angle [deg]
-    w: arrow shoulder
-    dip: input dip
+    outangle: output arrow angle [deg]
+    w: output arrow shoulder
+    inangle: input dip angle
     offset: text offset
     **kwargs: propagated to Patch (e.g. fill=False)
 
-    Return (patch,texts)."""
+    Return (patch,[intexts,outtexts])."""
 
-    assert sum(losses)==100, "Input losses don't sum up to 100%"
+    import matplotlib.patches as mpatches
+    from matplotlib.path import Path
 
-    def add_loss(loss, last=False):
-        h = (loss/2+w)*np.tan(angle/180.*np.pi) # Arrow tip height
+    outs = N.absolute(outputs)
+    outsigns = N.sign(outputs)
+    outsigns[-1] = 0                       # Last output
+
+    ins = N.absolute(inputs)
+    insigns = N.sign(inputs)
+    insigns[0] = 0                         # First input
+
+    assert sum(outs)==100, "Outputs don't sum up to 100%"
+    assert sum(ins)==100, "Inputs don't sum up to 100%"
+
+    def add_output(path, loss, sign=1):
+        h = (loss/2+w)*N.tan(outangle/180.*N.pi) # Arrow tip height
         move,(x,y) = path[-1]           # Use last point as reference
-        if last:                        # Final loss (horizontal)
+        if sign==0:                     # Final loss (horizontal)
             path.extend([(Path.LINETO,[x+dx,y]),
                          (Path.LINETO,[x+dx,y+w]),
                          (Path.LINETO,[x+dx+h,y-loss/2]), # Tip
                          (Path.LINETO,[x+dx,y-loss-w]),
                          (Path.LINETO,[x+dx,y-loss])])
-            tips.append(path[-3][1])
+            outtips.append((sign,path[-3][1]))
         else:                           # Intermediate loss (vertical)
-            path.extend([(Path.LINETO,[x+dx/2,y]),
-                        (Path.CURVE3,[x+dx,y]),
-                        (Path.CURVE3,[x+dx,y+dy]),
-                        (Path.LINETO,[x+dx-w,y+dy]),
-                        (Path.LINETO,[x+dx+loss/2,y+dy+h]), # Tip
-                        (Path.LINETO,[x+dx+loss+w,y+dy]),
-                        (Path.LINETO,[x+dx+loss,y+dy]),
-                        (Path.CURVE3,[x+dx+loss,y-loss]),
-                        (Path.CURVE3,[x+dx/2+loss,y-loss])])
-            tips.append(path[-5][1])
+            path.extend([(Path.CURVE4,[x+dx/2,y]),
+                         (Path.CURVE4,[x+dx,y]),
+                         (Path.CURVE4,[x+dx,y+sign*dy]),
+                         (Path.LINETO,[x+dx-w,y+sign*dy]),
+                         (Path.LINETO,[x+dx+loss/2,y+sign*(dy+h)]), # Tip
+                         (Path.LINETO,[x+dx+loss+w,y+sign*dy]),
+                         (Path.LINETO,[x+dx+loss,y+sign*dy]),
+                         (Path.CURVE3,[x+dx+loss,y-sign*loss]),
+                         (Path.CURVE3,[x+dx/2+loss,y-sign*loss])])
+            outtips.append((sign,path[-5][1]))
 
-    tips = []                           # Arrow tip positions
-    path = [(Path.MOVETO,[0,100])]      # 1st point
-    for i,loss in enumerate(losses):
-        add_loss(loss, last=(i==(len(losses)-1)))
-    path.extend([(Path.LINETO,[0,0]),
-                 (Path.LINETO,[dip,50]), # Dip
-                 (Path.CLOSEPOLY,[0,100])])
+    def add_input(path, gain, sign=1):
+        h = (gain/2)*N.tan(inangle/180.*N.pi) # Dip depth
+        move,(x,y) = path[-1]           # Use last point as reference
+        if sign==0:                     # First gain (horizontal)
+            path.extend([(Path.LINETO,[x-dx,y]),
+                         (Path.LINETO,[x-dx+h,y+gain/2]), # Dip
+                         (Path.LINETO,[x-dx,y+gain])])
+            xd,yd = path[-2][1]         # Dip position
+            indips.append((sign,[xd-h,yd]))
+        else:                           # Intermediate gain (vertical)
+            path.extend([(Path.CURVE4,[x-dx/2,y]),
+                         (Path.CURVE4,[x-dx,y]),
+                         (Path.CURVE4,[x-dx,y+sign*dy]),
+                         (Path.LINETO,[x-dx-gain/2,y+sign*(dy-h)]), # Dip
+                         (Path.LINETO,[x-dx-gain,y+sign*dy]),
+                         (Path.CURVE3,[x-dx-gain,y-sign*gain]),
+                         (Path.CURVE3,[x-dx/2-gain,y-sign*gain])])
+            xd,yd = path[-4][1]         # Dip position
+            indips.append((sign,[xd,yd+sign*h]))
+
+    outtips = []                        # Output arrow tip dir. and positions
+    urpath = [(Path.MOVETO,[0,100])]    # 1st point of upper right path
+    lrpath = [(Path.LINETO,[0,0])]      # 1st point of lower right path
+    for loss,sign in zip(outs,outsigns):
+        add_output(sign>=0 and urpath or lrpath, loss, sign=sign)
+
+    indips = []                         # Input arrow tip dir. and positions
+    llpath = [(Path.LINETO,[0,0])]      # 1st point of lower left path
+    ulpath = [(Path.MOVETO,[0,100])]    # 1st point of upper left path
+    for gain,sign in zip(ins,insigns)[::-1]:
+        add_input(sign<=0 and llpath or ulpath, gain, sign=sign)
+
+    def revert(path):
+        """A path is not just revertable by path[::-1] because of Bezier
+        curves."""
+        rpath = []
+        nextmove = Path.LINETO
+        for move,pos in path[::-1]:
+            rpath.append((nextmove,pos))
+            nextmove = move
+        return rpath
+
+    # Concatenate subpathes in correct order
+    path = urpath + revert(lrpath) + llpath + revert(ulpath)
+    
     codes,verts = zip(*path)
-    verts = np.array(verts)
+    verts = N.array(verts)
 
     # Path patch
     path = Path(verts,codes)
     patch = mpatches.PathPatch(path, **kwargs)
     ax.add_patch(patch)
 
+    if False:                           # DEBUG
+        print "urpath", urpath
+        print "lrpath", revert(lrpath)
+        print "llpath", llpath
+        print "ulpath", revert(ulpath)
+
+        xs,ys = zip(*verts)
+        ax.plot(xs,ys,'go-')
+
     # Labels
-    if labels=='':                      # No labels
-        pass
-    elif labels is None:                # Default labels
-        labels = [ '%2d%%' % loss for loss in losses ]
-    else:
-        assert len(labels)==len(losses)
 
-    texts = []
-    for i,label in enumerate(labels):
-        x,y = tips[i]                   # Label position
-        last = (i==(len(losses)-1))
-        if last:
-            t = ax.text(x+offset,y,label, ha='left', va='center')
+    def set_labels(labels,values):
+        """Set or check labels according to values."""
+        if labels=='':                   # No labels
+            return labels
+        elif labels is None:             # Default labels
+            return [ '%2d%%' % val for val in values ]
         else:
-            t = ax.text(x,y+offset,label, ha='center', va='bottom')
-        texts.append(t)
+            assert len(labels)==len(values)
+            return labels
 
+    def put_labels(labels,positions,output=True):
+        """Put labels to positions."""
+        texts = []
+        lbls = output and labels or labels[::-1]
+        for i,label in enumerate(lbls):
+            s,(x,y) = positions[i]      # Label direction and position
+            if s==0:
+                t = ax.text(x+offset,y,label,
+                            ha=output and 'left' or 'right', va='center')
+            elif s>0:
+                t = ax.text(x,y+offset,label, ha='center', va='bottom')
+            else:
+                t = ax.text(x,y-offset,label, ha='center', va='top')
+            texts.append(t)
+        return texts
+
+    outlabels = set_labels(outlabels, outs)
+    outtexts = put_labels(outlabels, outtips, output=True)
+
+    inlabels = set_labels(inlabels, ins)
+    intexts = put_labels(inlabels, indips, output=False)
+
     # Axes management
-    ax.set_xlim(verts[:,0].min()-10, verts[:,0].max()+40)
-    ax.set_ylim(verts[:,1].min()-10, verts[:,1].max()+20)
+    ax.set_xlim(verts[:,0].min()-dx, verts[:,0].max()+dx)
+    ax.set_ylim(verts[:,1].min()-dy, verts[:,1].max()+dy)
     ax.set_aspect('equal', adjustable='datalim')
-    ax.set_xticks([])
-    ax.set_yticks([])
 
-    return patch,texts
+    return patch,[intexts,outtexts]
 
 if __name__=='__main__':
 
-    losses = [10.,20.,5.,15.,10.,40.]
-    labels = ['First','Second','Third','Fourth','Fifth','Hurray!']
-    labels = [ s+'\n%d%%' % l for l,s in zip(losses,labels) ]
+    import matplotlib.pyplot as P
 
-    fig = plt.figure()
-    ax = fig.add_subplot(1,1,1)
+    outputs = [10.,-20.,5.,15.,-10.,40.]
+    outlabels = ['First','Second','Third','Fourth','Fifth','Hurray!']
+    outlabels = [ s+'\n%d%%' % abs(l) for l,s in zip(outputs,outlabels) ]
 
-    patch,texts = sankey(ax, losses, labels, fc='g', alpha=0.2)
-    texts[1].set_color('r')
-    texts[-1].set_fontweight('bold')
+    inputs = [60.,-25.,15.]
 
-    plt.show()
+    fig = P.figure()
+    ax = fig.add_subplot(1,1,1, xticks=[],yticks=[],
+                         title="Sankey diagram"
+                         )
+
+    patch,(intexts,outtexts) = sankey(ax, outputs=outputs, outlabels=outlabels,
+                                      inputs=inputs, inlabels=None,
+                                      fc='g', alpha=0.2)
+    outtexts[1].set_color('r')
+    outtexts[-1].set_fontweight('bold')
+
+    P.show()


This was sent by the SourceForge.net collaborative development platform, the 
world's largest Open Source development site.

------------------------------------------------------------------------------
SOLARIS 10 is the OS for Data Centers - provides features such as DTrace,
Predictive Self Healing and Award Winning ZFS. Get Solaris 10 NOW
http://p.sf.net/sfu/solaris-dev2dev
_______________________________________________
Matplotlib-checkins mailing list
Matplotlib-checkins@lists.sourceforge.net
https://lists.sourceforge.net/lists/listinfo/matplotlib-checkins

Reply via email to