from fipy import *

import math
from numpy import *
from pylab import *

import matplotlib
from matplotlib import pyplot

dt = 0.2

xmax = 4.5
dx = 0.1
nx = 2*xmax/dx
##mesh = Grid2D(dx=dx, Lx=2*xmax, dy=dx, Ly=2*xmax) + [[-xmax],[-xmax]]
mesh = Grid2D(dx=dx, Lx=xmax, dy=dx, Ly=xmax)

psi = CellVariable(name="psi", mesh=mesh, hasOld=True)
analytical = CellVariable(name="analytical", mesh=mesh)

x = mesh.getCellCenters()[0]
y = mesh.getCellCenters()[1]

X = mesh.getFaceCenters()[0]
Y = mesh.getFaceCenters()[1]

psi.faceGrad.constrain(0, where=mesh.facesLeft)
C = -Y*((1.,),(0.,))

eqn1 = (TransientTerm(coeff=1, var=psi) ==
        UpwindConvectionTerm(coeff = C, var=psi)
	+ ImplicitSourceTerm(coeff = - 2*x*y*(y**2+1), var=psi))

eqn = eqn1

psi.setValue(exp((-x**2-y**2)/2.)) #ansatz satisfies BCs
analytical.setValue(exp(-(x**2 + 5.)*(y**2+1.)))

print psi[0]

raw_input("press enter to see analytical solution")
vi2 = Viewer(vars=(analytical))#,datamin=-0.25,datamax=1.25)
vi2.plot()


raw_input("press enter to see initial function")
vi1 = Viewer(vars=(psi))#,datamin=-0.25,datamax=1.25)
vi1.plot()


raw_input("press enter to initialize")

for i in range(100000):
     # move forward in time by one time step
    psi.updateOld()

# sweep until you get the desired residue
    res = 100.

    while res > 0.001:        
        res = eqn.sweep(dt=dt, solver=LinearLUSolver())
        vi1.plot()

raw_input("press enter to close")
