[SYSTEMML-2507] New rewrites for cumulative aggregate patterns

This patch adds the following simplification rewrites as well as related
tests:
(a) X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri, if X squared
(b) colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
(c) rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9a1f64b4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9a1f64b4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9a1f64b4

Branch: refs/heads/master
Commit: 9a1f64b42c177a82a98716ad9ef34d4d266178d2
Parents: b96807b
Author: Matthias Boehm <[email protected]>
Authored: Tue Dec 11 20:10:23 2018 +0100
Committer: Matthias Boehm <[email protected]>
Committed: Tue Dec 11 20:10:46 2018 +0100

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  |  33 ++++-
 .../RewriteAlgebraicSimplificationStatic.java   |  45 +++++++
 .../hops/rewrite/RewriteGPUSpecificOps.java     |  26 ++--
 .../misc/RewriteCumulativeAggregatesTest.java   | 126 +++++++++++++++++++
 .../misc/RewriteCumulativeAggregates.R          |  43 +++++++
 .../misc/RewriteCumulativeAggregates.dml        |  49 ++++++++
 6 files changed, 306 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 36864aa..9556181 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifyMatrixMultDiag(hop, hi, i);          
//e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1 
                        hi = simplifyDiagMatrixMult(hop, hi, i);          
//e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
                        hi = simplifySumDiagToTrace(hi);                  
//e.g., sum(diag(X)) -> trace(X); if col vector
+                       hi = simplifyLowerTriExtraction(hop, hi, i);      
//e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
                        hi = pushdownBinaryOperationOnDiag(hop, hi, i);   
//e.g., diag(X)*7 -> diag(X*7); if col vector
                        hi = pushdownSumOnAdditiveBinary(hop, hi, i);     
//e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1046,7 +1047,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                if( hi instanceof AggUnaryOp ) 
                {
                        AggUnaryOp au = (AggUnaryOp) hi;
-                       if( au.getOp()==AggOp.SUM && 
au.getDirection()==Direction.RowCol )      //sum   
+                       if( au.getOp()==AggOp.SUM && 
au.getDirection()==Direction.RowCol )      //sum
                        {
                                Hop hi2 = au.getInput().get(0);
                                if( hi2 instanceof ReorgOp && 
((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
@@ -1054,7 +1055,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                        Hop hi3 = hi2.getInput().get(0);
                                        
                                        //remove diag operator
-                                       
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0); 
+                                       
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
                                        
HopRewriteUtils.cleanupUnreferenced(hi2);
                                        
                                        //change sum to trace
@@ -1063,12 +1064,38 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                        LOG.debug("Applied 
simplifySumDiagToTrace");
                                }
                        }
-                               
                }
                
                return hi;
        }
        
+       private static Hop simplifyLowerTriExtraction(Hop parent, Hop hi, int 
pos) {
+               //pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri 
(only right)
+               if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) 
+                       && hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) {
+                       Hop left = hi.getInput().get(0);
+                       Hop right = hi.getInput().get(1);
+                       
+                       if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) && 
right.getParent().size()==1
+                               && 
HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG)
+                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0),
 1d))
+                       {
+                               LinkedHashMap<String,Hop> args = new 
LinkedHashMap<>();
+                               args.put("target", left);
+                               args.put("diag", new LiteralOp(true));
+                               args.put("values", new LiteralOp(true));
+                               Hop hnew = 
HopRewriteUtils.createParameterizedBuiltinOp(
+                                       left, args, ParamBuiltinOp.LOWER_TRI);
+                               HopRewriteUtils.replaceChildReference(parent, 
hi, hnew);
+                               HopRewriteUtils.removeAllChildReferences(right);
+                               
+                               hi = hnew;
+                               LOG.debug("Applied simplifyLowerTriExtraction");
+                       }
+               }
+               return hi;
+       }
+       
        @SuppressWarnings("unchecked")
        private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, 
int pos) 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 62a5d4f..9a3956c 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -183,6 +183,9 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        }
                        hi = simplifyOuterSeqExpand(hop, hi, i);             
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, 
cast=false)
                        hi = simplifyBinaryComparisonChain(hop, hi, i);      
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> 
outer(v1,v2,"!="), 
+                       hi = simplifyCumsumColOrFullAggregates(hi);          
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+                       hi = simplifyCumsumReverse(hop, hi, i);              
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+                       
                        
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
@@ -1844,6 +1847,48 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
        
+       private static Hop simplifyCumsumColOrFullAggregates(Hop hi) {
+               //pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
+               if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col)
+                       || HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, 
Direction.RowCol))
+                       && HopRewriteUtils.isUnary(hi.getInput().get(0), 
OpOp1.CUMSUM)
+                       && hi.getInput().get(0).getParent().size()==1)
+               {
+                       Hop cumsumX = hi.getInput().get(0);
+                       Hop X = cumsumX.getInput().get(0);
+                       Hop mult = HopRewriteUtils.createBinary(X,
+                               HopRewriteUtils.createSeqDataGenOp(X, false), 
OpOp2.MULT);
+                       HopRewriteUtils.replaceChildReference(hi, cumsumX, 
mult);
+                       HopRewriteUtils.removeAllChildReferences(cumsumX);
+                       LOG.debug("Applied simplifyCumsumColOrFullAggregates 
(line "+hi.getBeginLine()+")");
+               }
+               return hi;
+       }
+       
+       private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) {
+               //pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+               if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+                       && HopRewriteUtils.isUnary(hi.getInput().get(0), 
OpOp1.CUMSUM)
+                       && hi.getInput().get(0).getParent().size()==1
+                       && 
HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
+                       && 
hi.getInput().get(0).getInput().get(0).getParent().size()==1)
+               {
+                       Hop cumsumX = hi.getInput().get(0);
+                       Hop revX = cumsumX.getInput().get(0);
+                       Hop X = revX.getInput().get(0);
+                       Hop plus = HopRewriteUtils.createBinary(X, 
HopRewriteUtils
+                               .createAggUnaryOp(X, AggOp.SUM, Direction.Col), 
OpOp2.PLUS);
+                       Hop minus = HopRewriteUtils.createBinary(plus,
+                               HopRewriteUtils.createUnary(X, OpOp1.CUMSUM), 
OpOp2.MINUS);
+                       HopRewriteUtils.replaceChildReference(parent, hi, 
minus, pos);
+                       HopRewriteUtils.cleanupUnreferenced(hi, cumsumX, revX);
+                       
+                       hi = minus;
+                       LOG.debug("Applied simplifyCumsumReverse (line 
"+hi.getBeginLine()+")");
+               }
+               return hi;
+       }
+       
        /**
         * NOTE: currently disabled since this rewrite is INVALID in the
         * presence of NaNs (because (NaN!=NaN) is true). 

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
index ab40d7b..1d87c09 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteGPUSpecificOps.java
@@ -176,19 +176,19 @@ public class RewriteGPUSpecificOps extends 
HopRewriteRuleWithPatternMatcher {
        // norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
        // # Compute gradients during training
        // dgamma = util::channel_sums(dout*norm, C, Hin, Win)
-       private static final HopDagPatternMatcher _batchNormDGamma;
-       static {
-               _batchNormDGamma = util_channel_sums(
-                               mult(   leaf("dout", MATRIX).fitsOnGPU(3),
-                                               
bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", 
MATRIX))), 
-                               leaf("ema_var", MATRIX))), leaf("C", SCALAR), 
leaf("HW", SCALAR));
-       }
-       private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi 
-> {
-               LOG.debug("Applied batchNormDGamma rewrite.");
-               Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
-                               "ema_mean", "dout", "X", "ema_var");
-               return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
-       };
+//     private static final HopDagPatternMatcher _batchNormDGamma;
+//     static {
+//             _batchNormDGamma = util_channel_sums(
+//                             mult(   leaf("dout", MATRIX).fitsOnGPU(3),
+//                                             
bias_multiply(bias_add(leaf("X", MATRIX), unaryMinus(leaf("ema_mean", 
MATRIX))), 
+//                             leaf("ema_var", MATRIX))), leaf("C", SCALAR), 
leaf("HW", SCALAR));
+//     }
+//     private static final Function<Hop, Hop> _batchNormDGammaReplacer = hi 
-> {
+//             LOG.debug("Applied batchNormDGamma rewrite.");
+//             Hop newHop = HopRewriteUtils.createDnnOp(_batchNormDGamma, 
OpOpDnn.BATCH_NORM2D_BACKWARD_DGAMMA, 
+//                             "ema_mean", "dout", "X", "ema_var");
+//             return HopRewriteUtils.rewireAllParentChildReferences(hi, 
newHop);
+//     };
                
        // Pattern 3:
        private static final HopDagPatternMatcher _batchNormTest;

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
new file mode 100644
index 0000000..da13502
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteCumulativeAggregatesTest extends AutomatedTestBase 
+{      
+       private static final String TEST_NAME = "RewriteCumulativeAggregates";
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteCumulativeAggregatesTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 1234;
+       private static final int cols = 7;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testCumAggRewrite1False() {
+               testCumAggRewrite(1, false);
+       }
+       
+       @Test
+       public void testCumAggRewrite1True() {
+               testCumAggRewrite(1, true);
+       }
+       
+       @Test
+       public void testCumAggRewrite2False() {
+               testCumAggRewrite(2, false);
+       }
+       
+       @Test
+       public void testCumAggRewrite2True() {
+               testCumAggRewrite(2, true);
+       }
+       
+       @Test
+       public void testCumAggRewrite3False() {
+               testCumAggRewrite(3, false);
+       }
+       
+       @Test
+       public void testCumAggRewrite3True() {
+               testCumAggRewrite(3, true);
+       }
+       
+       @Test
+       public void testCumAggRewrite4False() {
+               testCumAggRewrite(4, false);
+       }
+       
+       @Test
+       public void testCumAggRewrite4True() {
+               testCumAggRewrite(4, true);
+       }
+       
+       private void testCumAggRewrite(int num, boolean rewrites)
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{ "-stats", "-args",
+                               input("A"), String.valueOf(num), output("R") };
+                       rCmd = getRCmd(inputDir(), String.valueOf(num), 
expectedDir());
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       
+                       //generate input data
+                       double[][] A = getRandomMatrix((num==4)?1:rows,
+                               (num==1)?rows:cols, -1, 1, 0.9, 7); 
+                       writeInputMatrixWithMTD("A", A, true);
+                       
+                       //run performance tests
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+                       
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
+                       HashMap<CellIndex, Double> rfile = 
readRMatrixFromFS("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, 1e-7, 
"Stat-DML", "Stat-R");
+                       
+                       //check applied rewrites
+                       if( rewrites )
+                               
Assert.assertTrue(!heavyHittersContainsString((num==2) ? "rev" : "ucumk+"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R 
b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
new file mode 100644
index 0000000..f8a8576
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+num = as.integer(args[2]);
+
+#note: cumsum and rev only over vectors
+if( num == 1 ) {
+  R = lower.tri(X,diag=TRUE) * X;
+} else if( num == 2 ) {
+  A = X[seq(nrow(X),1),]
+  R = apply(A, 2, cumsum);
+  R = R[seq(nrow(X),1),]
+} else if( num == 3 ) {
+  R = t(as.matrix(colSums(apply(X, 2, cumsum))));
+} else if( num == 4 ) {
+  R = X;
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/9a1f64b4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml 
b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
new file mode 100644
index 0000000..f4c3486
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function( Matrix[Double] A ) return( Matrix[Double] B )
+{
+   for( i in 1:1 ) {
+     continue = TRUE;
+     if( sum(A)<0 ) {
+        continue = FALSE;
+     }
+     iter = 0;
+     if( continue ) {
+        iter = iter+1;
+     }
+     B = A+iter;
+   }
+}
+
+X = read($1);
+
+if( $2 == 1 )
+  R = X * cumsum(diag(matrix(1,nrow(X),1)));
+else if( $2 == 2 )
+  R = rev(cumsum(rev(X)));
+else if( $2 == 3 )
+  R = colSums(cumsum(X));
+else if( $2 == 4 )
+  R = cumsum(X);
+
+write(R, $3);
\ No newline at end of file

Reply via email to