#!/usr/bin/env python

import math
import pyx
from pyx import graph, path, text, color, deco, style, unit
from pyx.graph.axis.painter import ticklength


#ax =  graph.axis.linear(min=0.0, max=1.0)
#ay =  graph.axis.linear(min=0.0, max=1.0)
#g = graph.graphxy(width=8, height=4, x=ax, y=ay)
#g.plot(graph.data.function("y(x)=sin(x)/x", min=0.0, max=1.0))
#g.dolayout()

func = lambda x:(x-3.0)*(x-5.0)*(x-7.0)
func_s = "y(x)=(x-3)*(x-5)*(x-7)"
dfunc = lambda x:( (x-5.0)*(x-7.0) + (x-3.0)*(x-7.0) + (x-3.0)*(x-5.0) )

def m2a(m):
  return 180*(math.atan(m)/math.pi)

X0=0.0
X1=10.0
X=X1-X0
Y0=func(X0) # min at endpoint
Y1=func(X1) # max at endpoint
Y=Y1-Y0

width, height = 8, 4

aspect = Y*width/(X*height)

arrow_data = []
Nx=10;Ny=10
x=X0+X/Nx/2.0
for i in range(Nx):
  y=Y0+Y/Ny/2.0
  for j in range(Ny):
    size = 1.0
    angle1 = m2a(dfunc(x)/aspect)
    angle0 = m2a(-y/10.0)
    k = math.atan( abs(y-func(x))/100.0 ) * 2 / math.pi
    angle = k*angle0 + (1-k)*angle1
    arrow_data.append( (x, y, size, angle) )
    y+=Y/Ny
  x+=X/Nx

x0, x1 = 1.5, 9 # x-ticks

for i in range(2):
  p = graph.axis.painter.regular(
    ticklength.normal, ticklength.normal, 
    basepathattrs=[deco.earrow.normal],
    titlepos=0.98, titledirection=None)
#  ticks = [graph.axis.tick.tick(x0, label="$a$"), graph.axis.tick.tick(x1, label="$b$")]
  xticks = [graph.axis.tick.tick(x0, label="$0$")]
  yticks = [graph.axis.tick.tick(func(x0), label="$y_0$")]
  g = graph.graphxy(width=width, height=height, x2=None, y2=None,
    x=graph.axis.linear(title="", min=X0, max=X1,
      manualticks=xticks,
      parter=None, painter=p),
    y=graph.axis.linear(title="", min=Y0, max=Y1, 
      manualticks=yticks, parter=None, painter=p))
  xpos, ypos = g.pos( X1, Y0 )
  g.text( xpos, ypos, "$x$" )
  # plot arrows
  g.plot( 
    graph.data.list( arrow_data, x=1, y=2, size=3, angle=4 ), 
    [ graph.style.arrow(
      linelength=0.25*unit.v_cm,
      arrowsize=0.10*unit.v_cm,
      lineattrs=[style.linewidth.THIN], arrowattrs=[] ) ])
  # initial value
  g.plot( graph.data.list( [ (x0, func(x0)), ], x=1, y=2 ) )
  if i==1:
    # plot solution
    g.plot(graph.data.function(func_s))
  
  g.finish()

  filename = "vfield%d"%i
  print filename
  g.writeEPSfile(filename)





