Repository: incubator-systemml
Updated Branches:
  refs/heads/master 1385cf1ca -> 201238fd3


[SYSTEMML-1254] New rewrite 'pushdown CSE transpose-scalar', incl tests 

This new rewrite allows to pushdown a transpose below a matrix-scalar
binary operation (except quantile and centralMoment) in order to reuse
an existing transpose common subexpression. 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0e6411da
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0e6411da
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0e6411da

Branch: refs/heads/master
Commit: 0e6411dada77870ae29049288b1789313a35a9f6
Parents: 1385cf1
Author: Matthias Boehm <[email protected]>
Authored: Tue Feb 14 17:11:59 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Wed Feb 15 10:49:19 2017 -0800

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  19 +++-
 .../RewriteAlgebraicSimplificationStatic.java   |  36 +++++++
 .../java/org/apache/sysml/utils/Statistics.java |   7 +-
 .../misc/RewriteCSETransposeScalarTest.java     | 104 +++++++++++++++++++
 .../misc/RewriteCSETransposeScalarMult.dml      |  31 ++++++
 .../misc/RewriteCSETransposeScalarPow.dml       |  31 ++++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 7 files changed, 223 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 50501dc..d3be09d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -711,17 +711,28 @@ public class HopRewriteUtils
                return ret;
        }
        
-       public static boolean isTransposeOperation(Hop hop)
-       {
+       public static boolean isTransposeOperation(Hop hop) {
                return (hop instanceof ReorgOp && 
((ReorgOp)hop).getOp()==ReOrgOp.TRANSPOSE);
        }
        
-       public static boolean isTransposeOfItself(Hop hop1, Hop hop2)
-       {
+       public static boolean containsTransposeOperation(ArrayList<Hop> hops) {
+               boolean ret = false;
+               for( Hop hop : hops )
+                       ret |= isTransposeOperation(hop);
+               return ret;
+       }
+       
+       public static boolean isTransposeOfItself(Hop hop1, Hop hop2) {
                return hop1 instanceof ReorgOp && 
((ReorgOp)hop1).getOp()==ReOrgOp.TRANSPOSE && hop1.getInput().get(0) == hop2
                        || hop2 instanceof ReorgOp && 
((ReorgOp)hop2).getOp()==ReOrgOp.TRANSPOSE && hop2.getInput().get(0) == hop1;   
  
        }
        
+       public static boolean isBinaryMatrixScalarOperation(Hop hop) {
+               return hop instanceof BinaryOp && 
+                       ((hop.getInput().get(0).getDataType().isMatrix() && 
hop.getInput().get(1).getDataType().isScalar())
+                       ||(hop.getInput().get(1).getDataType().isMatrix() && 
hop.getInput().get(0).getDataType().isScalar()));
+       }
+       
        public static boolean isNonZeroIndicator(Hop pred, Hop hop )
        {
                if( pred instanceof BinaryOp && 
((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ad6a4da..41459b4 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -148,6 +148,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyUnaryAggReorgOperation(hop, hi, i);     
//e.g., sum(t(X)) -> sum(X)
                        hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
+                       hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
                        hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lamda*X) -> lamda*sum(X)
                        hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
                        hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
@@ -942,6 +943,41 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                
                return hi;
        }
+       
+       private Hop pushdownCSETransposeScalarOperation( Hop parent, Hop hi, 
int pos )
+       {
+               // a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
+               // probed at root node of b in above example
+               // (with support for left or right scalar operations)
+               if( HopRewriteUtils.isTransposeOperation(hi) && 
hi.getParent().size()==1
+                       && 
HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0))
+                       && hi.getInput().get(0).getParent().size()==1) 
+               {
+                       int Xpos = 
hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1;
+                       Hop X = hi.getInput().get(0).getInput().get(Xpos);
+                       BinaryOp binary = (BinaryOp) hi.getInput().get(0);
+                       
+                       if( 
HopRewriteUtils.containsTransposeOperation(X.getParent()) 
+                               && !HopRewriteUtils.isValidOp(binary.getOp(), 
new OpOp2[]{OpOp2.CENTRALMOMENT, OpOp2.QUANTILE})) 
+                       {
+                               //clear existing wiring
+                               
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);     
+                               HopRewriteUtils.removeChildReference(hi, 
binary);
+                               HopRewriteUtils.removeChildReference(binary, X);
+                               
+                               //re-wire operators
+                               HopRewriteUtils.addChildReference(parent, 
binary, pos);
+                               HopRewriteUtils.addChildReference(binary, hi, 
Xpos);
+                               HopRewriteUtils.addChildReference(hi, X);
+                               //note: common subexpression later eliminated 
by dedicated rewrite
+               
+                               hi = binary;
+                               LOG.debug("Applied 
pushdownCSETransposeScalarOperation (line "+hi.getBeginLine()+").");
+                       }       
+               }
+               
+               return hi;
+       }
 
        private Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) throws 
HopsException {
                //pattern:  sum(lamda*X) -> lamda*sum(X)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index cf9b5fb..e371d9c 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -464,10 +464,13 @@ public class Statistics
                _cpInstCounts.put(key, newCnt);
        }
        
-       public static Set<String> getCPHeavyHitterOpCodes()
-       {
+       public static Set<String> getCPHeavyHitterOpCodes() {
                return _cpInstTime.keySet();
        }
+       
+       public static long getCPHeavyHitterCount(String opcode) {
+               return _cpInstCounts.get(opcode);
+       }
 
        @SuppressWarnings("unchecked")
        public static String getHeavyHitters( int num )

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java
new file mode 100644
index 0000000..61daf38
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCSETransposeScalarTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * 
+ * 
+ */
+public class RewriteCSETransposeScalarTest extends AutomatedTestBase 
+{      
+       private static final String TEST_NAME1 = 
"RewriteCSETransposeScalarPow"; //right scalar
+       private static final String TEST_NAME2 = 
"RewriteCSETransposeScalarMult"; //left scalar
+       
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteCSETransposeScalarTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 1932;
+       private static final int cols = 14;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testRewriteCSETransposePow()  {
+               testRewriteCSETransposeScalar( TEST_NAME1, true );
+       }
+       
+       @Test
+       public void testRewriteCSETransposePowNoRewrite()  {
+               testRewriteCSETransposeScalar( TEST_NAME1, false );
+       }
+       
+       @Test
+       public void testRewriteCSETransposeMult()  {
+               testRewriteCSETransposeScalar( TEST_NAME2, true );
+       }
+       
+       @Test
+       public void testRewriteCSETransposeMultNoRewrite()  {
+               testRewriteCSETransposeScalar( TEST_NAME2, false );
+       }
+       
+       /**
+        * 
+        * @param testname
+        * @param et
+        */
+       private void testRewriteCSETransposeScalar( String testname, boolean 
rewrites )
+       {       
+               boolean rewritesOld = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats", "-args", 
String.valueOf(rows), 
+                                       String.valueOf(cols), output("R") };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       //run performance tests
+                       runTest(true, false, null, -1); 
+                       
+                       //compare output  
+                       double ret = TestUtils.readDMLScalar(output("R"));
+                       Assert.assertEquals("Wrong result, expected: 
"+(rows*cols), new Double(rows*cols), new Double(ret));
+                       Assert.assertEquals(new Long(rewrites?1:2), new 
Long(Statistics.getCPHeavyHitterCount("r'")));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewritesOld;
+               }
+       }       
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml 
b/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml
new file mode 100644
index 0000000..07e67dc
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCSETransposeScalarMult.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=$1, cols=$2, min=1, max=10);
+if(1==1){}
+
+a = t(X);
+b = t(2*X);
+
+if(1==1){}
+
+R = sum(2*a == b);
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml 
b/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml
new file mode 100644
index 0000000..f47c227
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCSETransposeScalarPow.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=$1, cols=$2, min=1, max=10);
+if(1==1){}
+
+a = t(X);
+b = t(X^2);
+
+if(1==1){}
+
+R = sum(a^2 == b);
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0e6411da/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 32b5f7b..1b3478d 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -46,6 +46,7 @@ import org.junit.runners.Suite;
        PrintExpressionTest.class,
        PrintMatrixTest.class,
        ReadAfterWriteTest.class,
+       RewriteCSETransposeScalarTest.class,
        RewriteFusedRandTest.class,
        RewriteLoopVectorization.class,
        RewritePushdownSumBinaryMult.class,

Reply via email to