from pylab import axes
from numpy import dot,array
from math import cos, sin, atan2

def rotate(xy,theta):
    A = array([[cos(theta),-sin(theta)],
               [sin(theta), cos(theta)]])
    return dot(A,xy)
    
def arrow(x0,y0,x1,y1,size=1,ax=None,lw=1,color='k',tip=True,tail=False, head = None):    
    # (x0, y0)      coordinates for tail of arrow
    # (X1, y1)      coordinates for tip of arrow
    # size=1        size of arrow head
    # ax=None       axes object to draw in
    # lw=1          linewidth of arrow
    # color='k'     color of arrow
    # tip = True    draw head at tip of arrow
    # tail = False  draw head at tail of arrow
    # head = None   custom arrow head
    
    if ax is None:
        ax = axes()

    s=0.1*size

    # Get information about real size of axis
    fw = ax.get_figure().get_figwidth()
    fh = ax.get_figure().get_figheight()
    pos = ax.get_position()
    xl = ax.get_xlim()
    xw = xl[1]-xl[0]
    yl = ax.get_ylim()
    yw = yl[1]-yl[0]
    
    # Axis units per inch
    upix = xw/(pos[2]*fw)
    upiy = yw/(pos[3]*fh)

    # A perfect square
    #ax.plot([xl[0]+upix,xl[0]+upix+upix,xl[0]+upix+upix,xl[0]+upix,xl[0]+upix],
    #        [yl[0]+upiy,yl[0]+upiy,yl[0]+upiy+upiy,yl[0]+upiy+upiy,yl[0]+upiy],'-k')

    dx = x1 - x0
    dy = y1 - y0
    l = (dx**2 + dy**2)**0.5
    ex = dx/l   # x component of unit vector
    ey = dy/l   # y component of unit vector

    theta = atan2(dy/upiy,dx/upix)

    # Defined head
    # Seen relative to the tip of the arrow
    # defined as the matrix:
    # [ x0, x1, x2, ..., xn]
    # [ y0, y1, y2, ..., yn]
    if head is None:
        head = array([[-2,0,-2],
                      [ 1,0,-1]])
    
    # Get rotated head
    headm = rotate(head,theta)

    ax.plot([x0,x1],[y0,y1],color=color,lw=lw)
    
    if tip:
        ax.plot(x1+s*upix*headm[0,:], y1+s*upiy*headm[1,:],color=color,lw=lw)
    if tail:	
        ax.plot(x0-s*upix*headm[0,:], y0-s*upiy*headm[1,:],color=color,lw=lw)

def varrow(x0,y0,y1,bhead=False,ehead=True,hsx=None,hsy=None,ax=None):
    arrow(x0,y0,x0,y1,tip=ehead,tail=bhead,ax=ax)
    
def harrow(x0,x1,y0,bhead=False,ehead=True,hsx=None,hsy=None,ax=None):
    arrow(x0,y0,x1,y0,tip=ehead,tail=bhead,ax=ax)


if __name__ == "__main__":
    from pylab import savefig,axis,plot,show
    from pylab import arrow as pyarrow
    plot([0,8],[-0.5,1],'x')
    arrow(1,0,1,0.5,color='r')
    arrow(1,0,1.5,0.5,color='b')
    arrow(1,0,2.0,0.5,color='g')
    arrow(1,0,2.5,0.5,lw=2)
    arrow(1,0,3.0,0.5,lw=3)
    arrow(1,0,3.0,0.4)
    arrow(1,0,3.0,0.3)
    arrow(1,0,3.0,0.2)
    arrow(1,0,3.0,0.1)
    arrow(1,0,3.0,0.0)
    
    arrow(1,1,3.0,0.8,tip=False,tail=True)

    pyarrow(5,0,0,0.5,head_width=0.2)
    pyarrow(5,0,.5,0.5,head_width=0.2)
    pyarrow(5,0,1.0,0.5,head_width=0.2)
    pyarrow(5,0,1.5,0.5,head_width=0.2)
    pyarrow(5,0,2.0,0.5,head_width=0.2)
    pyarrow(5,0,2.0,0.4,head_width=0.2)
    pyarrow(5,0,2.0,0.3,head_width=0.2)
    pyarrow(5,0,2.0,0.2,head_width=0.2)
    pyarrow(5,0,2.0,0.1,head_width=0.2)
    pyarrow(5,0,2.0,0.0,head_width=0.2)
    
    varrow(4,0,0.5,ehead=True,bhead=True)
    harrow(3.5,4.5,0.25,ehead=True,bhead=True)
    #savefig('test.eps')
    show()
