[SYSTEMML-2242] Performance spark binary block reblock operations This patch makes several performance improvements to spark reblock operations for binary block matrices. This includes (1) reduced overhead per non-zero value, and (2) a special case for aligned blocksizes (e.g., when reblocking from 1K to 2K or factors thereof, which is often used for ultra-sparse matrices in practice).
On a scenario of reblocking a 1.2K x 63G matrix w/ 943 non-zeros and empty block materialization from blocksize 1K to 2K, the reblock-sum runtime improved from 107s to 15s. Furthermore, this also reduced GC overhead for subsequent spark operations. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ec044885 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ec044885 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ec044885 Branch: refs/heads/master Commit: ec04488508e0c1743dda435d4c8f7205d161868c Parents: f822950 Author: Matthias Boehm <[email protected]> Authored: Fri Apr 13 23:35:40 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 13 23:35:40 2018 -0700 ---------------------------------------------------------------------- .../functions/ExtractBlockForBinaryReblock.java | 81 +++++++-------- .../functions/data/FullReblockTest.java | 103 +++++++++++-------- 2 files changed, 98 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ec044885/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java index 66a2271..0ce654b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java @@ -36,17 +36,14 @@ public class ExtractBlockForBinaryReblock implements PairFlatMapFunction<Tuple2< { private static final long serialVersionUID = -762987655085029215L; - private long rlen; - private long clen; - private int in_brlen; - private int in_bclen; - private int out_brlen; - private int out_bclen; + private final long rlen, clen; + private final int in_brlen, in_bclen; + private final int out_brlen, out_bclen; public ExtractBlockForBinaryReblock(MatrixCharacteristics mcIn, MatrixCharacteristics mcOut) { - rlen = mcIn.getRows(); + rlen = mcIn.getRows(); clen = mcIn.getCols(); - in_brlen = mcIn.getRowsPerBlock(); + in_brlen = mcIn.getRowsPerBlock(); in_bclen = mcIn.getColsPerBlock(); out_brlen = mcOut.getRowsPerBlock(); out_bclen = mcOut.getColsPerBlock(); @@ -54,7 +51,7 @@ public class ExtractBlockForBinaryReblock implements PairFlatMapFunction<Tuple2< //sanity check block sizes if(in_brlen <= 0 || in_bclen <= 0 || out_brlen <= 0 || out_bclen <= 0) { throw new DMLRuntimeException("Block sizes not unknown:" + - in_brlen + "," + in_bclen + "," + out_brlen + "," + out_bclen); + in_brlen + "," + in_bclen + "," + out_brlen + "," + out_bclen); } } @@ -65,58 +62,58 @@ public class ExtractBlockForBinaryReblock implements PairFlatMapFunction<Tuple2< MatrixIndexes ixIn = arg0._1(); MatrixBlock in = arg0._2(); - // The global cell indexes don't change in reblock operations - long startRowGlobalCellIndex = UtilFunctions.computeCellIndex(ixIn.getRowIndex(), in_brlen, 0); - long endRowGlobalCellIndex = getEndGlobalIndex(ixIn.getRowIndex(), true, true); - long startColGlobalCellIndex = UtilFunctions.computeCellIndex(ixIn.getColumnIndex(), in_bclen, 0); - long endColGlobalCellIndex = getEndGlobalIndex(ixIn.getColumnIndex(), true, false); + final long startRowGlobalCellIndex = UtilFunctions.computeCellIndex(ixIn.getRowIndex(), in_brlen, 0); + final long endRowGlobalCellIndex = getEndGlobalIndex(ixIn.getRowIndex(), true, true); + final long startColGlobalCellIndex = UtilFunctions.computeCellIndex(ixIn.getColumnIndex(), in_bclen, 0); + final long endColGlobalCellIndex = getEndGlobalIndex(ixIn.getColumnIndex(), true, false); - long out_startRowBlockIndex = UtilFunctions.computeBlockIndex(startRowGlobalCellIndex, out_brlen); - long out_endRowBlockIndex = UtilFunctions.computeBlockIndex(endRowGlobalCellIndex, out_brlen); - long out_startColBlockIndex = UtilFunctions.computeBlockIndex(startColGlobalCellIndex, out_bclen); - long out_endColBlockIndex = UtilFunctions.computeBlockIndex(endColGlobalCellIndex, out_bclen); + final long out_startRowBlockIndex = UtilFunctions.computeBlockIndex(startRowGlobalCellIndex, out_brlen); + final long out_endRowBlockIndex = UtilFunctions.computeBlockIndex(endRowGlobalCellIndex, out_brlen); + final long out_startColBlockIndex = UtilFunctions.computeBlockIndex(startColGlobalCellIndex, out_bclen); + final long out_endColBlockIndex = UtilFunctions.computeBlockIndex(endColGlobalCellIndex, out_bclen); + final boolean aligned = out_brlen%in_brlen==0 && out_bclen%in_bclen==0; //e.g, 1K -> 2K ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> retVal = new ArrayList<>(); - for(long i = out_startRowBlockIndex; i <= out_endRowBlockIndex; i++) { for(long j = out_startColBlockIndex; j <= out_endColBlockIndex; j++) { MatrixIndexes indx = new MatrixIndexes(i, j); - int new_lrlen = UtilFunctions.computeBlockSize(rlen, i, out_brlen); - int new_lclen = UtilFunctions.computeBlockSize(clen, j, out_bclen); + final int new_lrlen = UtilFunctions.computeBlockSize(rlen, i, out_brlen); + final int new_lclen = UtilFunctions.computeBlockSize(clen, j, out_bclen); MatrixBlock blk = new MatrixBlock(new_lrlen, new_lclen, true); + if( in.isEmptyBlock(false) ) continue; + + final long rowLower = Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), startRowGlobalCellIndex); + final long rowUpper = Math.min(getEndGlobalIndex(i, false, true), endRowGlobalCellIndex); + final long colLower = Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), startColGlobalCellIndex); + final long colUpper = Math.min(getEndGlobalIndex(j, false, false), endColGlobalCellIndex); + final int aixi = UtilFunctions.computeCellInBlock(rowLower, in_brlen); + final int aixj = UtilFunctions.computeCellInBlock(colLower, in_bclen); + final int cixi = UtilFunctions.computeCellInBlock(rowLower, out_brlen); + final int cixj = UtilFunctions.computeCellInBlock(colLower, out_bclen); - if( !in.isEmptyBlock(false) ) { - long rowLower = Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), startRowGlobalCellIndex); - long rowUpper = Math.min(getEndGlobalIndex(i, false, true), endRowGlobalCellIndex); - long colLower = Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), startColGlobalCellIndex); - long colUpper = Math.min(getEndGlobalIndex(j, false, false), endColGlobalCellIndex); - int in_i1 = UtilFunctions.computeCellInBlock(rowLower, in_brlen); - int out_i1 = UtilFunctions.computeCellInBlock(rowLower, out_brlen); - - for(long i1 = rowLower; i1 <= rowUpper; i1++, in_i1++, out_i1++) { - int in_j1 = UtilFunctions.computeCellInBlock(colLower, in_bclen); - int out_j1 = UtilFunctions.computeCellInBlock(colLower, out_bclen); - for(long j1 = colLower; j1 <= colUpper; j1++, in_j1++, out_j1++) { - double val = in.quickGetValue(in_i1, in_j1); - blk.appendValue(out_i1, out_j1, val); - } - } + if( aligned ) { + blk.appendToSparse(in, cixi, cixj); + blk.setNonZeros(in.getNonZeros()); + } + else { //general case + for(int i2 = 0; i2 <= (int)(rowUpper-rowLower); i2++) + for(int j2 = 0; j2 <= (int)(colUpper-colLower); j2++) + blk.appendValue(cixi+i2, cixj+j2, in.quickGetValue(aixi+i2, aixj+j2)); } retVal.add(new Tuple2<>(indx, blk)); } } + return retVal.iterator(); } - private long getEndGlobalIndex(long blockIndex, boolean isIn, boolean isRow) - { + private long getEndGlobalIndex(long blockIndex, boolean isIn, boolean isRow) { //determine dimension and block sizes long len = isRow ? rlen : clen; - int blen = isIn ? (isRow ? in_brlen : in_bclen) - : (isRow ? out_brlen : out_bclen); + int blen = isIn ? (isRow ? in_brlen : in_bclen) : (isRow ? out_brlen : out_bclen); //compute 1-based global cell index in block int new_len = UtilFunctions.computeBlockSize(len, blockIndex, blen); return UtilFunctions.computeCellIndex(blockIndex, blen, new_len-1); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ec044885/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java b/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java index c1a46ce..9e2f902 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/data/FullReblockTest.java @@ -260,40 +260,64 @@ public class FullReblockTest extends AutomatedTestBase } @Test - public void testBinaryBlockSingleMDenseSP() - { + public void testBinaryBlockSingleMDenseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Single, ExecType.SPARK); } @Test - public void testBinaryBlockSingeMSparseSP() - { + public void testBinaryBlockSingeMSparseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Single, ExecType.SPARK); } @Test - public void testBinaryBlockSingleVDenseSP() - { + public void testBinaryBlockSingleVDenseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Vector, ExecType.SPARK); } @Test - public void testBinaryBlockSingeVSparseSP() - { + public void testBinaryBlockSingeVSparseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Vector, ExecType.SPARK); } @Test - public void testBinaryBlockMultipleMDenseSP() - { + public void testBinaryBlockMultipleMDenseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Multiple, ExecType.SPARK); } @Test - public void testBinaryBlockMultipleMSparseSP() - { + public void testBinaryBlockMultipleMSparseSP() { runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Multiple, ExecType.SPARK); } + + @Test + public void testBinaryBlockSingleMDenseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Single, ExecType.SPARK, 500); + } + + @Test + public void testBinaryBlockSingeMSparseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Single, ExecType.SPARK, 500); + } + + @Test + public void testBinaryBlockSingleVDenseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Vector, ExecType.SPARK, 500); + } + + @Test + public void testBinaryBlockSingeVSparseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Vector, ExecType.SPARK, 500); + } + + @Test + public void testBinaryBlockMultipleMDenseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, false, Type.Multiple, ExecType.SPARK, 500); + } + + @Test + public void testBinaryBlockMultipleMSparseSPAligned() { + runReblockTest(OutputInfo.BinaryBlockOutputInfo, true, Type.Multiple, ExecType.SPARK, 500); + } //csv @@ -406,17 +430,15 @@ public class FullReblockTest extends AutomatedTestBase runReblockTest(OutputInfo.CSVOutputInfo, true, Type.Multiple, ExecType.MR); } - /** - * - * @param oi - * @param sparse - * @param type - * @param et - */ - private void runReblockTest( OutputInfo oi, boolean sparse, Type type, ExecType et ) - { - String TEST_NAME = (type==Type.Multiple) ? TEST_NAME2 : TEST_NAME1; - double sparsity = (sparse) ? sparsity2 : sparsity1; + private void runReblockTest( OutputInfo oi, boolean sparse, Type type, ExecType et ) { + //force binary reblock for 999 to match 1000 + runReblockTest(oi, sparse, type, et, blocksize-1); + } + + private void runReblockTest( OutputInfo oi, boolean sparse, Type type, ExecType et, int srcBlksize ) + { + String TEST_NAME = (type==Type.Multiple) ? TEST_NAME2 : TEST_NAME1; + double sparsity = (sparse) ? sparsity2 : sparsity1; int rows = (type==Type.Vector)? rowsV : rowsM; int cols = (type==Type.Vector)? colsV : colsM; @@ -448,38 +470,31 @@ public class FullReblockTest extends AutomatedTestBase boolean success = false; long seed1 = System.nanoTime(); long seed2 = System.nanoTime()+7; - - try - { + + try { //run test cases with single or multiple inputs if( type==Type.Multiple ) { double[][] A1 = getRandomMatrix(rows, cols, 0, 1, sparsity, seed1); double[][] A2 = getRandomMatrix(rows, cols, 0, 1, sparsity, seed2); - - //force binary reblock for 999 to match 1000 - writeMatrix(A1, input("A1"), oi, rows, cols, blocksize-1, blocksize-1); - writeMatrix(A2, input("A2"), oi, rows, cols, blocksize-1, blocksize-1); + writeMatrix(A1, input("A1"), oi, rows, cols, blocksize-1, blocksize-1); + writeMatrix(A2, input("A2"), oi, rows, cols, blocksize-1, blocksize-1); runTest(true, false, null, -1); - double[][] C1 = readMatrix(output("C1"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); - double[][] C2 = readMatrix(output("C2"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); - TestUtils.compareMatrices(A1, C1, rows, cols, eps); - TestUtils.compareMatrices(A2, C2, rows, cols, eps); - } - else - { + double[][] C1 = readMatrix(output("C1"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); + double[][] C2 = readMatrix(output("C2"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); + TestUtils.compareMatrices(A1, C1, rows, cols, eps); + TestUtils.compareMatrices(A2, C2, rows, cols, eps); + } + else { double[][] A = getRandomMatrix(rows, cols, 0, 1, sparsity, seed1); - - //force binary reblock for 999 to match 1000 - writeMatrix(A, input("A"), oi, rows, cols, blocksize-1, blocksize-1); + writeMatrix(A, input("A"), oi, rows, cols, blocksize-1, blocksize-1); runTest(true, false, null, -1); - double[][] C = readMatrix(output("C"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); - - TestUtils.compareMatrices(A, C, rows, cols, eps); + double[][] C = readMatrix(output("C"), InputInfo.BinaryBlockInputInfo, rows, cols, blocksize, blocksize); + TestUtils.compareMatrices(A, C, rows, cols, eps); } success = true; - } + } catch (Exception e) { e.printStackTrace(); Assert.fail();
