[SYSTEMML-1254] New sum-product rewrites (agg pushdown), stratstats

In the spirit of our SPOOF compiler framework and the existing
sum(X%*%Y) rewrite, this patch adds the following two sum-product
rewrites (where the first applies multiple times in stratstats):

* colSums(X %*% Y) -> colsSums(X) %*% Y
* rowSums(X %*% Y) -> X %*% rowSums(Y)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b3ba9916
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b3ba9916
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b3ba9916

Branch: refs/heads/master
Commit: b3ba991604cf79c5e3e2c0992fe2439ae47ce023
Parents: d3e617b
Author: Matthias Boehm <[email protected]>
Authored: Sun Feb 12 08:29:36 2017 +0100
Committer: Matthias Boehm <[email protected]>
Committed: Sun Feb 12 09:42:03 2017 +0100

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  | 54 +++++++++++---------
 1 file changed, 30 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3ba9916/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e8e3862..6ffcbd5 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2547,11 +2547,12 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
 
        private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos)
        {
-               //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B))
-               //if not dot product, not applied since aggregate removed
-               //if sum not the only consumer, not applied to prevent 
redundancy 
+               //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten 
to dot-product
+               //colSums(A%*%B) -> colSums(A)%*%B
+               //rowSums(A%*%B) -> A%*%rowSums(B)
+               //-- if not dot product, not applied since aggregate removed
+               //-- if sum not the only consumer, not applied to prevent 
redundancy 
                if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getOp()==AggOp.SUM  //sum
-                       && ((AggUnaryOp)hi).getDirection() == Direction.RowCol  
         //full aggregate
                        && hi.getInput().get(0) instanceof AggBinaryOp          
         //A%*%B
                        && (hi.getInput().get(0).getDim1()>1 || 
hi.getInput().get(0).getDim2()>1) //not dot product
                        && hi.getInput().get(0).getParent().size()==1 )     
//not multiple consumers of matrix mult
@@ -2560,34 +2561,39 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        Hop left = hi2.getInput().get(0);
                        Hop right = hi2.getInput().get(1);
                                
-                       //remove link from parent to diag
+                       //remove link from parent to matrix mult
                        HopRewriteUtils.removeChildReference(hi, hi2);
                                
                        //create new operators
-                       AggUnaryOp colSum = new AggUnaryOp(left.getName(), 
left.getDataType(), left.getValueType(), AggOp.SUM, Direction.Col, left);
-                       colSum.setRowsInBlock(left.getRowsInBlock());
-                       colSum.setColsInBlock(left.getColsInBlock());
-                       colSum.refreshSizeInformation();
-                       ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
-                       AggUnaryOp rowSum = new AggUnaryOp(right.getName(), 
right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, right);
-                       rowSum.setRowsInBlock(right.getRowsInBlock());
-                       rowSum.setColsInBlock(right.getColsInBlock());
-                       rowSum.refreshSizeInformation();
-                       BinaryOp mult = new BinaryOp(right.getName(), 
right.getDataType(), right.getValueType(), OpOp2.MULT, trans, rowSum);
-                       mult.setRowsInBlock(right.getRowsInBlock());
-                       mult.setColsInBlock(right.getColsInBlock());
-                       mult.refreshSizeInformation();
-                               
+                       Hop root = null;
+                       //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), 
later rewritten to dot-product
+                       if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol 
) {
+                               AggUnaryOp colSum = 
HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+                               ReorgOp trans = 
HopRewriteUtils.createTranspose(colSum);
+                               AggUnaryOp rowSum = 
HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+                               root = HopRewriteUtils.createBinary(trans, 
rowSum, OpOp2.MULT);
+                               LOG.debug("Applied simplifySumMatrixMult RC.");
+                       }
+                       //colSums(A%*%B) -> colSums(A)%*%B
+                       else if( ((AggUnaryOp)hi).getDirection() == 
Direction.Col ) {
+                               AggUnaryOp colSum = 
HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+                               root = 
HopRewriteUtils.createMatrixMultiply(colSum, right);
+                               LOG.debug("Applied simplifySumMatrixMult C.");
+                       }
+                       //rowSums(A%*%B) -> A%*%rowSums(B)
+                       else if( ((AggUnaryOp)hi).getDirection() == 
Direction.Row ) {
+                               AggUnaryOp rowSum = 
HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+                               root = 
HopRewriteUtils.createMatrixMultiply(left, rowSum);
+                               LOG.debug("Applied simplifySumMatrixMult R.");
+                       }
                        
                        //rehang new subdag under current node (keep hi intact)
-                       HopRewriteUtils.addChildReference(hi, mult, 0);         
                
+                       HopRewriteUtils.addChildReference(hi, root, 0);         
                
                        hi.refreshSizeInformation();
-                               
+                       
                        //cleanup if only consumer of intermediate
                        if( hi2.getParent().isEmpty() ) 
-                               HopRewriteUtils.removeAllChildReferences( hi2 );
-                       
-                       LOG.debug("Applied simplifySumMatrixMult.");    
+                               HopRewriteUtils.removeAllChildReferences( hi2 
);        
                }
                
                return hi;

Reply via email to