[SYSTEMML-1659] New simplification rewrite 'aggregate elimination'

This new static algebraic simplification rewrite removes unnecessary
row- or column-wise aggregates which are directly fed into a full
row/column aggregate. For example, we now rewrite sum(rowSums(X)), as it
appears in nn-cross_entropy_loss::forward, to sum(X).


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/50d211ba
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/50d211ba
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/50d211ba

Branch: refs/heads/master
Commit: 50d211baa91e6a74b32cd8c1780758608d33c7c8
Parents: a68648d
Author: Matthias Boehm <[email protected]>
Authored: Fri Jun 2 21:47:59 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 3 10:48:31 2017 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   |  28 ++++
 .../test/integration/AutomatedTestBase.java     |   2 +-
 .../misc/RewriteEliminateAggregatesTest.java    | 136 +++++++++++++++++++
 .../functions/misc/RewriteEliminateAggregate.R  |  41 ++++++
 .../misc/RewriteEliminateAggregate.dml          |  41 ++++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 6 files changed, 248 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index a3db317..74f5488 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -148,6 +148,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
                        hi = simplifyBushyBinaryOperation(hop, hi, i);       
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
                        hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
+                       hi = removeUnnecessaryAggregates(hi);                
//e.g., sum(rowSums(X)) -> sum(X)
                        hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
                        hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
@@ -817,6 +818,33 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                return hi;
        }
        
+       private Hop removeUnnecessaryAggregates(Hop hi)
+       {
+               //sum(rowSums(X)) -> sum(X), sum(colSums(X)) -> sum(X)
+               //min(rowMins(X)) -> min(X), min(colMins(X)) -> min(X)
+               //max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X)
+               //sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X)
+               if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof 
AggUnaryOp
+                       && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+                       && hi.getInput().get(0).getParent().size()==1 )
+               {
+                       AggUnaryOp au1 = (AggUnaryOp) hi;
+                       AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0);
+                       if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM 
|| au2.getOp()==AggOp.SUM_SQ)) 
+                               || (au1.getOp()==AggOp.MIN && 
au2.getOp()==AggOp.MIN)
+                               || (au1.getOp()==AggOp.MAX && 
au2.getOp()==AggOp.MAX) )
+                       {
+                               Hop input = au2.getInput().get(0);
+                               HopRewriteUtils.removeAllChildReferences(au2);
+                               HopRewriteUtils.replaceChildReference(au1, au2, 
input);
+                               
+                               LOG.debug("Applied removeUnnecessaryAggregates 
(line "+hi.getBeginLine()+").");
+                       }
+               }
+               
+               return hi;
+       }
+       
        private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, 
int pos ) 
                throws HopsException
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java 
b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 0e56655..7b93211 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1818,7 +1818,7 @@ public abstract class AutomatedTestBase
                for( String opcode : Statistics.getCPHeavyHitterOpCodes())
                        for( String s : str )
                                if(opcode.contains(s))
-                               return true;
+                                       return true;
                return false;
        }
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java
new file mode 100644
index 0000000..741ef31
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEliminateAggregatesTest.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteEliminateAggregatesTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME = "RewriteEliminateAggregate";
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteEliminateAggregatesTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }) );
+       }
+       
+       @Test
+       public void testEliminateSumSumNoRewrite() {
+               testRewriteEliminateAggregate(1, false);
+       }
+       
+       @Test
+       public void testEliminateMinMinNoRewrite() {
+               testRewriteEliminateAggregate(2, false);
+       }
+       
+       @Test
+       public void testEliminateMaxMaxNoRewrite() {
+               testRewriteEliminateAggregate(3, false);
+       }
+       
+       @Test
+       public void testEliminateSumSqSumNoRewrite() {
+               testRewriteEliminateAggregate(4, false);
+       }
+       
+       @Test
+       public void testEliminateMinSumNoRewrite() {
+               testRewriteEliminateAggregate(5, false);
+       }
+       
+       @Test
+       public void testEliminateSumSumRewrite() {
+               testRewriteEliminateAggregate(1, true);
+       }
+       
+       @Test
+       public void testEliminateMinMinRewrite() {
+               testRewriteEliminateAggregate(2, true);
+       }
+       
+       @Test
+       public void testEliminateMaxMaxRewrite() {
+               testRewriteEliminateAggregate(3, true);
+       }
+       
+       @Test
+       public void testEliminateSumSqSumRewrite() {
+               testRewriteEliminateAggregate(4, true);
+       }
+       
+       @Test
+       public void testEliminateMinSumRewrite() {
+               testRewriteEliminateAggregate(5, true);
+       }
+       
+       private void testRewriteEliminateAggregate(int type, boolean rewrites)
+       {       
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{ "-stats","-args", 
+                               input("A"), String.valueOf(type), 
output("Scalar") };
+                       
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = getRCmd(inputDir(), String.valueOf(type), 
expectedDir());                        
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       
+                       //generate actual dataset 
+                       double[][] A = getRandomMatrix(123, 12, 0, 1, 0.9, 7); 
+                       writeInputMatrixWithMTD("A", A, true);
+                       
+                       //run test
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare scalars 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLScalarFromHDFS("Scalar");
+                       HashMap<CellIndex, Double> rfile  = 
readRScalarFromFS("Scalar");
+                       TestUtils.compareScalars(dmlfile.toString(), 
rfile.toString());
+                       
+                       //check for applied rewrites
+                       if( rewrites ) {
+                               Assert.assertEquals(type==5, 
+                                       heavyHittersContainsSubString("uar", 
"uac"));
+                       } 
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }       
+       }       
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/scripts/functions/misc/RewriteEliminateAggregate.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEliminateAggregate.R 
b/src/test/scripts/functions/misc/RewriteEliminateAggregate.R
new file mode 100644
index 0000000..6848443
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEliminateAggregate.R
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+type = args[2]
+
+if( type==1 ) {
+  agg = sum(rowSums(A));
+} else if( type==2 ) {
+  agg = min(rowMins(A));
+} else if( type==3 ) {
+  agg = max(rowMaxs(A));
+} else if( type==4 ) {
+  agg = sum(rowSums(A^2));
+} else if( type==5 ) {
+  agg = sum(rowMins(A));
+}
+
+write(agg, paste(args[3], "Scalar",sep=""))
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml 
b/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml
new file mode 100644
index 0000000..e00199d
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEliminateAggregate.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+type = $2;
+
+if( type==1 ) {
+  agg = sum(rowSums(A));
+}
+else if( type==2 ) {
+  agg = min(rowMins(A));
+}
+else if( type==3 ) {
+  agg = max(rowMaxs(A));
+}
+else if( type==4 ) {
+  agg = sum(rowSums(A^2));
+}
+else if( type==5 ) {
+  agg = sum(rowMins(A));
+}
+
+write(agg, $3);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/50d211ba/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 8a06322..7da786d 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -49,6 +49,7 @@ import org.junit.runners.Suite;
        ReadAfterWriteTest.class,
        RewriteCSETransposeScalarTest.class,
        RewriteCTableToRExpandTest.class,
+       RewriteEliminateAggregatesTest.class,
        RewriteFusedRandTest.class,
        RewriteLoopVectorization.class,
        RewriteMatrixMultChainOptTest.class,

Reply via email to