Repository: systemml Updated Branches: refs/heads/master a83c25019 -> 07bab605e
[SYSTEMML-2184] Apply basic wdivmm conservatively (sparse/dense inputs) This patch makes a compiler improvement to apply the fused basic wdivmm for the pattern (X * U%*%t(V)) more conservatively only for when (1) X is sparse and U/V are unknown or dense, or (2) when U and V known to be dense, which are the primary applications for this sparsity exploiting fused operator. At the same time, this helps to avoid performance issues for scenarios where U and V are actually ultra-sparse and wdivmm would hurt performance. For example, on perftest stratstats 10K x 10K, this patch improved performance from 128s to 29s end-to-end. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/07bab605 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/07bab605 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/07bab605 Branch: refs/heads/master Commit: 07bab605ead7b5757ce6ed0211042c216b5ffb41 Parents: a83c250 Author: Matthias Boehm <[email protected]> Authored: Mon Mar 12 23:07:52 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Mar 12 23:07:52 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 7 ++++++- .../RewriteAlgebraicSimplificationDynamic.java | 15 ++++++--------- .../test/integration/AutomatedTestBase.java | 6 ++---- .../quaternary/WeightedDivMatrixMultTest.java | 19 +++++++++---------- .../org/apache/sysml/test/utils/TestUtils.java | 10 ++++++++-- 5 files changed, 31 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index dce21ab..735cac0 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -840,11 +840,16 @@ public class HopRewriteUtils return (Hop.getOpOp2ForOuterVectorOperation(opcode) == op); } - public static boolean isSparse( Hop hop ) { + public static boolean isSparse(Hop hop) { return hop.dimsKnown(true) //dims and nnz known && MatrixBlock.evalSparseFormatInMemory(hop.getDim1(), hop.getDim2(), hop.getNnz()); } + public static boolean isDense(Hop hop) { + return hop.dimsKnown(true) //dims and nnz known + && !MatrixBlock.evalSparseFormatInMemory(hop.getDim1(), hop.getDim2(), hop.getNnz()); + } + public static boolean isSparse( Hop hop, double threshold ) { return hop.getSparsity() < threshold; } http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index f07a379..89e2146 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -1851,18 +1851,15 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop V = hi.getInput().get(1).getInput().get(1); //for this basic pattern, we're more conservative and only apply wdivmm if - //the factors are not known to be sparse - if( !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V) ) { - if( !HopRewriteUtils.isTransposeOperation(V) ) - V = HopRewriteUtils.createTranspose(V); - else - V = V.getInput().get(0); - + //W is sparse and U/V unknown or dense; or if U/V are dense + if( (HopRewriteUtils.isSparse(W) && !HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V)) + || (HopRewriteUtils.isDense(U) && HopRewriteUtils.isDense(V)) ) { + V = !HopRewriteUtils.isTransposeOperation(V) ? + HopRewriteUtils.createTranspose(V) : V.getInput().get(0); hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, - OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); + OpOp4.WDIVMM, W, U, V, new LiteralOp(-1), 0, true, false); hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); hnew.refreshSizeInformation(); - appliedPattern = true; LOG.debug("Applied simplifyWeightedDivMM7 (line "+hi.getBeginLine()+")"); } http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 3bd78b0..01f44cf 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -504,14 +504,12 @@ public abstract class AutomatedTestBase return matrix; } - protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR) - { + protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR) { MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, -1); return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc); } - protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, int nnz, boolean bIncludeR) - { + protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, long nnz, boolean bIncludeR) { MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz); return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc); } http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java index d8053b3..a94eb61 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java @@ -626,12 +626,12 @@ public class WeightedDivMatrixMultTest extends AutomatedTestBase int rank = sparse ? rank2 : rank1; //generate actual dataset - double[][] W = getRandomMatrix(rows, cols, 0, 1, sparsity, 7); - writeInputMatrixWithMTD("W", W, true); - double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 713); - writeInputMatrixWithMTD("U", U, true); - double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 812); - writeInputMatrixWithMTD("V", V, true); + double[][] W = getRandomMatrix(rows, cols, 0, 1, sparsity, 7); + writeInputMatrixWithMTD("W", W, TestUtils.computeNNZ(W), true); + double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 713); + writeInputMatrixWithMTD("U", U, TestUtils.computeNNZ(U), true); + double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 812); + writeInputMatrixWithMTD("V", V, TestUtils.computeNNZ(V), true); runTest(true, false, null, -1); runRScript(true); @@ -652,12 +652,11 @@ public class WeightedDivMatrixMultTest extends AutomatedTestBase Assert.assertTrue("Missing opcode sp_wdivmm", Statistics.getCPHeavyHitterOpCodes().contains(opcode) ); } } - finally - { + finally { rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld; QuaternaryOp.FORCE_REPLICATION = forceOld; } - } -} \ No newline at end of file + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/utils/TestUtils.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/utils/TestUtils.java b/src/test/java/org/apache/sysml/test/utils/TestUtils.java index aaaade23..612ed40 100644 --- a/src/test/java/org/apache/sysml/test/utils/TestUtils.java +++ b/src/test/java/org/apache/sysml/test/utils/TestUtils.java @@ -2070,8 +2070,7 @@ public class TestUtils return data; } - public static double sum(double[][] data, int rows, int cols) - { + public static double sum(double[][] data, int rows, int cols) { double sum = 0; for (int i = 0; i< rows; i++){ for (int j = 0; j < cols; j++){ @@ -2080,4 +2079,11 @@ public class TestUtils } return sum; } + + public static long computeNNZ(double[][] data) { + long nnz = 0; + for(int i=0; i<data.length; i++) + nnz += UtilFunctions.computeNnz(data[i], 0, data[i].length); + return nnz; + } }
