[SYSTEMML-2242] Performance spark binary block reblock operations 

This patch makes several performance improvements to spark reblock
operations for binary block matrices. This includes (1) reduced overhead
per non-zero value, and (2) a special case for aligned blocksizes (e.g.,
when reblocking from 1K to 2K or factors thereof, which is often used
for ultra-sparse matrices in practice).

On a scenario of reblocking a 1.2K x 63G matrix w/ 943 non-zeros and
empty block materialization from blocksize 1K to 2K, the reblock-sum
runtime improved from 107s to 15s. Furthermore, this also reduced GC
overhead for subsequent spark operations.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ec044885
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ec044885
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ec044885

Branch: refs/heads/master
Commit: ec04488508e0c1743dda435d4c8f7205d161868c
Parents: f822950
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Apr 13 23:35:40 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Apr 13 23:35:40 2018 -0700

----------------------------------------------------------------------
 .../functions/ExtractBlockForBinaryReblock.java |  81 +++++++--------
 .../functions/data/FullReblockTest.java         | 103 +++++++++++--------
 2 files changed, 98 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ec044885/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
index 66a2271..0ce654b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
@@ -36,17 +36,14 @@ public class ExtractBlockForBinaryReblock implements 
PairFlatMapFunction<Tuple2<
 {
        private static final long serialVersionUID = -762987655085029215L;
        
-       private long rlen; 
-       private long clen; 
-       private int in_brlen; 
-       private int in_bclen; 
-       private int out_brlen; 
-       private int out_bclen;
+       private final long rlen, clen; 
+       private final int in_brlen, in_bclen;
+       private final int out_brlen, out_bclen;
        
        public ExtractBlockForBinaryReblock(MatrixCharacteristics mcIn, 
MatrixCharacteristics mcOut) {
-               rlen = mcIn.getRows(); 
+               rlen = mcIn.getRows();
                clen = mcIn.getCols();
-               in_brlen = mcIn.getRowsPerBlock(); 
+               in_brlen = mcIn.getRowsPerBlock();
                in_bclen = mcIn.getColsPerBlock();
                out_brlen = mcOut.getRowsPerBlock(); 
                out_bclen = mcOut.getColsPerBlock();
@@ -54,7 +51,7 @@ public class ExtractBlockForBinaryReblock implements 
PairFlatMapFunction<Tuple2<
                //sanity check block sizes
                if(in_brlen <= 0 || in_bclen <= 0 || out_brlen <= 0 || 
out_bclen <= 0) {
                        throw new DMLRuntimeException("Block sizes not 
unknown:" + 
-                      in_brlen + "," + in_bclen + "," +  out_brlen + "," + 
out_bclen);
+                               in_brlen + "," + in_bclen + "," +  out_brlen + 
"," + out_bclen);
                }
        }
        
@@ -65,58 +62,58 @@ public class ExtractBlockForBinaryReblock implements 
PairFlatMapFunction<Tuple2<
                MatrixIndexes ixIn = arg0._1();
                MatrixBlock in = arg0._2();
                
-               // The global cell indexes don't change in reblock operations
-               long startRowGlobalCellIndex = 
UtilFunctions.computeCellIndex(ixIn.getRowIndex(), in_brlen, 0);
-               long endRowGlobalCellIndex = 
getEndGlobalIndex(ixIn.getRowIndex(), true, true);
-               long startColGlobalCellIndex = 
UtilFunctions.computeCellIndex(ixIn.getColumnIndex(), in_bclen, 0);
-               long endColGlobalCellIndex = 
getEndGlobalIndex(ixIn.getColumnIndex(), true, false);
+               final long startRowGlobalCellIndex = 
UtilFunctions.computeCellIndex(ixIn.getRowIndex(), in_brlen, 0);
+               final long endRowGlobalCellIndex = 
getEndGlobalIndex(ixIn.getRowIndex(), true, true);
+               final long startColGlobalCellIndex = 
UtilFunctions.computeCellIndex(ixIn.getColumnIndex(), in_bclen, 0);
+               final long endColGlobalCellIndex = 
getEndGlobalIndex(ixIn.getColumnIndex(), true, false);
                
-               long out_startRowBlockIndex = 
UtilFunctions.computeBlockIndex(startRowGlobalCellIndex, out_brlen);
-               long out_endRowBlockIndex = 
UtilFunctions.computeBlockIndex(endRowGlobalCellIndex, out_brlen);
-               long out_startColBlockIndex = 
UtilFunctions.computeBlockIndex(startColGlobalCellIndex, out_bclen);
-               long out_endColBlockIndex = 
UtilFunctions.computeBlockIndex(endColGlobalCellIndex, out_bclen);
+               final long out_startRowBlockIndex = 
UtilFunctions.computeBlockIndex(startRowGlobalCellIndex, out_brlen);
+               final long out_endRowBlockIndex = 
UtilFunctions.computeBlockIndex(endRowGlobalCellIndex, out_brlen);
+               final long out_startColBlockIndex = 
UtilFunctions.computeBlockIndex(startColGlobalCellIndex, out_bclen);
+               final long out_endColBlockIndex = 
UtilFunctions.computeBlockIndex(endColGlobalCellIndex, out_bclen);
+               final boolean aligned = out_brlen%in_brlen==0 && 
out_bclen%in_bclen==0; //e.g, 1K -> 2K
                
                ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> retVal = new 
ArrayList<>();
-               
                for(long i = out_startRowBlockIndex; i <= out_endRowBlockIndex; 
i++) {
                        for(long j = out_startColBlockIndex; j <= 
out_endColBlockIndex; j++) {
                                MatrixIndexes indx = new MatrixIndexes(i, j);
-                               int new_lrlen = 
UtilFunctions.computeBlockSize(rlen, i, out_brlen);
-                               int new_lclen = 
UtilFunctions.computeBlockSize(clen, j, out_bclen);
+                               final int new_lrlen = 
UtilFunctions.computeBlockSize(rlen, i, out_brlen);
+                               final int new_lclen = 
UtilFunctions.computeBlockSize(clen, j, out_bclen);
                                MatrixBlock blk = new MatrixBlock(new_lrlen, 
new_lclen, true);
+                               if( in.isEmptyBlock(false) ) continue;
+                               
+                               final long rowLower = 
Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), 
startRowGlobalCellIndex);
+                               final long rowUpper = 
Math.min(getEndGlobalIndex(i, false, true), endRowGlobalCellIndex);
+                               final long colLower = 
Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), 
startColGlobalCellIndex);
+                               final long colUpper = 
Math.min(getEndGlobalIndex(j, false, false), endColGlobalCellIndex);
+                               final int aixi = 
UtilFunctions.computeCellInBlock(rowLower, in_brlen);
+                               final int aixj = 
UtilFunctions.computeCellInBlock(colLower, in_bclen);
+                               final int cixi = 
UtilFunctions.computeCellInBlock(rowLower, out_brlen);
+                               final int cixj = 
UtilFunctions.computeCellInBlock(colLower, out_bclen);
                                
-                               if( !in.isEmptyBlock(false) ) {
-                                       long rowLower = 
Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), 
startRowGlobalCellIndex);
-                                       long rowUpper = 
Math.min(getEndGlobalIndex(i, false, true), endRowGlobalCellIndex);
-                                       long colLower = 
Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), 
startColGlobalCellIndex);
-                                       long colUpper = 
Math.min(getEndGlobalIndex(j, false, false), endColGlobalCellIndex);
-                                       int in_i1 = 
UtilFunctions.computeCellInBlock(rowLower, in_brlen);
-                                       int out_i1 = 
UtilFunctions.computeCellInBlock(rowLower, out_brlen);
-                                       
-                                       for(long i1 = rowLower; i1 <= rowUpper; 
i1++, in_i1++, out_i1++) {
-                                               int in_j1 = 
UtilFunctions.computeCellInBlock(colLower, in_bclen);
-                                               int out_j1 = 
UtilFunctions.computeCellInBlock(colLower, out_bclen);
-                                               for(long j1 = colLower; j1 <= 
colUpper; j1++, in_j1++, out_j1++) {
-                                                       double val = 
in.quickGetValue(in_i1, in_j1);
-                                                       blk.appendValue(out_i1, 
out_j1, val);
-                                               }
-                                       }
+                               if( aligned ) {
+                                       blk.appendToSparse(in, cixi, cixj);
+                                       blk.setNonZeros(in.getNonZeros());
+                               }
+                               else { //general case
+                                       for(int i2 = 0; i2 <= 
(int)(rowUpper-rowLower); i2++)
+                                               for(int j2 = 0; j2 <= 
(int)(colUpper-colLower); j2++)
+                                                       
blk.appendValue(cixi+i2, cixj+j2, in.quickGetValue(aixi+i2, aixj+j2));
                                }
                                retVal.add(new Tuple2<>(indx, blk));
                        }
                }
+               
                return retVal.iterator();
        }
 
-       private long getEndGlobalIndex(long blockIndex, boolean isIn, boolean 
isRow) 
-       {
+       private long getEndGlobalIndex(long blockIndex, boolean isIn, boolean 
isRow) {
                //determine dimension and block sizes
                long len = isRow ? rlen : clen;
-               int blen = isIn ? (isRow ? in_brlen : in_bclen) 
-                                       : (isRow ? out_brlen : out_bclen);
+               int blen = isIn ? (isRow ? in_brlen : in_bclen) : (isRow ? 
out_brlen : out_bclen);
                
                //compute 1-based global cell index in block
                int new_len = UtilFunctions.computeBlockSize(len, blockIndex, 
blen);
                return UtilFunctions.computeCellIndex(blockIndex, blen, 
new_len-1);
        }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec044885/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java
index c1a46ce..9e2f902 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java
@@ -260,40 +260,64 @@ public class FullReblockTest extends AutomatedTestBase
        }
 
        @Test
-       public void testBinaryBlockSingleMDenseSP() 
-       {
+       public void testBinaryBlockSingleMDenseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Single, ExecType.SPARK);
        }
        
        @Test
-       public void testBinaryBlockSingeMSparseSP() 
-       {
+       public void testBinaryBlockSingeMSparseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Single, ExecType.SPARK);
        }
        
        @Test
-       public void testBinaryBlockSingleVDenseSP() 
-       {
+       public void testBinaryBlockSingleVDenseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Vector, ExecType.SPARK);
        }
        
        @Test
-       public void testBinaryBlockSingeVSparseSP() 
-       {
+       public void testBinaryBlockSingeVSparseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Vector, ExecType.SPARK);
        }
        
        @Test
-       public void testBinaryBlockMultipleMDenseSP() 
-       {
+       public void testBinaryBlockMultipleMDenseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Multiple, ExecType.SPARK);
        }
        
        @Test
-       public void testBinaryBlockMultipleMSparseSP() 
-       {
+       public void testBinaryBlockMultipleMSparseSP() {
                runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Multiple, ExecType.SPARK);
        }
+       
+       @Test
+       public void testBinaryBlockSingleMDenseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Single, ExecType.SPARK, 500);
+       }
+       
+       @Test
+       public void testBinaryBlockSingeMSparseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Single, ExecType.SPARK, 500);
+       }
+       
+       @Test
+       public void testBinaryBlockSingleVDenseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Vector, ExecType.SPARK, 500);
+       }
+       
+       @Test
+       public void testBinaryBlockSingeVSparseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Vector, ExecType.SPARK, 500);
+       }
+       
+       @Test
+       public void testBinaryBlockMultipleMDenseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, 
Type.Multiple, ExecType.SPARK, 500);
+       }
+       
+       @Test
+       public void testBinaryBlockMultipleMSparseSPAligned() {
+               runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, 
Type.Multiple, ExecType.SPARK, 500);
+       }
 
        //csv
        
@@ -406,17 +430,15 @@ public class FullReblockTest extends AutomatedTestBase
                runReblockTest(OutputInfo.CSVOutputInfo, true, Type.Multiple, 
ExecType.MR);
        }
 
-       /**
-        * 
-        * @param oi
-        * @param sparse
-        * @param type
-        * @param et
-        */
-       private void runReblockTest( OutputInfo oi, boolean sparse, Type type, 
ExecType et )
-       {               
-               String TEST_NAME = (type==Type.Multiple) ? TEST_NAME2 : 
TEST_NAME1;              
-               double sparsity = (sparse) ? sparsity2 : sparsity1;             
+       private void runReblockTest( OutputInfo oi, boolean sparse, Type type, 
ExecType et ) {
+               //force binary reblock for 999 to match 1000
+               runReblockTest(oi, sparse, type, et, blocksize-1);
+       }
+       
+       private void runReblockTest( OutputInfo oi, boolean sparse, Type type, 
ExecType et, int srcBlksize )
+       {
+               String TEST_NAME = (type==Type.Multiple) ? TEST_NAME2 : 
TEST_NAME1;
+               double sparsity = (sparse) ? sparsity2 : sparsity1;
                int rows = (type==Type.Vector)? rowsV : rowsM;
                int cols = (type==Type.Vector)? colsV : colsM;
                
@@ -448,38 +470,31 @@ public class FullReblockTest extends AutomatedTestBase
                boolean success = false;
                long seed1 = System.nanoTime();
                long seed2 = System.nanoTime()+7;
-        
-               try 
-               {
+
+               try {
                        //run test cases with single or multiple inputs
                        if( type==Type.Multiple )
                        {
                                double[][] A1 = getRandomMatrix(rows, cols, 0, 
1, sparsity, seed1);
                                double[][] A2 = getRandomMatrix(rows, cols, 0, 
1, sparsity, seed2);
-                       
-                               //force binary reblock for 999 to match 1000
-                       writeMatrix(A1, input("A1"), oi, rows, cols, 
blocksize-1, blocksize-1);
-                       writeMatrix(A2, input("A2"), oi, rows, cols, 
blocksize-1, blocksize-1);
+                               writeMatrix(A1, input("A1"), oi, rows, cols, 
blocksize-1, blocksize-1);
+                               writeMatrix(A2, input("A2"), oi, rows, cols, 
blocksize-1, blocksize-1);
                                runTest(true, false, null, -1);
-                       double[][] C1 = readMatrix(output("C1"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
-                       double[][] C2 = readMatrix(output("C2"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
-                       TestUtils.compareMatrices(A1, C1, rows, cols, eps);
-                       TestUtils.compareMatrices(A2, C2, rows, cols, eps);
-                   }
-                       else
-                       {
+                               double[][] C1 = readMatrix(output("C1"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
+                               double[][] C2 = readMatrix(output("C2"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
+                               TestUtils.compareMatrices(A1, C1, rows, cols, 
eps);
+                               TestUtils.compareMatrices(A2, C2, rows, cols, 
eps);
+                       }
+                       else {
                                double[][] A = getRandomMatrix(rows, cols, 0, 
1, sparsity, seed1);
-
-                               //force binary reblock for 999 to match 1000
-                       writeMatrix(A, input("A"), oi, rows, cols, blocksize-1, 
blocksize-1);
+                               writeMatrix(A, input("A"), oi, rows, cols, 
blocksize-1, blocksize-1);
                                runTest(true, false, null, -1);
-                       double[][] C = readMatrix(output("C"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
-                               
-                       TestUtils.compareMatrices(A, C, rows, cols, eps);
+                               double[][] C = readMatrix(output("C"), 
InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize);
+                               TestUtils.compareMatrices(A, C, rows, cols, 
eps);
                        }
                        
                        success = true;
-               } 
+               }
                catch (Exception e) {
                        e.printStackTrace();
                        Assert.fail();

Reply via email to