"""
Copyright (C) 2008, Andre Weideman
"""

import numpy as np
from numpy.fft import fft, fftshift
from numpy import pi, sqrt, tan

def cef(z,N=100):
    """Approximate the complex error function.

    A series approximation of

    w(z) = exp(-z^2) erfc(-z)

    is calculated.

    Parameters
    ----------
    z : complex ndarray
        Im(z) > 0 or Im(z) = 0
    N : int
        Number of terms in the rational series.

    """
    M,M2 = 2*N,4*N
    k = np.arange(-M+1,M)                  # M2 is nr. of sampling points
    L = sqrt(N/sqrt(2))                    # Optimal choice of L.
    theta = k*pi/M
    t = L*tan(theta/2.)
    f = np.hstack([0, np.exp(-t**2) * (L**2+t**2)]) # Function to be transformed
    a = fft(fftshift(f))/M2                # Coefficients of transform.
    a = a[1:N+1][::-1]                     # Reorder coefficients.
    Z = (L+1j*z)/(L-1j*z)
    p = np.polyval(a,Z)

    return 2*p/(L-1j*z)**2+(1/sqrt(pi))/(L-1j*z) # Evaluate w(z).

if __name__ == "__main__":
    from matplotlib import pyplot as plt

    N = 64
    step = 0.005

    yy,xx = np.mgrid[0:4:step, 0:4:step]
    x = xx[0]
    y = yy[:,0]

    z = xx + 1j*yy
    w = cef(z,N)
    ww = 2*np.exp(-z**2)-w
    ww = ww[::-1]
    W = np.log(np.abs(np.vstack((ww,w))))

    plt.contourf(W,40,extent=[0,4,-4,4])
    plt.axis('image')
    plt.title('Level curves of the Complex Error Function, N=%s' % N)
    plt.xlabel('Real')
    plt.ylabel('Imag')
    plt.show()
