-- Copyright (c) 2003 Matthew P. Donadio (m.p.donadio@ieee.org)
--
-- This program is free software; you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation; either version 2 of the License, or
-- (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

Recursive implementation of a Fast Fourier Transform.

TODO: Winograd Fourier Transform Algorithm (WFTA)

TODO: MR-DIT

TODO: SR-DIT

TODO: Exponent (-1) SR-DIT

TODO: Work on list/array optimization in fft_r2dit and fft_r2dif.

TODO: Work on list/array optimization in ct1 and ct2.

TODO: Algorithm for 2N-point real FFT^-1 computed with N-point complex FFT

TODO: Algorithm for 2 N-point real FFT's computed with N-point complex FFT

TODO: Lyon's book derived the 2N-point real FFT separating out the
real and imaginary parts.  The complex math version is in P&M.
arithmetic.

> module FFT2 where

> import Array
> import Complex

> import DFT

-------------------------------------------------------------------------------

> fft a | n == 1               = a
>       | n == 2               = fft'2 a
>       | n == 3               = fft'3 a
>       | n == 4               = fft'4 a
>       | is_power2 n          = fft_r2dif a
>       | l == 1               = dft a -- rader a n
>       | gcd l m == 1         = pfa a l m
>       | otherwise            = ct1 a l m
>     where l = choose_factor n
>           m = n `div` l
>           n = snd (bounds a) + 1

> is_power2 n = is_power2' n twos
>     where is_power2' n (t:ts) | n == t    = True
>			        | n <  t    = False
>			        | otherwise = is_power2' n ts
>           twos = 2 : map (2 *) twos

choose_factor is borrowed from FFTW

> choose1 n = loop1 1 1
>     where loop1 i f | i * i > n = f
>	              | (n `mod` i) == 0 && gcd i (n `div` i) == 1 = loop1 (i+1) i
>	              | otherwise = loop1 (i+1) f

> choose2 n = loop2 1 1
>     where loop2 i f | i * i > n = f
>                     | n `mod` i == 0 = loop2 (i+1) i
>	              | otherwise = loop2 (i+1) f

> choose_factor n | i > 1 = i
>		  | otherwise = choose2 n
>     where i = choose1 n

-------------------------------------------------------------------------------

This a recursive implementation of a FFT.  I believe this is
equivalent to a radix-2 decimation-in-time (DIT) FFT, which is a
special case of the Cooley-Tukey algorithm for N=2^v.

This algorithm was taken from Cormen, Leiserson, and Rivest's
_Introduction to Algorithms, and we added the hardcodes.

> fft_r2dit a | n==1      = a
>             | n==2      = fft'2 a
>             | n==4      = fft'4 a
>             | otherwise = y
>     where wn = cis (-2 * pi / fromIntegral n)
>	    w  = array (0,n2-1) ((0,1) : [ (k, w!(k-1) *  wn) | k <- [1..(n2-1)] ])
>	    a0 = listArray (0,n2-1) [ a!k | k <- [0..(n-1)], even k ]
>	    a1 = listArray (0,n2-1) [ a!k | k <- [0..(n-1)], odd k  ]
>	    y0 = fft_r2dit a0
>	    y1 = fft_r2dit a1
> 	    y  = array (0,n-1) ([ (k, y0!k + w!k * y1!k) | k <- [0..(n2-1)] ] ++ [ (k + n2, y0!k - w!k * y1!k) | k <- [0..(n2-1)] ])
>	    n  = snd (bounds a) + 1
>           n2 = n `div` 2

-------------------------------------------------------------------------------

Radix-2 Decimation-in-Frequency (DIF) FFT

> fft_r2dif a | n==1      = a
>             | n==2      = fft'2 a
>             | n==4      = fft'4 a
>             | otherwise = y
>     where wn = cis (-2 * pi / fromIntegral n)
>	    w  = array (0,n2-1) ((0,1) : [ (k, w!(k-1) *  wn) | k <- [1..(n2-1)] ])
>	    ae = listArray (0,n2-1) [  a!k + a!(k+n2)        | k <- [0..(n2-1)] ]
>	    ao = listArray (0,n2-1) [ (a!k - a!(k+n2)) * w!k | k <- [0..(n2-1)] ]
>	    ye = fft_r2dif ae
>	    yo = fft_r2dif ao
> 	    y  = listArray (0,n-1) (interleave (elems ye) (elems yo))
>           interleave []     []     = []
>           interleave (e:es) (o:os) = e : o : interleave es os
>	    n  = snd (bounds a) + 1
>           n2 = n `div` 2

-------------------------------------------------------------------------------

Cooley-Tukey algorithm

Cooley-Tukey algorithm doing row FFT's then column FFT's

> ct1 a l m | otherwise = array (0,n-1) $ zip ks (elems x')
>     where x = listArray ((0,0),(l-1,m-1)) [ a!i | i <- xs ]
>	    f = listArray ((0,0),(l-1,m-1)) (flatten_rows $ map fft $ rows x)
>	    g = listArray ((0,0),(l-1,m-1)) [ f!(i,j) * w i j | i <- [0..(l-1)], j <- [0..(m-1)] ]
>	    x' = listArray ((0,0),(l-1,m-1)) (flatten_cols l $ map fft $ cols g)
>	    w i j = cis (-2 * pi * fromIntegral (i*j) / fromIntegral n)
>	    (xs,ks) = ct_index_map1 l m
>           n = l * m

Cooley-Tukey algorithm doing column FFT's then row FFT's

> ct2 a l m = array (0,n-1) $ zip ks (elems x')
>     where x = listArray ((0,0),(l-1,m-1)) [ a!i | i <- xs ]
>	    f = listArray ((0,0),(l-1,m-1)) (flatten_cols l $ map fft $ cols x)
>	    g = listArray ((0,0),(l-1,m-1)) [ f!(i,j) * w i j | i <- [0..(l-1)], j <- [0..(m-1)] ]
>	    x' = listArray ((0,0),(l-1,m-1)) (flatten_rows $ map fft $ rows g)
>	    w i j = cis (-2 * pi * fromIntegral (i*j) / fromIntegral n)
>	    (xs,ks) = ct_index_map2 l m
>           n = l * m

> ct_index_map1 l m = (n,k)
>     where n = [ n1 + l * n2 | n1 <- [0..(l-1)], n2 <- [0..(m-1)] ]
>           k = [ m * k1 + k2 | k1 <- [0..(l-1)], k2 <- [0..(m-1)] ]

> ct_index_map2 l m = (n,k)
>     where n = [ m * n1 + n2 | n1 <- [0..(l-1)], n2 <- [0..(m-1)] ]
>           k = [ k1 + l * k2 | k1 <- [0..(l-1)], k2 <- [0..(m-1)] ]

> row_dft x i = fft $ listArray (0,m) [ x!(i,j) | j <- [0..m] ]
>     where ((_,_),(_,m)) = bounds x

> col_dft x j = fft $ listArray (0,l) [ x!(i,j) | i <- [0..l] ]
>     where ((_,_),(l,_)) = bounds x

> rows x = [ listArray (0,m) [ x!(i,j) | j <- [0..m] ] | i <- [0..l] ]
>     where ((_,_),(l,m)) = bounds x

> cols x = [ listArray (0,l) [ x!(i,j) | i <- [0..l] ] | j <- [0..m] ]
>     where ((_,_),(l,m)) = bounds x

> flatten_rows a = foldr (++) [] (map elems a)

> flatten_cols l a = foldr (++) [] (map (flatten_cols' a) [0..l])

> flatten_cols' []     i = []
> flatten_cols' (a:as) i = a!i : flatten_cols' as i

-------------------------------------------------------------------------------

Prime Factor Algorithm

> pfa a l m | otherwise = array (0,n-1) $ zip ks (elems x')
>     where x = listArray ((0,0),(l-1,m-1)) [ a!i | i <- xs ]
>	    f = listArray ((0,0),(l-1,m-1)) (flatten_rows $ map fft $ rows x)
>	    x' = listArray ((0,0),(l-1,m-1)) (flatten_cols l $ map fft $ cols f)
>	    w i j = cis (-2 * pi * fromIntegral (i*j) / fromIntegral n)
>	    (xs,ks) = pfa_index_map l m
>           n = l * m

> pfa_index_map l m = (ns,ks)
>     where ns = [ (m * n1 + l * n2) `mod` n | n1 <- [0..(l-1)], n2 <- [0..(m-1)] ]
>           ks = [ (c * m * k1 + d * l * k2) `mod` n | k1 <- [0..(l-1)], k2 <- [0..(m-1)] ]
>	    c = find_inverse m l
>	    d = find_inverse l m
>           n = l * m

> find_inverse a n = find_inverse' a n 1
>     where find_inverse' a n a' | (a*a') `mod` n == 1 = a'
>			         | otherwise = find_inverse' a n (a'+1)

-------------------------------------------------------------------------------

Rader's Algorithm

> rader f n | n <= 31   = rader2 f n
>           | otherwise = rader2 f n

Rader's Algorithm using direct convolution

> foo n = [ a ^* i | i <- [0..(n-2)] ]
>     where i ^* j = (i ^ j) `mod` n
>	    a = generator n

> rader1 f n = foo n
>     where --h = listArray (0,n-2) [ f!(a ^* (n-(1+n'))) | n' <- [0..(n-2)] ]
>           g = listArray (0,n-2) [ w (a ^* n') | n' <- [0..(n-2)] ]
>           --f' = array (0,n-1) ((0, sum [ f!i | i <- [0..(n-1)] ]) : [ (a ^* i, f!0 + sum [ h!j * g!((i-j)`mod`(n-1)) | j <- [0..(n-2)] ]) | i <- [0..(n-2)] ])
>           w i = cis (-2 * pi * fromIntegral i / fromIntegral n)
>           i ^* j = (i ^ j) `mod` n
>	    a = generator n

Rader's Algorithm using FFT convolution

> rader2 f n = f'
>     where h = listArray (0,n-2) [ f!(a ^* (n-(1+n'))) | n' <- [0..(n-2)] ]
>           g = listArray (0,n-2) [ w (a ^* n') | n' <- [0..(n-2)] ]
>	    h' = fft h
>	    g' = fft g
>           hg' = listArray (0,n-2) [ h'!i * g'!i | i <- [0..(n-2)] ]
>           hg = ifft hg'
>           f' = array (0,n-1) ((0, sum [ f!i | i <- [0..(n-1)] ]) : [ (a ^* i, f!0 + hg!i) | i <- [0..(n-2)] ])
>           w i = cis (-2 * pi * fromIntegral i / fromIntegral n)
>           i ^* j = (i ^ j) `mod` n
>	    a = generator n

Haskell translation of find_generator from FFTW

> generator p = findgen 1
>   where findgen 0 = error "rader: generator: no primative root?"
>         findgen x | (period x x) == (p - 1) = x
>                   | otherwise               = findgen ((x + 1) `mod` p)
>         period x 1    = 1
>         period x prod = 1 + (period x (prod * x `mod` p))

-------------------------------------------------------------------------------

We want to define the inverse and real valued FFT's based on the
forward complex FFT.  This way, if we implement a speedup, we only
have to do it in one place.  Personally, I don't like adding a sign
argument to the FFT for signify forward and inverse.

x(n) = 1/N * ~(fft ~X(k))
  where X(k) = fft(x(n))
        x    = conjugate x
        N    = length x

P&M and Rick Lyon's books have the derivation.

ifft a = fmap (/ fromIntegral n) $ fmap conjugate $ fft $ fmap conjugate a
  where n = snd (bounds a) + 1

We can also replace complex conjugation by swapping the real and
imaginary parts and get the same result.  Rick Lyon's book has the
derivation.

> ifft a = fmap (/ fromIntegral n) $ fmap swap $ fft $ fmap swap a
>   where swap (x:+y) = (y:+x)
>	  n = snd (bounds a) + 1

-------------------------------------------------------------------------------

This is the algorithm for computing 2N-point real FFT with an N-point
complex FFT.  This formulation is from Rick's book.

> rfft a = listArray (0,n-1) [ xa m | m <- [0..(n-1)] ]
>   where x   = fft $ listArray (0,n-1) $ rfft_unzip (elems a)
>         xpr = listArray (0,n-1) (xr!0 : [ (xr!m + xr!(n-m)) / 2 | m <- [1..(n-1)] ])
>         xmr = listArray (0,n-1) (0 :    [ (xr!m - xr!(n-m)) / 2 | m <- [1..(n-1)] ])
>         xpi = listArray (0,n-1) (xi!0 : [ (xi!m + xi!(n-m)) / 2 | m <- [1..(n-1)] ])
>         xmi = listArray (0,n-1) (0 :    [ (xi!m - xi!(n-m)) / 2 | m <- [1..(n-1)] ])
>         xr = fmap realPart x
>         xi = fmap imagPart x
>         xa m = (xpr!m + cos w * xpi!m - sin w * xmr!m) :+ 
>                (xmi!m - sin w * xpi!m - cos w * xmr!m)
>              where w = pi * fromIntegral m / fromIntegral n
>         rfft_unzip []         = []
>	  rfft_unzip (x1:x2:xs) = (x1:+x2) : rfft_unzip xs
>	  n = (snd (bounds a) + 1) `div` 2

-------------------------------------------------------------------------------

These are the hard coded DFT's borrowed from FFTW

> fft'2 a = array (0,1) [ (0, ((tmp1 + tmp2) :+ (tmp3 + tmp4))), 
>		          (1, ((tmp1 - tmp2) :+ (tmp3 - tmp4) )) ]
>     where tmp1 = realPart (a!0)
>           tmp3 = imagPart (a!0)
>           tmp2 = realPart (a!1)
>           tmp4 = imagPart (a!1)

> fft'3 a = array (0,2) [ (0, ((tmp1 + tmp4) :+ (tmp10 + tmp11))),
>		          (1, ((tmp5 + tmp8) :+ (tmp9 + tmp12))),
>		          (2, ((tmp5 - tmp8) :+ (tmp12 - tmp9))) ]
>     where k866025403 = sqrt 3 / 2
>           k500000000 = 0.5
>           tmp1  = realPart (a!0)
>	    tmp10 = imagPart (a!0)
>	    tmp2  = realPart (a!1)
>	    tmp6  = imagPart (a!1)
>	    tmp3  = realPart (a!2)
>	    tmp7  = imagPart (a!2)
>	    tmp4  = tmp2 + tmp3
>	    tmp9  = k866025403 * (tmp3 - tmp2)
>	    tmp8  = k866025403 * (tmp6 - tmp7)
>	    tmp11 = tmp6 + tmp7
>           tmp5  = tmp1 - (k500000000 * tmp4)
>           tmp12 = tmp10 - (k500000000 * tmp11)

> fft'4 a = array (0,3) [ (0, (tmp3 + tmp6) :+ (tmp15 + tmp16)), 
>		          (1, (tmp11 + tmp14) :+ (tmp9 - tmp10)), 
>		          (2, (tmp3 - tmp6) :+ (tmp15 - tmp16)), 
>		          (3, (tmp11 - tmp14) :+ (tmp10 + tmp9)) ]
>     where tmp1  = realPart (a!0)
>	    tmp7  = imagPart (a!0)
>	    tmp4  = realPart (a!1)
>	    tmp12 = imagPart (a!1)
>	    tmp2  = realPart (a!2)
>	    tmp8  = imagPart (a!2)
>	    tmp5  = realPart (a!3)
>	    tmp13 = imagPart (a!3)
>	    tmp3  = tmp1 + tmp2
>	    tmp11 = tmp1 - tmp2
>	    tmp9  = tmp7 - tmp8
>	    tmp15 = tmp7 + tmp8
>	    tmp6  = tmp4 + tmp5
>	    tmp10 = tmp4 - tmp5
>	    tmp14 = tmp12 - tmp13
>	    tmp16 = tmp12 + tmp13

-------------------------------------------------------------------------------

Test routines

> gendata :: Int -> Array Int (Complex Double)
> gendata n = listArray (0,n-1) [ ((sqrt $ fromIntegral i) :+ (log $ fromIntegral i)) | i <- [1..n] ]

> magsq (x:+y) = x*x + y*y

> mean x = sum x / (fromIntegral $ length x)

> rms :: Array Int (Complex Double) -> Array Int (Complex Double) -> Double
> rms x y = sqrt $ mean $ map magsq $ zipWith (-) (elems x) (elems y)

> testfft :: Int -> Double
> testfft n = rms (fft a) (dft a)
>      where a = gendata n
