from fipy import *

import math
from numpy import *
from pylab import *

import matplotlib
from matplotlib import pyplot

dt = 0.002

xmax = 4.5
dx = 0.2
nx = 2*xmax/dx
mesh = Grid2D(dx=dx, Lx=2*xmax, dy=dx, Ly=2*xmax) + [[-xmax],[-xmax]]

psi = CellVariable(name="psi", mesh=mesh, hasOld=True)
analytical = CellVariable(name="analytical", mesh=mesh)

x = mesh.getCellCenters()[0]
y = mesh.getCellCenters()[1]

X = mesh.getFaceCenters()[0]
Y = mesh.getFaceCenters()[1]


C = -Y*((1.,),(0.,))

eqn1 = (TransientTerm(coeff=1, var=psi) ==
        CentralDifferenceConvectionTerm(coeff = C, var=psi)
	+ ImplicitSourceTerm(coeff = - 2*x*y*(y**2+1), var=psi))

eqn = eqn1

psi.setValue(exp((-x**2-y**2)/2.)) #ansatz satisfies BCs
analytical.setValue(exp(-(x**2+1.)*(y**2+1.)))

print psi[0]

raw_input("press enter to see analytical solution")
vi2 = Viewer(vars=(analytical))#,datamin=-0.25,datamax=1.25)
vi2.plot()


raw_input("press enter to see initial function")
vi1 = Viewer(vars=(psi))#,datamin=-0.25,datamax=1.25)
vi1.plot()


raw_input("press enter to initialize")

for i in range(10):
     # move forward in time by one time step
     psi.updateOld()

# sweep until you get the desired residue
     res = 100.
     while res > 0.001:
         res = eqn.sweep(dt=dt, solver=GeneralSolver(iterations=20, tolerance=1e-5))
	 print res
         vi1.plot()

raw_input("press enter to close")
