import numpy
import ncreduce

tiny=numpy.random.rand(4,4)
small=numpy.random.rand(40,4)
big=numpy.random.rand(4000,10)
big2=numpy.random.rand(10,4000)
huge=numpy.random.rand(4000,10000)

all_arrays = [tiny, small, big, big2, huge]

all_function_names=( 'sum', 'prod', 'mean', 'std', 'var', 'max', 'min', 'all', 'any' )

for funcname in all_function_names:
    exec '''
def traditional_%s(A,axis=None):
    return A.%s(axis=axis)

def ncr_%s(A,axis=None):
    return ncreduce.%s(A,axis=axis)

all_%s = (traditional_%s,ncr_%s)
''' % tuple([funcname]*7)

def bind_axis(f,axis):
    if type(f) == list or type(f) == tuple:
        return [bind_axis(ff,axis) for ff in f]
    def internal(A):
        return f(A,axis=axis)
    return internal


def compare(arrays,functions,namebase):
    def compare1(arrays,functions,name):
        standard=functions[0]
        for A in arrays:
            result=standard(A)
            result += (result == 0)
            for f in functions[1:]:
                alt=f(A)
                alt += (alt == 0)
                reldiff = abs(result/alt)
                if numpy.any(reldiff > 1.01) or numpy.any(reldiff < .99):
                    print "%s[%s]: this differs from reference." % (name,arrays.index(A))
                    return False
        return True
    compare1(arrays,functions,namebase+': direct')
    compare1([A.T for A in arrays],functions,namebase+': transpose')
    compare1(arrays,bind_axis(functions,0),namebase+': bind_axis( . , 0)')
    compare1(arrays,bind_axis(functions,1),namebase+': bind_axis( . , 1)')
    compare1([A.T for A in arrays],bind_axis(functions,1),namebase+': transpose+bind_axis( . , 1)')
    compare1([A[2] for A in arrays],functions,namebase+': A[2]')
    compare1([A[:,2] for A in arrays],functions,namebase+': A[:,2]')

def all_comparisions():
    for funcname in all_function_names:
        exec 'compare(all_arrays,all_%s,"%s")' % (funcname,funcname)

def time_all(arrays,functions):
    def timethis(arrays,functions,name=None):
        from time import time
        times=[]
        for A in arrays:
            thesetimes=[]
            for f in functions:
                start=time()
                for i in xrange(100):
                    f(A)
                end=time()
                thesetimes.append(1000*(end-start))
            times.append(thesetimes)
        return times
    times=[]
    times.append(timethis(arrays,functions))
    times.append(timethis([A.T for A in arrays],functions,'transpose'))
    times.append(timethis(arrays,bind_axis(functions,0),'bind_axis( . , 0)'))
    times.append(timethis(arrays,bind_axis(functions,1),'bind_axis( . , 1)'))
    times.append(timethis([A.T for A in arrays],bind_axis(functions,1),'transpose+bind_axis( . , 1)'))
    times.append(timethis([A[2] for A in arrays],functions,'A[2]'))
    times.append(timethis([A[:,2] for A in arrays],functions,'A[:,2]'))
    return times

if __name__ == '__main__':
    # Running this ensures that stuff is loaded into memory (as well as a sanity check):
    all_comparisions()

    print 'Values are fold improvements: >1 means ncreduce is faster, <1 means ncreduce is slower'
    print 'Columns are different array sizes: (4x4) (40x4) (4000x10) (10x4000) (4000x10000)'
    print 'Rows are different types of reduce operation:'
    print ' A.f()'
    print ' A.T.f()'
    print ' A.f(0)'
    print ' A.f(1)'
    print ' A.T.f(1)'
    print ' A[2].f()'
    print ' A[:,2].f()'
    print
    print

    print 'For SUM function'
    times=numpy.array(time_all(all_arrays,all_sum))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For MEAN function'
    times=numpy.array(time_all(all_arrays,all_mean))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For STD function'
    times=numpy.array(time_all(all_arrays,all_std))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For VAR function'
    times=numpy.array(time_all(all_arrays,all_var))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For MAX function'
    times=numpy.array(time_all(all_arrays,all_max))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For MIN function'
    times=numpy.array(time_all(all_arrays,all_min))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For ALL function'
    # We transform the arrays into bool arrays.
    # This also insures that the arrays are mostly Trues
    # Otherwise, the advantage of early exit is too big
    times=numpy.array(time_all([A > .99 for A in all_arrays],all_all))
    print times[:,:,0]/times[:,:,1]
    print

    print 'For ANY function'
    # See comment above
    times=numpy.array(time_all([A < .99 for A in all_arrays],all_any))
    print times[:,:,0]/times[:,:,1]
    print

