[SYSTEMML-641] Cache-conscious short matrix-matrix block multiply

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7fb11176
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7fb11176
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7fb11176

Branch: refs/heads/master
Commit: 7fb11176cd6c09d19f9752308274a627d582c8b0
Parents: 39cf164
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 4 23:08:26 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 4 23:08:26 2016 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 37 +++++++++++---------
 1 file changed, 21 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7fb11176/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index bd28376..782987a 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -1060,25 +1060,30 @@ public class LibMatrixMult
                        }
                        else if( pm2 && m<=16 )    //MATRIX-MATRIX (short lhs) 
                        {
-                               //parallelization over rows in rhs matrix
-                               final int kn = (ru-rl)%2;                       
        
+                               //cache-conscious parallelization over rows in 
rhs matrix
+                               final int kn = (ru-rl)%4;                       
        
                                
                                //rest not aligned to blocks of 2 rows
-                               if( kn == 1 )
-                                       for( int i=0, aix=0, cix=0; i<m; i++, 
aix+=cd, cix+=n )
-                                               if( a[aix+rl] != 0 )
-                                                       
vectMultiplyAdd(a[aix+rl], b, c, rl*n, cix, n);
-                               
-                               //compute blocks of 2 rows (w/ repeated scan 
for each row in lhs) 
-                               for( int k=rl+kn, bix=(rl+kn)*n; k<ru; k+=2, 
bix+=2*n )
-                                       for( int i=0, aix=0, cix=0; i<m; i++, 
aix+=cd, cix+=n ){
-                                               if( a[aix+k] != 0 && a[aix+k+1] 
!= 0  )
-                                                       
vectMultiplyAdd2(a[aix+k], a[aix+k+1], b, c, bix, bix+n, cix, n);
-                                               else if( a[aix+k] != 0 )
+                               for( int i=0, aix=0, cix=0; i<m; i++, aix+=cd, 
cix+=n )
+                                       for( int k=rl, bix=rl*n; k<rl+kn; k++, 
bix+=n )
+                                               if( a[aix+k] != 0 )
                                                        
vectMultiplyAdd(a[aix+k], b, c, bix, cix, n);
-                                               else if( a[aix+k+1] != 0 )      
-                                                       
vectMultiplyAdd(a[aix+k+1], b, c, bix+n, cix, n);
-                                       }                               
+
+                               final int blocksizeK = 48;  
+                               final int blocksizeJ = 1024; 
+                               
+                               //blocked execution
+                               for( int bk = rl+kn; bk < ru; bk+=blocksizeK ) 
+                                       for( int bj = 0, bkmin = Math.min(cd, 
bk+blocksizeK); bj < n; bj+=blocksizeJ ) 
+                                       {
+                                               //compute blocks of 4 rows in 
rhs w/ IKJ
+                                               int bjlen = Math.min(n, 
bj+blocksizeJ)-bj;
+                                               for( int i=0, aix=0, cix=bj; 
i<m; i++, aix+=cd, cix+=n )
+                                                       for( int k=bk, 
bix=bk*n; k<bkmin; k+=4, bix+=4*n ) {
+                                                               
vectMultiplyAdd4(a[aix+k], a[aix+k+1], a[aix+k+2], a[aix+k+3], 
+                                                                               
b, c, bix, bix+n, bix+2*n, bix+3*n, cix, bjlen);
+                                                       }
+                                       }
                        }
                        else if( tm2 )             //MATRIX-MATRIX (skinny rhs)
                        {

Reply via email to