Repository: systemml
Updated Branches:
  refs/heads/master a83c25019 -> 07bab605e


[SYSTEMML-2184] Apply basic wdivmm conservatively (sparse/dense inputs)

This patch makes a compiler improvement to apply the fused basic wdivmm
for the pattern (X * U%*%t(V)) more conservatively only for when (1) X
is sparse and U/V are unknown or dense, or (2) when U and V known to be
dense, which are the primary applications for this sparsity exploiting
fused operator. At the same time, this helps to avoid performance issues
for scenarios where U and V are actually ultra-sparse and wdivmm would
hurt performance.

For example, on perftest stratstats 10K x 10K, this patch improved
performance from 128s to 29s end-to-end.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/07bab605
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/07bab605
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/07bab605

Branch: refs/heads/master
Commit: 07bab605ead7b5757ce6ed0211042c216b5ffb41
Parents: a83c250
Author: Matthias Boehm <[email protected]>
Authored: Mon Mar 12 23:07:52 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Mar 12 23:07:52 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java      |  7 ++++++-
 .../RewriteAlgebraicSimplificationDynamic.java   | 15 ++++++---------
 .../test/integration/AutomatedTestBase.java      |  6 ++----
 .../quaternary/WeightedDivMatrixMultTest.java    | 19 +++++++++----------
 .../org/apache/sysml/test/utils/TestUtils.java   | 10 ++++++++--
 5 files changed, 31 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index dce21ab..735cac0 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -840,11 +840,16 @@ public class HopRewriteUtils
                return (Hop.getOpOp2ForOuterVectorOperation(opcode) == op);
        }
        
-       public static boolean isSparse( Hop hop ) {
+       public static boolean isSparse(Hop hop) {
                return hop.dimsKnown(true) //dims and nnz known
                        && MatrixBlock.evalSparseFormatInMemory(hop.getDim1(), 
hop.getDim2(), hop.getNnz());
        }
        
+       public static boolean isDense(Hop hop) {
+               return hop.dimsKnown(true) //dims and nnz known
+                       && !MatrixBlock.evalSparseFormatInMemory(hop.getDim1(), 
hop.getDim2(), hop.getNnz());
+       }
+       
        public static boolean isSparse( Hop hop, double threshold ) {
                return hop.getSparsity() < threshold;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index f07a379..89e2146 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1851,18 +1851,15 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        Hop V = hi.getInput().get(1).getInput().get(1);
                        
                        //for this basic pattern, we're more conservative and 
only apply wdivmm if
-                       //the factors are not known to be sparse
-                       if( !HopRewriteUtils.isSparse(U) && 
!HopRewriteUtils.isSparse(V) ) {
-                               if( !HopRewriteUtils.isTransposeOperation(V) )
-                                       V = HopRewriteUtils.createTranspose(V);
-                               else 
-                                       V = V.getInput().get(0);
-                               
+                       //W is sparse and U/V unknown or dense; or if U/V are 
dense
+                       if( (HopRewriteUtils.isSparse(W) && 
!HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V))
+                               || (HopRewriteUtils.isDense(U) && 
HopRewriteUtils.isDense(V)) ) {
+                               V = !HopRewriteUtils.isTransposeOperation(V) ?
+                                       HopRewriteUtils.createTranspose(V) : 
V.getInput().get(0);
                                hnew = new QuaternaryOp(hi.getName(), 
DataType.MATRIX, ValueType.DOUBLE, 
-                                                 OpOp4.WDIVMM, W, U, V, new 
LiteralOp(-1), 0, true, false);
+                                       OpOp4.WDIVMM, W, U, V, new 
LiteralOp(-1), 0, true, false);
                                hnew.setOutputBlocksizes(W.getRowsInBlock(), 
W.getColsInBlock());
                                hnew.refreshSizeInformation();
-                               
                                appliedPattern = true;
                                LOG.debug("Applied simplifyWeightedDivMM7 (line 
"+hi.getBeginLine()+")");
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java 
b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 3bd78b0..01f44cf 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -504,14 +504,12 @@ public abstract class AutomatedTestBase
                return matrix;
        }
 
-       protected double[][] writeInputMatrixWithMTD(String name, double[][] 
matrix, boolean bIncludeR)
-       {
+       protected double[][] writeInputMatrixWithMTD(String name, double[][] 
matrix, boolean bIncludeR) {
                MatrixCharacteristics mc = new 
MatrixCharacteristics(matrix.length, matrix[0].length, 
OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
                return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc);
        }
 
-       protected double[][] writeInputMatrixWithMTD(String name, double[][] 
matrix, int nnz, boolean bIncludeR)
-       {
+       protected double[][] writeInputMatrixWithMTD(String name, double[][] 
matrix, long nnz, boolean bIncludeR) {
                MatrixCharacteristics mc = new 
MatrixCharacteristics(matrix.length, matrix[0].length, 
OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz);
                return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
index d8053b3..a94eb61 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/quaternary/WeightedDivMatrixMultTest.java
@@ -626,12 +626,12 @@ public class WeightedDivMatrixMultTest extends 
AutomatedTestBase
                        int rank = sparse ? rank2 : rank1;
                        
                        //generate actual dataset 
-                       double[][] W = getRandomMatrix(rows, cols, 0, 1, 
sparsity, 7); 
-                       writeInputMatrixWithMTD("W", W, true);
-                       double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 
713); 
-                       writeInputMatrixWithMTD("U", U, true);
-                       double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 
812); 
-                       writeInputMatrixWithMTD("V", V, true);
+                       double[][] W = getRandomMatrix(rows, cols, 0, 1, 
sparsity, 7);
+                       writeInputMatrixWithMTD("W", W, 
TestUtils.computeNNZ(W), true);
+                       double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 
713);
+                       writeInputMatrixWithMTD("U", U, 
TestUtils.computeNNZ(U), true);
+                       double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 
812);
+                       writeInputMatrixWithMTD("V", V, 
TestUtils.computeNNZ(V), true);
                        
                        runTest(true, false, null, -1); 
                        runRScript(true); 
@@ -652,12 +652,11 @@ public class WeightedDivMatrixMultTest extends 
AutomatedTestBase
                                Assert.assertTrue("Missing opcode sp_wdivmm", 
Statistics.getCPHeavyHitterOpCodes().contains(opcode) );
                        }
                }
-               finally
-               {
+               finally {
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewritesOld;
                        QuaternaryOp.FORCE_REPLICATION = forceOld;
                }
-       }       
-}
\ No newline at end of file
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/07bab605/src/test/java/org/apache/sysml/test/utils/TestUtils.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/utils/TestUtils.java 
b/src/test/java/org/apache/sysml/test/utils/TestUtils.java
index aaaade23..612ed40 100644
--- a/src/test/java/org/apache/sysml/test/utils/TestUtils.java
+++ b/src/test/java/org/apache/sysml/test/utils/TestUtils.java
@@ -2070,8 +2070,7 @@ public class TestUtils
                return data;
        }
        
-       public static double sum(double[][] data, int rows, int cols)
-       {
+       public static double sum(double[][] data, int rows, int cols) {
                double sum = 0;
                for (int i = 0; i< rows; i++){
                        for (int j = 0; j < cols; j++){
@@ -2080,4 +2079,11 @@ public class TestUtils
                }
                return sum;
        }
+       
+       public static long computeNNZ(double[][] data) {
+               long nnz = 0;
+               for(int i=0; i<data.length; i++)
+                       nnz += UtilFunctions.computeNnz(data[i], 0, 
data[i].length);
+               return nnz;
+       }
 }

Reply via email to