Maybe there is some gain in using Strassen multiplication on Z or Q in gmpy
mode.
Attached is an implementation; for matrices with entries randint(0,N)
and size n, I find that for N=10, 10**3, 10**6 the speedup increases with N;
the break even point is n=190; for n=448
the speedup with Strassen in 11% for N=10, 42% for N=1000, 47% for N=10**6
On Saturday, March 17, 2012 4:09:49 PM UTC+1, Sai Nikhil wrote:
>
> Hi,
>
> Intro: Sai Nikhil, 3rd year Graduate Student. Has enough experience with
> python programming .
>
> I am following up from various topics list on Sympy-GSoC 2012 ideas page.
> I found Series, Matrices, Functions modules specifically interesting. I
> wanted to know which algorithm has been implemented in matrix_multiply
> function. Is it the Naive Algorithm ? If that were the case, then the
> running time for the code would be of the order, O(n^3). The element wise
> multiplication also takes, running time of the order, O(n^3). But Strassen
> Algorithm would be more efficient compared to this, as the running time is
> of the order, O(n^lg7) ≈ O(n^2.807). I wanted to implement it and I need
> your comments/views in regard to this.
>
> I also submitted my first pull here:
> https://github.com/sympy/sympy/pull/1130/
>
> Please go through it and tell me if I need to make any edits, so that you
> can merge it into sympy master .
>
> *-thanks,*
> *Sai Nikhil.T <http://www.tsndiffopera.in>*
> *
> *
> *1*
>
--
You received this message because you are subscribed to the Google Groups
"sympy" group.
To view this discussion on the web visit
https://groups.google.com/d/msg/sympy/-/2QKQrDh-DxsJ.
To post to this group, send email to [email protected].
To unsubscribe from this group, send email to
[email protected].
For more options, visit this group at
http://groups.google.com/group/sympy?hl=en.
from gmpy import mpq,mpz
from random import randint,seed
from time import time
def rand_matrix(n,typ=int,N=10):
m = []
for i in range(n):
a = []
for j in range(n):
p = randint(0,N)
a.append(p)
m.append(a)
return m
def madd(a,b):
n = len(a)
c = []
for i in range(n):
c.append(a[i][:])
for i in range(n):
c1 = c[i]
b1 = b[i]
for j in range(n):
c1[j] += b1[j]
return c
def imadd(a,b):
n = len(a)
for i in range(n):
a1 = a[i]
b1 = b[i]
for j in range(n):
a1[j] += b1[j]
return a
def msub(a,b):
n = len(a)
c = []
for i in range(n):
c.append(a[i][:])
for i in range(n):
c1 = c[i]
b1 = b[i]
for j in range(n):
c1[j] -= b1[j]
return c
def imsub(a,b):
n = len(a)
for i in range(n):
a1 = a[i]
b1 = b[i]
for j in range(n):
a1[j] -= b1[j]
return a
def mmul(a,b):
n = len(a)
c = []
for i in range(n):
cv = []
a1 = a[i]
for j in range(n):
r = 0
for k in range(n):
r += a1[k]*b[k][j]
cv.append(r)
c.append(cv)
return c
def mmul(a,b,trans=0):
n = len(a)
c = []
if not trans:
bt = [[0]*n for _ in range(n)]
for i in range(n):
for j in range(n):
bt[i][j] = b[j][i]
b = bt
for i in range(n):
a1 = a[i]
cv = []
for j in range(n):
b1 = b[j]
r = 0
for k in range(n):
r += a1[k]*b1[k]
cv.append(r)
c.append(cv)
return c
def strassen_mul(a,b,trans=0):
n = len(a)
if n%2 == 1 or n <= 128:
return mmul(a,b,trans)
rn0 = range(n/2)
rn1 = range(n/2,n)
A00 = [a[i][:n/2] for i in rn0]
A01 = [a[i][n/2:] for i in rn0]
A10 = [a[i][:n/2] for i in rn1]
A11 = [a[i][n/2:n] for i in rn1]
if trans:
B00 = [b[i][:n/2] for i in rn0]
B10 = [b[i][n/2:] for i in rn0]
B01 = [b[i][:n/2] for i in rn1]
B11 = [b[i][n/2:n] for i in rn1]
else:
B00 = [[b[j][i] for j in rn0] for i in rn0]
B01 = [[b[j][i] for j in rn0] for i in rn1]
B10 = [[b[j][i] for j in rn1] for i in rn0]
B11 = [[b[j][i] for j in rn1] for i in rn1]
M1 = strassen_mul(madd(A00,A11),madd(B00,B11),1)
M2 = strassen_mul(madd(A10,A11),B00,1)
M3 = strassen_mul(A00,msub(B01,B11),1)
M4 = strassen_mul(A11,msub(B10,B00),1)
M5 = strassen_mul(madd(A00,A01),B11,1)
M6 = strassen_mul(msub(A10,A00),madd(B00,B01),1)
M7 = strassen_mul(msub(A01,A11),madd(B10,B11),1)
C00 = madd(msub(madd(M1,M4),M5),M7)
C01 = madd(M3,M5)
C10 = madd(M2,M4)
C11 = madd(madd(msub(M1,M2),M3),M6)
c = [C00[i] + C01[i] for i in rn0]
for i in rn0:
c.append(C10[i] + C11[i])
return c
def test_1(N):
seed(2)
for n in range(64,800,64):
a = rand_matrix(n,mpq,N)
b = rand_matrix(n,mpq,N)
t0 = time()
C = mmul(a,b)
t1 = time()
c = strassen_mul(a,b)
t2 = time()
assert c == C
print 'n=%d %.2f %.2f' %(n,t1-t0,t2-t1)
test_1(1000000)