Repository: systemml
Updated Branches:
  refs/heads/master ee5f61307 -> 5f6748fda


[SYSTEMML-2176] Performance spark reshape ultra-sparse matrices

This patch makes a significant performance improvement to ultra-sparse
distributed spark reshape operations, especially for special cases of
vector-matrix reshapes. In detail, this includes a better enumeration of
output block indexes (which becomes a bottleneck for ultra-sparse
matrices if not done right), avoid unnecessary sort for column-wise
reshape of vectors to matrices, and a guard to ensure correct
sparse/dense output block formats.

On a specific algorithm with 110 distributed reshape operations over
ultra-sparse inputs and intermediates, this patch improved the
end-to-end runtime from 508s to 139s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5f6748fd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5f6748fd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5f6748fd

Branch: refs/heads/master
Commit: 5f6748fda8ea09567a494b7259b01400d65a5ed5
Parents: ee5f613
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Wed Mar 7 22:38:32 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Mar 7 22:38:32 2018 -0800

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixReorg.java     | 134 +++++++++++--------
 1 file changed, 80 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5f6748fd/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
index 3f302a9..145a474 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
@@ -416,6 +416,11 @@ public class LibMatrixReorg
        /**
         * CP reshape operation (single input, single output matrix) 
         * 
+        * NOTE: In contrast to R, the rowwise parameter specifies both
+        * the read and write order, with row-wise being the default, while
+        * R uses always a column-wise read, rowwise specifying the write
+        * order and column-wise being the default. 
+        * 
         * @param in input matrix
         * @param out output matrix
         * @param rows number of rows
@@ -497,11 +502,13 @@ public class LibMatrixReorg
                        reshapeSparse(mbIn, row_offset, col_offset, rblk, mcIn, 
mcOut, rowwise);
                else //dense
                        reshapeDense(mbIn, row_offset, col_offset, rblk, mcIn, 
mcOut, rowwise);
-
+               
                //prepare output
                out = new ArrayList<>();
-               for( Entry<MatrixIndexes, MatrixBlock> e : rblk.entrySet() )
+               for( Entry<MatrixIndexes, MatrixBlock> e : rblk.entrySet() ) {
+                       e.getValue().examSparsity(); //ensure correct format
                        out.add(new 
IndexedMatrixValue(e.getKey(),e.getValue()));
+               }
                
                return out;
        }
@@ -1493,65 +1500,84 @@ public class LibMatrixReorg
                
                long row_offset = 
(ixin.getRowIndex()-1)*mcOut.getRowsPerBlock();
                long col_offset = 
(ixin.getColumnIndex()-1)*mcOut.getColsPerBlock();
+               long max_row_offset = 
Math.min(mcIn.getRows(),row_offset+mcIn.getRowsPerBlock())-1;
+               long max_col_offset = 
Math.min(mcIn.getCols(),col_offset+mcIn.getColsPerBlock())-1;
                
                if( rowwise ) {
-                       long max_col_offset = 
Math.min(mcIn.getCols(),col_offset+mcIn.getColsPerBlock())-1;
-                       for( long i=row_offset; 
i<Math.min(mcIn.getRows(),row_offset+mcIn.getRowsPerBlock()); i++ ) {
+                       if( mcIn.getCols() == 1 ) {
+                               MatrixIndexes first = 
computeResultBlockIndex(new MatrixIndexes(), row_offset, 0, mcIn, mcOut, 
rowwise);
+                               MatrixIndexes last = 
computeResultBlockIndex(new MatrixIndexes(), max_row_offset, 0, mcIn, mcOut, 
rowwise);
+                               createRowwiseIndexes(first, last, 
mcOut.getNumColBlocks(), ret);
+                       
+                       }
+                       for( long i=row_offset; i<max_row_offset+1; i++ ) {
                                MatrixIndexes first = 
computeResultBlockIndex(new MatrixIndexes(), i, col_offset, mcIn, mcOut, 
rowwise);
                                MatrixIndexes last = 
computeResultBlockIndex(new MatrixIndexes(), i, max_col_offset, mcIn, mcOut, 
rowwise);
-                               if( first.getRowIndex()<=0 || 
first.getColumnIndex()<=0 )
-                                       throw new RuntimeException("Invalid 
computed first index: "+first.toString());
-                               if( last.getRowIndex()<=0 || 
last.getColumnIndex()<=0 )
-                                       throw new RuntimeException("Invalid 
computed last index: "+last.toString());
-                               
-                               //add first row block
-                               ret.add(first);
-                               
-                               //add blocks in between first and last
-                               if( !first.equals(last) ) {
-                                       boolean fill = 
first.getRowIndex()==last.getRowIndex()
-                                               && first.getColumnIndex() > 
last.getColumnIndex();
-                                       for( long k1=first.getRowIndex(); 
k1<=last.getRowIndex(); k1++ ) {
-                                               long k2_start = 
(k1==first.getRowIndex() ? first.getColumnIndex()+1 : 1);
-                                               long k2_end = 
(k1==last.getRowIndex() && !fill) ?
-                                                       last.getColumnIndex()-1 
: mcOut.getNumColBlocks();
-                                               for( long k2=k2_start; 
k2<=k2_end; k2++ )
-                                                       ret.add(new 
MatrixIndexes(k1,k2));
-                                       }
-                                       ret.add(last);
-                               }
+                               createRowwiseIndexes(first, last, 
mcOut.getNumColBlocks(), ret);
                        }
                }
                else{ //colwise
-                       long max_row_offset = 
Math.min(mcIn.getRows(),row_offset+mcIn.getRowsPerBlock())-1;
-                       for( long j=col_offset; 
j<Math.min(mcIn.getCols(),col_offset+mcIn.getColsPerBlock()); j++ ) {
-                               MatrixIndexes first = 
computeResultBlockIndex(new MatrixIndexes(), row_offset, j, mcIn, mcOut, 
rowwise);
-                               MatrixIndexes last = 
computeResultBlockIndex(new MatrixIndexes(), max_row_offset, j, mcIn, mcOut, 
rowwise);
-                               if( first.getRowIndex()<=0 || 
first.getColumnIndex()<=0 )
-                                       throw new RuntimeException("Invalid 
computed first index: "+first.toString());
-                               if( last.getRowIndex()<=0 || 
last.getColumnIndex()<=0 )
-                                       throw new RuntimeException("Invalid 
computed last index: "+last.toString());
-                               
-                               //add first row block
-                               ret.add(first);
-                               
-                               //add blocks in between first and last
-                               if( !first.equals(last) ) {
-                                       boolean fill = 
first.getColumnIndex()==last.getColumnIndex()
-                                                       && first.getRowIndex() 
> last.getRowIndex();
-                                       for( long k1=first.getColumnIndex(); 
k1<=last.getColumnIndex(); k1++ ) {
-                                               long k2_start = 
((k1==first.getColumnIndex()) ? first.getRowIndex()+1 : 1);
-                                               long k2_end = 
((k1==last.getColumnIndex() && !fill) ? 
-                                                       last.getRowIndex()-1 : 
mcOut.getNumRowBlocks());
-                                               for( long k2=k2_start; 
k2<=k2_end; k2++ )
-                                                       ret.add(new 
MatrixIndexes(k1,k2));
-                                       }
-                                       ret.add(last);
+                       if( mcIn.getRows() == 1 ) {
+                               MatrixIndexes first = 
computeResultBlockIndex(new MatrixIndexes(), 0, col_offset, mcIn, mcOut, 
rowwise);
+                               MatrixIndexes last = 
computeResultBlockIndex(new MatrixIndexes(), 0, max_col_offset, mcIn, mcOut, 
rowwise);
+                               createColwiseIndexes(first, last, 
mcOut.getNumRowBlocks(), ret);
+                       }
+                       else {
+                               for( long j=col_offset; j<max_col_offset+1; j++ 
) {
+                                       MatrixIndexes first = 
computeResultBlockIndex(new MatrixIndexes(), row_offset, j, mcIn, mcOut, 
rowwise);
+                                       MatrixIndexes last = 
computeResultBlockIndex(new MatrixIndexes(), max_row_offset, j, mcIn, mcOut, 
rowwise);
+                                       createColwiseIndexes(first, last, 
mcOut.getNumRowBlocks(), ret);
                                }
                        }
                }
                return ret;
        }
+       
+       private static void createRowwiseIndexes(MatrixIndexes first, 
MatrixIndexes last, long ncblks, HashSet<MatrixIndexes> ret) {
+               if( first.getRowIndex()<=0 || first.getColumnIndex()<=0 )
+                       throw new RuntimeException("Invalid computed first 
index: "+first.toString());
+               if( last.getRowIndex()<=0 || last.getColumnIndex()<=0 )
+                       throw new RuntimeException("Invalid computed last 
index: "+last.toString());
+               
+               //add first row block
+               ret.add(first);
+               
+               //add blocks in between first and last
+               if( !first.equals(last) ) {
+                       boolean fill = first.getRowIndex()==last.getRowIndex()
+                               && first.getColumnIndex() > 
last.getColumnIndex();
+                       for( long k1=first.getRowIndex(); 
k1<=last.getRowIndex(); k1++ ) {
+                               long k2_start = (k1==first.getRowIndex() ? 
first.getColumnIndex()+1 : 1);
+                               long k2_end = (k1==last.getRowIndex() && !fill) 
? last.getColumnIndex()-1 : ncblks;
+                               for( long k2=k2_start; k2<=k2_end; k2++ )
+                                       ret.add(new MatrixIndexes(k1,k2));
+                       }
+                       ret.add(last);
+               }
+       }
+       
+       private static void createColwiseIndexes(MatrixIndexes first, 
MatrixIndexes last, long nrblks, HashSet<MatrixIndexes> ret) {
+               if( first.getRowIndex()<=0 || first.getColumnIndex()<=0 )
+                       throw new RuntimeException("Invalid computed first 
index: "+first.toString());
+               if( last.getRowIndex()<=0 || last.getColumnIndex()<=0 )
+                       throw new RuntimeException("Invalid computed last 
index: "+last.toString());
+               
+               //add first row block
+               ret.add(first);
+               
+               //add blocks in between first and last
+               if( !first.equals(last) ) {
+                       boolean fill = 
first.getColumnIndex()==last.getColumnIndex()
+                                       && first.getRowIndex() > 
last.getRowIndex();
+                       for( long k1=first.getColumnIndex(); 
k1<=last.getColumnIndex(); k1++ ) {
+                               long k2_start = ((k1==first.getColumnIndex()) ? 
first.getRowIndex()+1 : 1);
+                               long k2_end = ((k1==last.getColumnIndex() && 
!fill) ? last.getRowIndex()-1 : nrblks);
+                               for( long k2=k2_start; k2<=k2_end; k2++ )
+                                       ret.add(new MatrixIndexes(k2,k1));
+                       }
+                       ret.add(last);
+               }
+       }
 
        @SuppressWarnings("unused")
        private static HashMap<MatrixIndexes, MatrixBlock> 
createAllResultBlocks( Collection<MatrixIndexes> rix, 
@@ -1623,7 +1649,7 @@ public class LibMatrixReorg
                }
                
                //cleanup for sparse blocks
-               if( !rowwise ) {
+               if( !rowwise && mcIn.getRows() > 1 ) {
                        for( MatrixBlock block : rix.values() )
                                if( block.sparse )
                                        block.sortSparseRows();
@@ -1634,7 +1660,7 @@ public class LibMatrixReorg
                        HashMap<MatrixIndexes,MatrixBlock> rix, 
                        MatrixCharacteristics mcIn, MatrixCharacteristics 
mcOut, boolean rowwise ) 
                throws DMLRuntimeException
-    {
+       {
                if( in.isEmptyBlock(false) )
                        return;
                
@@ -1665,12 +1691,12 @@ public class LibMatrixReorg
                }
                
                //cleanup for sparse blocks
-               if( !rowwise ) {
+               if( !rowwise && mcIn.getRows() > 1 ) {
                        for( MatrixBlock block : rix.values() )
                                if( block.sparse )
                                        block.sortSparseRows();
-               }                               
-    }
+               }
+       }
        
        /**
         * Assumes internal (0-begin) indices ai, aj as input; computes 
external block indexes (1-begin) 

Reply via email to