from pylab import axes
from math import cos, sin, atan2

def rotate(x,y,theta):
    xm = x*cos(theta)-y*sin(theta)
    ym = y*cos(theta)+x*sin(theta)
    return xm,ym
    
def arrow(x0,y0,x1,y1,size=1,ax=None,lw=1,color='k'):    
    if ax is None:
        ax = axes()

    s=0.1*size

    # Get information about real size of axis
    fw = ax.get_figure().get_figwidth()
    fh = ax.get_figure().get_figheight()
    pos = ax.get_position()
    xl = ax.get_xlim()
    xw = xl[1]-xl[0]
    yl = ax.get_ylim()
    yw = yl[1]-yl[0]
    
    # Axis units per inch
    upix = xw/(pos[2]*fw)
    upiy = yw/(pos[3]*fh)

    # A perfect square
    #ax.plot([xl[0]+upix,xl[0]+upix+upix,xl[0]+upix+upix,xl[0]+upix,xl[0]+upix],
    #        [yl[0]+upiy,yl[0]+upiy,yl[0]+upiy+upiy,yl[0]+upiy+upiy,yl[0]+upiy],'-k')

    dx = x1 - x0
    dy = y1 - y0
    l = (dx**2 + dy**2)**0.5
    ex = dx/l   # x component of unit vector
    ey = dy/l   # y component of unit vector

    theta = atan2(dy/upiy,dx/upix)

    # Get rotated coordinates for arrow head
    headx0,heady0 = rotate(-2,1,theta)
    headx1,heady1 = rotate(-2,-1,theta)

    ax.plot([x0,x1],[y0,y1],color=color,lw=lw)
    ax.plot([x1+s*headx0*upix,
             x1,
             x1+s*headx1*upix],
            [y1+s*heady0*upiy,
             y1,
             y1+s*heady1*upiy],
             color=color,lw=lw) 

if __name__ == "__main__":
    from pylab import savefig,axis,plot,show
    from pylab import arrow as pyarrow
    plot([0,8],[-0.5,1],'x')
    arrow(1,0,1,0.5,color='r')
    arrow(1,0,1.5,0.5,color='b')
    arrow(1,0,2.0,0.5,color='g')
    arrow(1,0,2.5,0.5,lw=2)
    arrow(1,0,3.0,0.5,lw=3)
    arrow(1,0,3.0,0.4)
    arrow(1,0,3.0,0.3)
    arrow(1,0,3.0,0.2)
    arrow(1,0,3.0,0.1)
    arrow(1,0,3.0,0.0)

    pyarrow(5,0,0,0.5,head_width=0.2)
    pyarrow(5,0,.5,0.5,head_width=0.2)
    pyarrow(5,0,1.0,0.5,head_width=0.2)
    pyarrow(5,0,1.5,0.5,head_width=0.2)
    pyarrow(5,0,2.0,0.5,head_width=0.2)
    pyarrow(5,0,2.0,0.4,head_width=0.2)
    pyarrow(5,0,2.0,0.3,head_width=0.2)
    pyarrow(5,0,2.0,0.2,head_width=0.2)
    pyarrow(5,0,2.0,0.1,head_width=0.2)
    pyarrow(5,0,2.0,0.0,head_width=0.2)

    #savefig('test.eps')
    show()
