On 22/01/2012 20:01, Ondřej Čertík wrote:
On Sun, Jan 22, 2012 at 3:13 AM, Sebastian Haase<[email protected]> wrote:
How does the algorithm and timing compare to this one:
http://code.google.com/p/priithon/source/browse/Priithon/mandel.py?spec=svna6117f5e81ec00abcfb037f0f9da2937bb2ea47f&r=a6117f5e81ec00abcfb037f0f9da2937bb2ea47f
The author of original version is Dan Goodman
# FAST FRACTALS WITH PYTHON AND NUMPY
Thanks Sebastian. This one is much faster ---- 2.7s on my laptop with
the same dimensions/iterations.
It uses a better datastructures -- it only keeps track of points that
still need to be iterated --- very clever.
If I have time, I'll try to provide an equivalent Fortran version too,
for comparison.
I spent a little while trying to optimise my algorithm using only numpy
and couldn't get it running much faster than that. Given the relatively
low number of iterations it's probably not a problem of Python
overheads, so I guess it is indeed memory access that is the problem.
One way to get round this using numexpr would be something like this.
Write f(z)=z^2+c and then f(n+1,z)=f(n,f(z)). Now try out instead of
computing z->f(z) each iteration, write down the formula for z->f(n,z)
for a few different n and use that in numexpr, e.g. z->f(2,z) or
z->(z^2+c)^2+c. This amounts to doing several iterations per step, but
it means that you'll be spending more time doing floating point ops and
less time waiting for memory operations so it might get closer to
fortran/C speeds.
Actually, my curiosity was piqued so I tried it out. On my laptop I get
that using the idea above gives a maximum speed increase for n=8, and
after that you start to get overflow errors so it runs slower. At n=8 it
runs about 4.5x faster than the original version. So if you got the same
speedup it would be running in about 0.6s compared to your fortran 0.7s.
However it's not a fair comparison as numexpr is using multiple cores
(but only about 60% peak on my dual core laptop), but still nice to see
what can be achieved with numexpr. :)
Code attached.
Dan
from numpy import *
import numexpr as ne
def mandel(n, m, itermax, xmin, xmax, ymin, ymax):
ix, iy = mgrid[0:n, 0:m]
x = linspace(xmin, xmax, n)[ix]
y = linspace(ymin, ymax, m)[iy]
c = x+complex(0,1)*y
del x, y # save a bit of memory, we only need z
img = zeros(c.shape, dtype=int)
ix.shape = n*m
iy.shape = n*m
c.shape = n*m
z = copy(c)
for i in xrange(itermax):
if not len(z): break # all points have escaped
multiply(z, z, z)
add(z, c, z)
rem = abs(z)>2.0
img[ix[rem], iy[rem]] = i+1
rem = -rem
z = z[rem]
ix, iy = ix[rem], iy[rem]
c = c[rem]
return img
def nemandel(n, m, itermax, xmin, xmax, ymin, ymax,
depth=1):
expr = 'z**2+c'
for _ in xrange(depth-1):
expr = '({expr})**2+c'.format(expr=expr)
itermax = itermax/depth
print 'Expression used:', expr
ix, iy = mgrid[0:n, 0:m]
x = linspace(xmin, xmax, n)[ix]
y = linspace(ymin, ymax, m)[iy]
c = x+complex(0,1)*y
del x, y # save a bit of memory, we only need z
img = zeros(c.shape, dtype=int)
ix.shape = n*m
iy.shape = n*m
c.shape = n*m
z = copy(c)
for i in xrange(itermax):
if not len(z): break # all points have escaped
z = ne.evaluate(expr)
rem = abs(z)>2.0
img[ix[rem], iy[rem]] = i+1
rem = -rem
z = z[rem]
ix, iy = ix[rem], iy[rem]
c = c[rem]
img[img==0] = itermax+1
return img
if __name__=='__main__':
from pylab import *
import time
doplot = True
args = (1000, 1000, 100, -2, .5, -1.25, 1.25)
start = time.time()
I = mandel(*args)
print 'Mandel time taken:', time.time()-start
start = time.time()
I2 = nemandel(*args)
print 'Nemandel time taken:', time.time()-start
start = time.time()
I3 = nemandel(*args, depth=2)
print 'Nemandel 2 time taken:', time.time()-start
start = time.time()
I4 = nemandel(*args, depth=3)
print 'Nemandel 3 time taken:', time.time()-start
for d in xrange(4, 9):
start = time.time()
I4 = nemandel(*args, depth=d)
print 'Nemandel', d, 'time taken:', time.time()-start
if doplot:
subplot(221)
img = imshow(I.T, origin='lower left')
subplot(222)
img = imshow(I2.T, origin='lower left')
subplot(223)
img = imshow(I3.T, origin='lower left')
subplot(224)
img = imshow(I4.T, origin='lower left')
show()
_______________________________________________
NumPy-Discussion mailing list
[email protected]
http://mail.scipy.org/mailman/listinfo/numpy-discussion